llm.go 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. package services
  2. import (
  3. "encoding/json"
  4. "eta/eta_api/models/rag"
  5. "eta/eta_api/models/system"
  6. "eta/eta_api/utils"
  7. "fmt"
  8. "time"
  9. )
  10. // AddGenerateAbstractTask
  11. // @Description: 添加全部报告(微信文章/ETA报告)生成摘要任务
  12. // @author: Roc
  13. // @datetime 2025-04-16 17:02:18
  14. // @param question *rag.Question
  15. // @param sysUser *system.Admin
  16. func AddGenerateAbstractTask(question *rag.Question, sysUser *system.Admin) {
  17. // 找出所有公众号文章Id
  18. wechatArticleIdList, err := getAllWechatArticleIdList()
  19. if err != nil {
  20. return
  21. }
  22. // 找出所有Eta报告
  23. ragEtaReportIdList, err := getAllEtaReportIdList()
  24. if err != nil {
  25. return
  26. }
  27. taskName := fmt.Sprintf("自动生成摘要%s-%s", time.Now().Format(utils.FormatShortDateTimeUnSpace), question.QuestionTitle)
  28. aiTask := &rag.AiTask{
  29. AiTaskID: 0,
  30. TaskName: taskName,
  31. TaskType: utils.AI_TASK_TYPE_GENERATE_ABSTRACT,
  32. Status: "init",
  33. //StartTime: time.Time{},
  34. //EndTime: time.Time{},
  35. CreateTime: time.Now(),
  36. UpdateTime: time.Now(),
  37. Parameters: fmt.Sprint(question.QuestionId),
  38. Logs: "",
  39. Errormessage: "",
  40. Priority: 0,
  41. RetryCount: 0,
  42. //EstimatedCompletionTime: time.Time{},
  43. //ActualCompletitonTime: time.Time{},
  44. Remark: "",
  45. SysUserID: sysUser.AdminId,
  46. SysUserRealName: sysUser.RealName,
  47. }
  48. taskRecordList := make([]*rag.AiTaskRecord, 0)
  49. // 微信文章
  50. for _, wechatArticleId := range wechatArticleIdList {
  51. param := rag.QuestionGenerateAbstractParam{
  52. QuestionId: question.QuestionId,
  53. ArticleType: `wechat_article`,
  54. ArticleId: wechatArticleId,
  55. }
  56. paramByte, tmpErr := json.Marshal(param)
  57. if tmpErr != nil {
  58. return
  59. }
  60. taskRecord := &rag.AiTaskRecord{
  61. AiTaskRecordID: 0,
  62. AiTaskID: 0,
  63. Parameters: string(paramByte),
  64. Status: "待处理",
  65. Remark: "",
  66. ModifyTime: time.Now(),
  67. CreateTime: time.Now(),
  68. }
  69. taskRecordList = append(taskRecordList, taskRecord)
  70. }
  71. // eta报告
  72. for _, ragEtaReportId := range ragEtaReportIdList {
  73. param := rag.QuestionGenerateAbstractParam{
  74. QuestionId: question.QuestionId,
  75. ArticleType: `rag_eta_report`,
  76. ArticleId: ragEtaReportId,
  77. }
  78. paramByte, tmpErr := json.Marshal(param)
  79. if tmpErr != nil {
  80. return
  81. }
  82. taskRecord := &rag.AiTaskRecord{
  83. AiTaskRecordID: 0,
  84. AiTaskID: 0,
  85. Parameters: string(paramByte),
  86. Status: "待处理",
  87. Remark: "",
  88. ModifyTime: time.Now(),
  89. CreateTime: time.Now(),
  90. }
  91. taskRecordList = append(taskRecordList, taskRecord)
  92. }
  93. // 创建AI模块的任务,用于后面的任务调度去生成摘要
  94. err = rag.AddAiTask(aiTask, taskRecordList)
  95. if err != nil {
  96. return
  97. }
  98. return
  99. }
  100. // getAllWechatArticleIdList
  101. // @Description: 获取所有的微信文章Id列表
  102. // @author: Roc
  103. // @datetime 2025-04-16 17:18:31
  104. // @return wechatArticleIdList []int
  105. // @return err error
  106. func getAllWechatArticleIdList() (wechatArticleIdList []int, err error) {
  107. wechatArticleIdList = make([]int, 0)
  108. pageSize := 10000
  109. currentIndex := 1
  110. // 注意,默认是10000条,如果超过10000条,需要分页查询
  111. // 避免死循环
  112. for {
  113. tmpWechatArticleIdList, tmpErr := getWechatArticleIdList(currentIndex, pageSize)
  114. if tmpErr != nil {
  115. return
  116. }
  117. wechatArticleIdList = append(wechatArticleIdList, tmpWechatArticleIdList...)
  118. if len(tmpWechatArticleIdList) < pageSize {
  119. return
  120. }
  121. currentIndex++
  122. // 超过100次,那么也退出,避免死循环
  123. if currentIndex > 100 {
  124. return
  125. }
  126. }
  127. }
  128. // getWechatArticleIdList
  129. // @Description: 分页获取微信文章Id列表
  130. // @author: Roc
  131. // @datetime 2025-04-16 17:18:44
  132. // @param currentIndex int
  133. // @param pageSize int
  134. // @return wechatArticleIdList []int
  135. // @return err error
  136. func getWechatArticleIdList(currentIndex, pageSize int) (wechatArticleIdList []int, err error) {
  137. wechatArticleIdList = make([]int, 0)
  138. var condition string
  139. var pars []interface{}
  140. var startSize int
  141. if pageSize <= 0 {
  142. pageSize = utils.PageSize20
  143. }
  144. if currentIndex <= 0 {
  145. currentIndex = 1
  146. }
  147. startSize = utils.StartIndex(currentIndex, pageSize)
  148. condition += fmt.Sprintf(` AND %s = ? `, rag.WechatArticleColumns.IsDeleted)
  149. pars = append(pars, 0, 1)
  150. obj := new(rag.WechatArticle)
  151. list, err := obj.GetListByCondition(` wechat_article_id `, condition, pars, startSize, pageSize)
  152. if err != nil {
  153. return
  154. }
  155. for _, item := range list {
  156. wechatArticleIdList = append(wechatArticleIdList, item.WechatArticleId)
  157. }
  158. return
  159. }
  160. // getAllEtaReportIdList
  161. // @Description: 获取所有的eta报告Id列表
  162. // @author: Roc
  163. // @datetime 2025-04-16 17:19:29
  164. // @return ragEtaReportIdList []int
  165. // @return err error
  166. func getAllEtaReportIdList() (ragEtaReportIdList []int, err error) {
  167. ragEtaReportIdList = make([]int, 0)
  168. pageSize := 10000
  169. currentIndex := 1
  170. // 注意,默认是10000条,如果超过10000条,需要分页查询
  171. // 避免死循环
  172. for {
  173. tmpRagEtaReportIdList, tmpErr := getEtaReportIdList(currentIndex, pageSize)
  174. if tmpErr != nil {
  175. return
  176. }
  177. ragEtaReportIdList = append(ragEtaReportIdList, tmpRagEtaReportIdList...)
  178. if len(tmpRagEtaReportIdList) < pageSize {
  179. return
  180. }
  181. currentIndex++
  182. // 超过100次,那么也退出,避免死循环
  183. if currentIndex > 100 {
  184. return
  185. }
  186. }
  187. }
  188. // getEtaReportIdList
  189. // @Description: 分页获取eta报告Id列表
  190. // @author: Roc
  191. // @datetime 2025-04-16 17:19:14
  192. // @param currentIndex int
  193. // @param pageSize int
  194. // @return ragEtaReportIdList []int
  195. // @return err error
  196. func getEtaReportIdList(currentIndex, pageSize int) (ragEtaReportIdList []int, err error) {
  197. ragEtaReportIdList = make([]int, 0)
  198. var condition string
  199. var pars []interface{}
  200. var startSize int
  201. if pageSize <= 0 {
  202. pageSize = utils.PageSize20
  203. }
  204. if currentIndex <= 0 {
  205. currentIndex = 1
  206. }
  207. startSize = utils.StartIndex(currentIndex, pageSize)
  208. condition += fmt.Sprintf(` AND %s = ? AND %s = ? `, rag.RagEtaReportColumns.IsDeleted, rag.RagEtaReportColumns.IsPublished)
  209. pars = append(pars, 0, 1)
  210. obj := new(rag.RagEtaReport)
  211. list, err := obj.GetListByCondition(` rag_eta_report_id `, condition, pars, startSize, pageSize)
  212. if err != nil {
  213. return
  214. }
  215. for _, item := range list {
  216. ragEtaReportIdList = append(ragEtaReportIdList, item.RagEtaReportId)
  217. }
  218. return
  219. }
  220. // CheckOpQuestionAuth
  221. // @Description: 校验是否有权限操作提示词
  222. // @author: Roc
  223. // @datetime 2025-04-16 17:33:01
  224. // @return auth bool
  225. // @return err error
  226. func CheckOpQuestionAuth() (auth bool, err error) {
  227. total, err := getNotFinishGenerateAbstractTaskNum()
  228. if err != nil {
  229. return
  230. }
  231. // 存在未完成的任务,则无权限
  232. if total > 0 {
  233. return
  234. }
  235. auth = true
  236. return
  237. }
  238. // getNotFinishGenerateAbstractTaskNum
  239. // @Description: 获取未完成的生成摘要任务的数量
  240. // @author: Roc
  241. // @datetime 2025-04-16 17:31:12
  242. // @return total int
  243. // @return err error
  244. func getNotFinishGenerateAbstractTaskNum() (total int, err error) {
  245. obj := rag.AiTask{}
  246. var condition string
  247. var pars []interface{}
  248. condition += fmt.Sprintf(` AND %s NOT IN (?) AND %s = ? `, rag.AiTaskColumns.Status, rag.AiTaskColumns.TaskType)
  249. pars = append(pars, []string{`done`, `failed`}, utils.AI_TASK_TYPE_GENERATE_ABSTRACT)
  250. total, err = obj.GetCountByCondition(condition, pars)
  251. if err != nil {
  252. return
  253. }
  254. return
  255. }
  256. // GetNotFinishGenerateAbstractTaskNumByQuestionId
  257. // @Description: 根据提示词ID获取未完成的生成摘要任务的数量
  258. // @author: Roc
  259. // @datetime 2025-04-16 17:31:12
  260. // @return total int
  261. // @return err error
  262. func GetNotFinishGenerateAbstractTaskNumByQuestionId(questionId int) (total int, err error) {
  263. obj := rag.AiTask{}
  264. var condition string
  265. var pars []interface{}
  266. condition += fmt.Sprintf(` AND %s NOT IN (?) AND %s = ? AND %s = ? `, rag.AiTaskColumns.Status, rag.AiTaskColumns.TaskType, rag.AiTaskColumns.Parameters)
  267. pars = append(pars, []string{`done`, `failed`}, utils.AI_TASK_TYPE_GENERATE_ABSTRACT, fmt.Sprint(questionId))
  268. total, err = obj.GetCountByCondition(condition, pars)
  269. if err != nil {
  270. return
  271. }
  272. return
  273. }