rag_eta_report_abstract.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. package rag
  2. import (
  3. "database/sql"
  4. "eta/eta_api/global"
  5. "eta/eta_api/utils"
  6. "fmt"
  7. "time"
  8. )
  9. // RagEtaReportAbstract 报告摘要
  10. type RagEtaReportAbstract struct {
  11. RagEtaReportAbstractId int `gorm:"primaryKey;column:rag_eta_report_abstract_id" description:"-"`
  12. RagEtaReportId int `gorm:"column:rag_eta_report_id" description:"ETA报告id"`
  13. Content string `gorm:"column:content" description:"摘要内容"`
  14. QuestionId int `gorm:"column:question_id" description:"提示词Id"`
  15. QuestionContent string `gorm:"column:question_content" description:"questionContent"`
  16. Version int `gorm:"column:version" description:"版本号"`
  17. Tags string `gorm:"column:tags" description:"标签"`
  18. TagsName string `gorm:"column:tags_name" description:"标签名,多个用英文逗号隔开"`
  19. VectorKey string `gorm:"column:vector_key" description:"向量key标识"`
  20. ModifyTime time.Time `gorm:"column:modify_time" description:"modifyTime"`
  21. CreateTime time.Time `gorm:"column:create_time" description:"createTime"`
  22. }
  23. // TableName get sql table name.获取数据库表名
  24. func (m *RagEtaReportAbstract) TableName() string {
  25. return "rag_eta_report_abstract"
  26. }
  27. // RagEtaReportAbstractColumns get sql column name.获取数据库列名
  28. var RagEtaReportAbstractColumns = struct {
  29. RagEtaReportAbstractId string
  30. RagEtaReportId string
  31. Content string
  32. QuestionId string
  33. QuestionContent string
  34. Version string
  35. Tags string
  36. TagsName string
  37. VectorKey string
  38. ModifyTime string
  39. CreateTime string
  40. }{
  41. RagEtaReportAbstractId: "rag_eta_report_abstract_id",
  42. RagEtaReportId: "rag_eta_report_id",
  43. Content: "content",
  44. QuestionId: "question_id",
  45. QuestionContent: "question_content",
  46. Version: "version",
  47. Tags: "tags",
  48. TagsName: "tags_name",
  49. VectorKey: "vector_key",
  50. ModifyTime: "modify_time",
  51. CreateTime: "create_time",
  52. }
  53. func (m *RagEtaReportAbstract) Create() (err error) {
  54. err = global.DbMap[utils.DbNameAI].Create(&m).Error
  55. return
  56. }
  57. func (m *RagEtaReportAbstract) Update(updateCols []string) (err error) {
  58. err = global.DbMap[utils.DbNameAI].Select(updateCols).Updates(&m).Error
  59. return
  60. }
  61. func (m *RagEtaReportAbstract) Del() (err error) {
  62. err = global.DbMap[utils.DbNameAI].Delete(&m).Error
  63. return
  64. }
  65. func (m *RagEtaReportAbstract) GetById(id int) (item *RagEtaReportAbstract, err error) {
  66. err = global.DbMap[utils.DbNameAI].Where(fmt.Sprintf("%s = ?", RagEtaReportAbstractColumns.RagEtaReportAbstractId), id).First(&item).Error
  67. return
  68. }
  69. func (m *RagEtaReportAbstract) GetByIdList(idList []int) (items []*RagEtaReportAbstract, err error) {
  70. err = global.DbMap[utils.DbNameAI].Where(fmt.Sprintf("%s in (?) ", RagEtaReportAbstractColumns.RagEtaReportAbstractId), idList).Find(&items).Error
  71. return
  72. }
  73. func (m *RagEtaReportAbstract) GetListByQuestionId(questionId int) (items []*RagEtaReportAbstract, err error) {
  74. err = global.DbMap[utils.DbNameAI].Where(fmt.Sprintf("%s = ? ", RagEtaReportAbstractColumns.QuestionId), questionId).Find(&items).Error
  75. return
  76. }
  77. func (m *RagEtaReportAbstract) GetListByCondition(field, condition string, pars []interface{}, startSize, pageSize int) (items []*RagEtaReportAbstract, err error) {
  78. if field == "" {
  79. field = "*"
  80. }
  81. sqlStr := fmt.Sprintf(`SELECT %s FROM %s WHERE 1=1 %s order by rag_eta_report_abstract_id desc LIMIT ?,?`, field, m.TableName(), condition)
  82. pars = append(pars, startSize, pageSize)
  83. err = global.DbMap[utils.DbNameAI].Raw(sqlStr, pars...).Find(&items).Error
  84. return
  85. }
  86. func (m *RagEtaReportAbstract) DelByIdList(idList []int) (err error) {
  87. if len(idList) <= 0 {
  88. return
  89. }
  90. sqlStr := fmt.Sprintf(`delete from %s where %s in (?)`, m.TableName(), RagEtaReportAbstractColumns.RagEtaReportAbstractId)
  91. err = global.DbMap[utils.DbNameAI].Exec(sqlStr, idList).Error
  92. return
  93. }
  94. // GetByRagEtaReportId
  95. // @Description: 根据报告id获取摘要
  96. // @author: Roc
  97. // @receiver m
  98. // @datetime 2025-03-07 10:00:59
  99. // @param id int
  100. // @return item *RagEtaReportAbstract
  101. // @return err error
  102. func (m *RagEtaReportAbstract) GetByRagEtaReportId(id int) (item *RagEtaReportAbstract, err error) {
  103. err = global.DbMap[utils.DbNameAI].Where(fmt.Sprintf("%s = ?", RagEtaReportAbstractColumns.RagEtaReportId), id).Order(fmt.Sprintf(`%s DESC`, RagEtaReportAbstractColumns.RagEtaReportAbstractId)).First(&item).Error
  104. return
  105. }
  106. // GetByRagEtaReportIdAndQuestionId
  107. // @Description: 根据报告id和提示词ID获取摘要
  108. // @author: Roc
  109. // @receiver m
  110. // @datetime 2025-04-17 17:39:27
  111. // @param articleId int
  112. // @param questionId int
  113. // @return item *RagEtaReportAbstract
  114. // @return err error
  115. func (m *RagEtaReportAbstract) GetByRagEtaReportIdAndQuestionId(articleId, questionId int) (item *RagEtaReportAbstract, err error) {
  116. err = global.DbMap[utils.DbNameAI].Where(fmt.Sprintf("%s = ? AND %s = ? ", RagEtaReportAbstractColumns.RagEtaReportId, RagEtaReportAbstractColumns.QuestionId), articleId, questionId).Order(fmt.Sprintf(`%s DESC`, RagEtaReportAbstractColumns.RagEtaReportAbstractId)).First(&item).Error
  117. return
  118. }
  119. type RagEtaReportAbstractView struct {
  120. RagEtaReportAbstractId int `gorm:"primaryKey;column:rag_eta_report_abstract_id" description:"-"`
  121. RagEtaReportId int `gorm:"column:rag_eta_report_id" description:"ETA报告id"`
  122. Abstract string `gorm:"column:abstract;type:longtext;comment:摘要内容;" description:"摘要内容"`
  123. QuestionId int `gorm:"column:question_id" description:"提示词Id"`
  124. QuestionContent string `gorm:"column:question_content" description:"questionContent"`
  125. Version int `gorm:"column:version" description:"版本号"`
  126. Tags string `gorm:"column:tags" description:"标签"`
  127. TagsName string `gorm:"column:tags_name" description:"标签名,多个用英文逗号隔开"`
  128. VectorKey string `gorm:"column:vector_key" description:"向量key标识"`
  129. ModifyTime string `gorm:"column:modify_time;type:datetime;default:NULL;" description:"modify_time"`
  130. CreateTime string `gorm:"column:create_time;type:datetime;default:NULL;" description:"create_time"`
  131. Title string `gorm:"column:title;type:varchar(255);comment:标题;" description:"标题"`
  132. }
  133. type RagEtaReportAbstractItem struct {
  134. RagEtaReportAbstractId int `gorm:"primaryKey;column:rag_eta_report_abstract_id" description:"-"`
  135. RagEtaReportId int `gorm:"column:rag_eta_report_id" description:"ETA报告id"`
  136. Content string `gorm:"column:content" description:"摘要内容"`
  137. QuestionId int `gorm:"column:question_id" description:"提示词Id"`
  138. QuestionContent string `gorm:"column:question_content" description:"questionContent"`
  139. Version int `gorm:"column:version" description:"版本号"`
  140. Tags string `gorm:"column:tags" description:"标签"`
  141. TagsName string `gorm:"column:tags_name" description:"标签名,多个用英文逗号隔开"`
  142. VectorKey string `gorm:"column:vector_key" description:"向量key标识"`
  143. ModifyTime time.Time `gorm:"column:modify_time;type:datetime;default:NULL;" description:"modify_time"`
  144. CreateTime time.Time `gorm:"column:create_time;type:datetime;default:NULL;" description:"create_time"`
  145. Title string `gorm:"column:title;type:varchar(255);comment:标题;" description:"标题"`
  146. }
  147. func (m *RagEtaReportAbstractItem) ToView() RagEtaReportAbstractView {
  148. return RagEtaReportAbstractView{
  149. RagEtaReportAbstractId: m.RagEtaReportAbstractId,
  150. RagEtaReportId: m.RagEtaReportId,
  151. Abstract: m.Content,
  152. Version: m.Version,
  153. VectorKey: m.VectorKey,
  154. ModifyTime: utils.DateStrToDateTimeStr(m.ModifyTime),
  155. CreateTime: utils.DateStrToDateTimeStr(m.CreateTime),
  156. Title: m.Title,
  157. QuestionId: m.QuestionId,
  158. Tags: m.Tags,
  159. TagsName: m.TagsName,
  160. QuestionContent: m.QuestionContent,
  161. }
  162. }
  163. func (m *RagEtaReportAbstract) EtaReportAbstractItem(list []*RagEtaReportAbstractItem) (etaReportAbstractViewList []RagEtaReportAbstractView) {
  164. etaReportAbstractViewList = make([]RagEtaReportAbstractView, 0)
  165. for _, v := range list {
  166. etaReportAbstractViewList = append(etaReportAbstractViewList, v.ToView())
  167. }
  168. return
  169. }
  170. func (m *RagEtaReportAbstract) GetListByPlatformCondition(field, condition string, pars []interface{}, startSize, pageSize int) (items []*RagEtaReportAbstractItem, err error) {
  171. if field == "" {
  172. field = "*"
  173. }
  174. sqlStr := fmt.Sprintf(`SELECT %s FROM %s AS a
  175. WHERE 1=1 %s order by a.modify_time DESC,a.rag_eta_report_abstract_id DESC LIMIT ?,?`, field, m.TableName(), condition)
  176. pars = append(pars, startSize, pageSize)
  177. err = global.DbMap[utils.DbNameAI].Raw(sqlStr, pars...).Find(&items).Error
  178. return
  179. }
  180. func (m *RagEtaReportAbstract) GetCountByPlatformCondition(condition string, pars []interface{}) (total int, err error) {
  181. var intNull sql.NullInt64
  182. sqlStr := fmt.Sprintf(`SELECT COUNT(1) total FROM %s AS a
  183. WHERE 1=1 %s`, m.TableName(), condition)
  184. err = global.DbMap[utils.DbNameAI].Raw(sqlStr, pars...).Scan(&intNull).Error
  185. if err == nil && intNull.Valid {
  186. total = int(intNull.Int64)
  187. }
  188. return
  189. }
  190. func (m *RagEtaReportAbstract) GetPageListByPlatformCondition(condition string, pars []interface{}, startSize, pageSize int) (total int, items []*RagEtaReportAbstractItem, err error) {
  191. total, err = m.GetCountByPlatformCondition(condition, pars)
  192. if err != nil {
  193. return
  194. }
  195. if total > 0 {
  196. items, err = m.GetListByPlatformCondition(``, condition, pars, startSize, pageSize)
  197. }
  198. return
  199. }
  200. func (m *RagEtaReportAbstract) GetListByTagAndPlatformCondition(field, condition string, pars []interface{}, startSize, pageSize int) (items []*RagEtaReportAbstractItem, err error) {
  201. if field == "" {
  202. field = "*"
  203. }
  204. sqlStr := fmt.Sprintf(`SELECT %s FROM %s AS a
  205. JOIN wechat_article AS b ON a.rag_eta_report_id=b.rag_eta_report_id
  206. JOIN wechat_platform AS c ON b.wechat_platform_id=c.wechat_platform_id
  207. JOIN wechat_platform_tag_mapping AS d ON c.wechat_platform_id=d.wechat_platform_id
  208. WHERE 1=1 AND b.is_deleted=0 %s order by a.modify_time DESC,a.rag_eta_report_abstract_id DESC LIMIT ?,?`, field, m.TableName(), condition)
  209. pars = append(pars, startSize, pageSize)
  210. err = global.DbMap[utils.DbNameAI].Raw(sqlStr, pars...).Find(&items).Error
  211. return
  212. }
  213. func (m *RagEtaReportAbstract) GetCountByTagAndPlatformCondition(condition string, pars []interface{}) (total int, err error) {
  214. var intNull sql.NullInt64
  215. sqlStr := fmt.Sprintf(`SELECT COUNT(1) total FROM %s AS a
  216. JOIN wechat_article AS b ON a.rag_eta_report_id=b.rag_eta_report_id
  217. JOIN wechat_platform AS c ON b.wechat_platform_id=c.wechat_platform_id
  218. JOIN wechat_platform_tag_mapping AS d ON c.wechat_platform_id=d.wechat_platform_id
  219. WHERE 1=1 AND b.is_deleted=0 %s`, m.TableName(), condition)
  220. err = global.DbMap[utils.DbNameAI].Raw(sqlStr, pars...).Scan(&intNull).Error
  221. if err == nil && intNull.Valid {
  222. total = int(intNull.Int64)
  223. }
  224. return
  225. }
  226. func (m *RagEtaReportAbstract) GetPageListByTagAndPlatformCondition(condition string, pars []interface{}, startSize, pageSize int) (total int, items []*RagEtaReportAbstractItem, err error) {
  227. total, err = m.GetCountByTagAndPlatformCondition(condition, pars)
  228. if err != nil {
  229. return
  230. }
  231. if total > 0 {
  232. items, err = m.GetListByTagAndPlatformCondition(`a.rag_eta_report_abstract_id,a.rag_eta_report_id,a.content AS abstract,a.version,a.vector_key,a.modify_time,a.create_time,b.title,b.link,d.tag_id`, condition, pars, startSize, pageSize)
  233. }
  234. return
  235. }
  236. // DelVectorKey
  237. // @Description: 批量删除向量库
  238. // @author: Roc
  239. // @receiver m
  240. // @datetime 2025-03-12 16:47:52
  241. // @param ragEtaReportAbstractIdList []int
  242. // @return err error
  243. func (m *RagEtaReportAbstract) DelVectorKey(ragEtaReportAbstractIdList []int) (err error) {
  244. sqlStr := fmt.Sprintf(`UPDATE %s set vector_key = '' WHERE rag_eta_report_abstract_id IN (?)`, m.TableName())
  245. err = global.DbMap[utils.DbNameAI].Exec(sqlStr, ragEtaReportAbstractIdList).Error
  246. return
  247. }