llm.go 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. package services
  2. import (
  3. "encoding/json"
  4. "eta/eta_api/cache"
  5. "eta/eta_api/models/rag"
  6. "eta/eta_api/models/system"
  7. "eta/eta_api/utils"
  8. "fmt"
  9. "time"
  10. )
  11. // AddGenerateAbstractTask
  12. // @Description: 添加全部报告(微信文章/ETA报告)生成摘要任务
  13. // @author: Roc
  14. // @datetime 2025-04-16 17:02:18
  15. // @param question *rag.Question
  16. // @param sysUser *system.Admin
  17. func AddGenerateAbstractTask(question *rag.Question, sysUser *system.Admin) {
  18. // 找出所有公众号文章Id
  19. wechatArticleIdList, err := getAllWechatArticleIdList()
  20. if err != nil {
  21. return
  22. }
  23. // 找出所有Eta报告
  24. ragEtaReportIdList, err := getAllEtaReportIdList()
  25. if err != nil {
  26. return
  27. }
  28. taskName := fmt.Sprintf("自动生成摘要%s-%s", time.Now().Format(utils.FormatShortDateTimeUnSpace), question.QuestionTitle)
  29. aiTask := &rag.AiTask{
  30. AiTaskID: 0,
  31. TaskName: taskName,
  32. TaskType: utils.AI_TASK_TYPE_GENERATE_ABSTRACT,
  33. Status: "init",
  34. //StartTime: time.Time{},
  35. //EndTime: time.Time{},
  36. CreateTime: time.Now(),
  37. UpdateTime: time.Now(),
  38. Parameters: fmt.Sprint(question.QuestionId),
  39. Logs: "",
  40. Errormessage: "",
  41. Priority: 0,
  42. RetryCount: 0,
  43. //EstimatedCompletionTime: time.Time{},
  44. //ActualCompletitonTime: time.Time{},
  45. Remark: "",
  46. SysUserID: sysUser.AdminId,
  47. SysUserRealName: sysUser.RealName,
  48. }
  49. taskRecordList := make([]*rag.AiTaskRecord, 0)
  50. // 微信文章
  51. for _, wechatArticleId := range wechatArticleIdList {
  52. param := rag.QuestionGenerateAbstractParam{
  53. QuestionId: question.QuestionId,
  54. ArticleType: `wechat_article`,
  55. ArticleId: wechatArticleId,
  56. }
  57. paramByte, tmpErr := json.Marshal(param)
  58. if tmpErr != nil {
  59. return
  60. }
  61. taskRecord := &rag.AiTaskRecord{
  62. AiTaskRecordID: 0,
  63. AiTaskID: 0,
  64. Parameters: string(paramByte),
  65. Status: "待处理",
  66. Remark: "",
  67. ModifyTime: time.Now(),
  68. CreateTime: time.Now(),
  69. }
  70. taskRecordList = append(taskRecordList, taskRecord)
  71. }
  72. // eta报告
  73. for _, ragEtaReportId := range ragEtaReportIdList {
  74. param := rag.QuestionGenerateAbstractParam{
  75. QuestionId: question.QuestionId,
  76. ArticleType: `rag_eta_report`,
  77. ArticleId: ragEtaReportId,
  78. }
  79. paramByte, tmpErr := json.Marshal(param)
  80. if tmpErr != nil {
  81. return
  82. }
  83. taskRecord := &rag.AiTaskRecord{
  84. AiTaskRecordID: 0,
  85. AiTaskID: 0,
  86. Parameters: string(paramByte),
  87. Status: "待处理",
  88. Remark: "",
  89. ModifyTime: time.Now(),
  90. CreateTime: time.Now(),
  91. }
  92. taskRecordList = append(taskRecordList, taskRecord)
  93. }
  94. // 创建AI模块的任务,用于后面的任务调度去生成摘要
  95. err = rag.AddAiTask(aiTask, taskRecordList)
  96. if err != nil {
  97. return
  98. }
  99. // 添加到缓存队列中
  100. go addTaskToCache(aiTask.AiTaskID)
  101. return
  102. }
  103. func addTaskToCache(aiTaskId int) {
  104. var err error
  105. defer func() {
  106. if err != nil {
  107. utils.FileLog.Error("addTaskToCache error: %v", err)
  108. }
  109. }()
  110. obj := rag.AiTaskRecord{}
  111. list, err := obj.GetAllListByCondition("*", ` AND ai_task_id = ? `, []interface{}{aiTaskId})
  112. if err != nil {
  113. return
  114. }
  115. for _, item := range list {
  116. cache.AddAiTaskRecordOpToCache(item.AiTaskRecordID)
  117. }
  118. }
  119. // getAllWechatArticleIdList
  120. // @Description: 获取所有的微信文章Id列表
  121. // @author: Roc
  122. // @datetime 2025-04-16 17:18:31
  123. // @return wechatArticleIdList []int
  124. // @return err error
  125. func getAllWechatArticleIdList() (wechatArticleIdList []int, err error) {
  126. wechatArticleIdList = make([]int, 0)
  127. pageSize := 10000
  128. currentIndex := 1
  129. // 注意,默认是10000条,如果超过10000条,需要分页查询
  130. // 避免死循环
  131. for {
  132. tmpWechatArticleIdList, tmpErr := getWechatArticleIdList(currentIndex, pageSize)
  133. if tmpErr != nil {
  134. return
  135. }
  136. wechatArticleIdList = append(wechatArticleIdList, tmpWechatArticleIdList...)
  137. if len(tmpWechatArticleIdList) < pageSize {
  138. return
  139. }
  140. currentIndex++
  141. // 超过100次,那么也退出,避免死循环
  142. if currentIndex > 100 {
  143. return
  144. }
  145. }
  146. }
  147. // getWechatArticleIdList
  148. // @Description: 分页获取微信文章Id列表
  149. // @author: Roc
  150. // @datetime 2025-04-16 17:18:44
  151. // @param currentIndex int
  152. // @param pageSize int
  153. // @return wechatArticleIdList []int
  154. // @return err error
  155. func getWechatArticleIdList(currentIndex, pageSize int) (wechatArticleIdList []int, err error) {
  156. wechatArticleIdList = make([]int, 0)
  157. var condition string
  158. var pars []interface{}
  159. var startSize int
  160. if pageSize <= 0 {
  161. pageSize = utils.PageSize20
  162. }
  163. if currentIndex <= 0 {
  164. currentIndex = 1
  165. }
  166. startSize = utils.StartIndex(currentIndex, pageSize)
  167. condition += fmt.Sprintf(` AND %s = ? `, rag.WechatArticleColumns.IsDeleted)
  168. pars = append(pars, 0)
  169. obj := new(rag.WechatArticle)
  170. list, err := obj.GetListByCondition(` wechat_article_id `, condition, pars, startSize, pageSize)
  171. if err != nil {
  172. return
  173. }
  174. for _, item := range list {
  175. wechatArticleIdList = append(wechatArticleIdList, item.WechatArticleId)
  176. }
  177. return
  178. }
  179. // getAllEtaReportIdList
  180. // @Description: 获取所有的eta报告Id列表
  181. // @author: Roc
  182. // @datetime 2025-04-16 17:19:29
  183. // @return ragEtaReportIdList []int
  184. // @return err error
  185. func getAllEtaReportIdList() (ragEtaReportIdList []int, err error) {
  186. ragEtaReportIdList = make([]int, 0)
  187. pageSize := 10000
  188. currentIndex := 1
  189. // 注意,默认是10000条,如果超过10000条,需要分页查询
  190. // 避免死循环
  191. for {
  192. tmpRagEtaReportIdList, tmpErr := getEtaReportIdList(currentIndex, pageSize)
  193. if tmpErr != nil {
  194. return
  195. }
  196. ragEtaReportIdList = append(ragEtaReportIdList, tmpRagEtaReportIdList...)
  197. if len(tmpRagEtaReportIdList) < pageSize {
  198. return
  199. }
  200. currentIndex++
  201. // 超过100次,那么也退出,避免死循环
  202. if currentIndex > 100 {
  203. return
  204. }
  205. }
  206. }
  207. // getEtaReportIdList
  208. // @Description: 分页获取eta报告Id列表
  209. // @author: Roc
  210. // @datetime 2025-04-16 17:19:14
  211. // @param currentIndex int
  212. // @param pageSize int
  213. // @return ragEtaReportIdList []int
  214. // @return err error
  215. func getEtaReportIdList(currentIndex, pageSize int) (ragEtaReportIdList []int, err error) {
  216. ragEtaReportIdList = make([]int, 0)
  217. var condition string
  218. var pars []interface{}
  219. var startSize int
  220. if pageSize <= 0 {
  221. pageSize = utils.PageSize20
  222. }
  223. if currentIndex <= 0 {
  224. currentIndex = 1
  225. }
  226. startSize = utils.StartIndex(currentIndex, pageSize)
  227. condition += fmt.Sprintf(` AND %s = ? AND %s = ? `, rag.RagEtaReportColumns.IsDeleted, rag.RagEtaReportColumns.IsPublished)
  228. pars = append(pars, 0, 1)
  229. obj := new(rag.RagEtaReport)
  230. list, err := obj.GetListByCondition(` rag_eta_report_id `, condition, pars, startSize, pageSize)
  231. if err != nil {
  232. return
  233. }
  234. for _, item := range list {
  235. ragEtaReportIdList = append(ragEtaReportIdList, item.RagEtaReportId)
  236. }
  237. return
  238. }
  239. // CheckOpQuestionAuth
  240. // @Description: 校验是否有权限操作提示词
  241. // @author: Roc
  242. // @datetime 2025-04-16 17:33:01
  243. // @return auth bool
  244. // @return err error
  245. func CheckOpQuestionAuth() (auth bool, err error) {
  246. total, err := getNotFinishGenerateAbstractTaskNum()
  247. if err != nil {
  248. return
  249. }
  250. // 存在未完成的任务,则无权限
  251. if total > 0 {
  252. return
  253. }
  254. auth = true
  255. return
  256. }
  257. // getNotFinishGenerateAbstractTaskNum
  258. // @Description: 获取未完成的生成摘要任务的数量
  259. // @author: Roc
  260. // @datetime 2025-04-16 17:31:12
  261. // @return total int
  262. // @return err error
  263. func getNotFinishGenerateAbstractTaskNum() (total int, err error) {
  264. obj := rag.AiTask{}
  265. var condition string
  266. var pars []interface{}
  267. condition += fmt.Sprintf(` AND %s NOT IN (?) AND %s = ? `, rag.AiTaskColumns.Status, rag.AiTaskColumns.TaskType)
  268. pars = append(pars, []string{`done`, `failed`}, utils.AI_TASK_TYPE_GENERATE_ABSTRACT)
  269. total, err = obj.GetCountByCondition(condition, pars)
  270. if err != nil {
  271. return
  272. }
  273. return
  274. }
  275. // GetNotFinishGenerateAbstractTaskNumByQuestionId
  276. // @Description: 根据提示词ID获取未完成的生成摘要任务的数量
  277. // @author: Roc
  278. // @datetime 2025-04-16 17:31:12
  279. // @return total int
  280. // @return err error
  281. func GetNotFinishGenerateAbstractTaskNumByQuestionId(questionId int) (total int, err error) {
  282. obj := rag.AiTask{}
  283. var condition string
  284. var pars []interface{}
  285. condition += fmt.Sprintf(` AND %s NOT IN (?) AND %s = ? AND %s = ? `, rag.AiTaskColumns.Status, rag.AiTaskColumns.TaskType, rag.AiTaskColumns.Parameters)
  286. pars = append(pars, []string{`done`, `failed`}, utils.AI_TASK_TYPE_GENERATE_ABSTRACT, fmt.Sprint(questionId))
  287. total, err = obj.GetCountByCondition(condition, pars)
  288. if err != nil {
  289. return
  290. }
  291. return
  292. }