浏览代码

fix:新增默认提示词、摘要生成等逻辑

Roc 6 天之前
父节点
当前提交
4f9eb26413

+ 39 - 2
controllers/llm/question.go

@@ -304,6 +304,7 @@ func (c *QuestionController) Edit() {
 		c.Data["json"] = br
 		c.ServeJSON()
 	}()
+
 	var req request.EditQuestionReq
 	err := json.Unmarshal(c.Ctx.Input.RequestBody, &req)
 	if err != nil {
@@ -334,6 +335,7 @@ func (c *QuestionController) Edit() {
 		}
 		return
 	}
+
 	item.QuestionTitle = utils.GetFirstNChars(req.Content, 20)
 	item.QuestionContent = req.Content
 	item.ModifyTime = time.Now()
@@ -387,6 +389,13 @@ func (c *QuestionController) Del() {
 		}
 		return
 	}
+
+	// 删除提示词:若删除默认提示词,提示:当前提示词不允许删除;若删除非默认提示词:提示删除成功(项目eta4.0,时间:2025-4-16 17:39:38)
+	if item.IsDefault == 1 {
+		br.Msg = "当前提示词不允许删除!"
+		return
+	}
+
 	err = item.Del()
 	if err != nil {
 		br.Msg = "删除失败"
@@ -501,6 +510,19 @@ func (c *QuestionController) UnSetDefault() {
 		br.IsSendEmail = false
 		return
 	}
+
+	// 如果是取消已经设置成默认的提示词,那么需要判断是否有正在生成摘要的提示词任务,如果存在的话,那么就不允许取消
+	auth, err := services.CheckOpQuestionAuth()
+	if err != nil {
+		br.Msg = "修改失败"
+		br.ErrMsg = "权限校验失败,Err:" + err.Error()
+		return
+	}
+	if !auth {
+		br.Msg = "当前有提示词正在生成摘要,请稍后再修改"
+		return
+	}
+
 	item.IsDefault = 1
 	item.GenerateStatus = `undo`
 	item.ModifyTime = time.Now()
@@ -566,6 +588,18 @@ func (c *QuestionController) GenerateAbstract() {
 		return
 	}
 
+	// 如果是需要对提示词做摘要的生成,那么需要判断是否有正在生成摘要的提示词任务,如果存在的话,那么就不允许生成(暂定,后面可以改成加到任务中去,等上一个批次的任务完成后,继续该任务)
+	auth, err := services.CheckOpQuestionAuth()
+	if err != nil {
+		br.Msg = "修改失败"
+		br.ErrMsg = "权限校验失败,Err:" + err.Error()
+		return
+	}
+	if !auth {
+		br.Msg = "当前有提示词正在生成摘要,请稍后再修改"
+		return
+	}
+
 	// 标记摘要生成状态,避免重复生成
 	item.GenerateStatus = `done`
 	item.ModifyTime = time.Now()
@@ -576,11 +610,14 @@ func (c *QuestionController) GenerateAbstract() {
 		return
 	}
 
-	// todo 对应的提示词生成的摘要库和向量库内容也取消,同时需要加锁,不允许重复操作
+	// 添加任务
+	services.AddGenerateAbstractTask(item, c.SysUser)
+
+	// todo 开始任务
 
 	br.Ret = 200
 	br.Success = true
-	br.Msg = `取消设置成功`
+	br.Msg = `摘要生成中`
 }
 
 //func init() {

+ 1 - 1
controllers/llm/report.go

@@ -68,7 +68,7 @@ func (c *RagEtaReportController) ArticleList() {
 	}
 
 	obj := new(rag.RagEtaReport)
-	tmpTotal, list, err := obj.GetPageListByCondition(condition, pars, startSize, pageSize)
+	tmpTotal, list, err := obj.GetPageListByCondition(``, condition, pars, startSize, pageSize)
 	if err != nil {
 		br.Msg = "获取失败"
 		br.ErrMsg = "获取失败,Err:" + err.Error()

+ 68 - 16
models/rag/ai_task.go

@@ -1,6 +1,7 @@
 package rag
 
 import (
+	"database/sql"
 	"eta/eta_api/global"
 	"eta/eta_api/utils"
 	"fmt"
@@ -9,22 +10,24 @@ import (
 
 // AiTask ai这边的任务表
 type AiTask struct {
-	AiTaskID                int       `gorm:"primaryKey;column:ai_task_id" json:"-"`                           // id
-	TaskName                string    `gorm:"column:task_name" json:"taskName"`                                // 任务名称
-	TaskType                string    `gorm:"column:task_type" json:"taskType"`                                // 任务类型
-	Status                  string    `gorm:"column:status" json:"status"`                                     // 任务状态
-	StartTime               time.Time `gorm:"column:start_time" json:"startTime"`                              // 开始时间
-	EndTime                 time.Time `gorm:"column:end_time" json:"endTime"`                                  // 结束时间
-	CreateTime              time.Time `gorm:"column:create_time" json:"createTime"`                            // 创建时间
-	UpdateTime              time.Time `gorm:"column:update_time" json:"updateTime"`                            // 更新时间
-	Parameters              string    `gorm:"column:parameters" json:"parameters"`                             // 执行参数
-	Logs                    string    `gorm:"column:logs" json:"logs"`                                         // 日志
-	Errormessage            string    `gorm:"column:ErrorMessage" json:"errorMessage"`                         // 错误信息
-	Priority                int       `gorm:"column:priority" json:"priority"`                                 // 优先级
-	RetryCount              int       `gorm:"column:retry_count" json:"retryCount"`                            // 重试次数
-	EstimatedCompletionTime time.Time `gorm:"column:estimated_completion_time" json:"estimatedCompletionTime"` // 预计完成时间
-	ActualCompletitonTime   time.Time `gorm:"column:actual_completiton_time" json:"actualCompletitonTime"`     // 实际完成时间
-	Remark                  string    `gorm:"column:remark" json:"remark"`                                     // 备注
+	AiTaskID                int       `gorm:"primaryKey;column:ai_task_id" description:"-"`
+	TaskName                string    `gorm:"column:task_name" description:"任务名称"`
+	TaskType                string    `gorm:"column:task_type" description:"任务类型"`
+	Status                  string    `gorm:"column:status" description:"任务状态"`
+	StartTime               time.Time `gorm:"column:start_time" description:"开始时间"`
+	EndTime                 time.Time `gorm:"column:end_time" description:"结束时间"`
+	CreateTime              time.Time `gorm:"column:create_time" description:"创建时间"`
+	UpdateTime              time.Time `gorm:"column:update_time" description:"更新时间"`
+	Parameters              string    `gorm:"column:parameters" description:"执行参数"`
+	Logs                    string    `gorm:"column:logs" description:"日志"`
+	Errormessage            string    `gorm:"column:ErrorMessage" description:"错误信息"`
+	Priority                int       `gorm:"column:priority" description:"优先级"`
+	RetryCount              int       `gorm:"column:retry_count" description:"重试次数"`
+	EstimatedCompletionTime time.Time `gorm:"column:estimated_completion_time" description:"预计完成时间"`
+	ActualCompletitonTime   time.Time `gorm:"column:actual_completiton_time" description:"实际完成时间"`
+	Remark                  string    `gorm:"column:remark" description:"备注"`
+	SysUserID               int       `gorm:"column:sys_user_id" description:"任务创建人id"`
+	SysUserRealName         string    `gorm:"column:sys_user_real_name" description:"任务创建人名称"`
 }
 
 // TableName get sql table name.获取数据库表名
@@ -50,6 +53,8 @@ var AiTaskColumns = struct {
 	EstimatedCompletionTime string
 	ActualCompletitonTime   string
 	Remark                  string
+	SysUserID               string
+	SysUserRealName         string
 }{
 	AiTaskID:                "ai_task_id",
 	TaskName:                "task_name",
@@ -67,6 +72,8 @@ var AiTaskColumns = struct {
 	EstimatedCompletionTime: "estimated_completion_time",
 	ActualCompletitonTime:   "actual_completiton_time",
 	Remark:                  "remark",
+	SysUserID:               "sys_user_id",
+	SysUserRealName:         "sys_user_real_name",
 }
 
 func (m *AiTask) Create() (err error) {
@@ -110,3 +117,48 @@ func (m *AiTask) GetListByCondition(field, condition string, pars []interface{},
 
 	return
 }
+
+func (m *AiTask) GetCountByCondition(condition string, pars []interface{}) (total int, err error) {
+	var intNull sql.NullInt64
+	sqlStr := fmt.Sprintf(`SELECT COUNT(1) total FROM %s WHERE 1=1 %s`, m.TableName(), condition)
+	err = global.DbMap[utils.DbNameAI].Raw(sqlStr, pars...).Scan(&intNull).Error
+	if err == nil && intNull.Valid {
+		total = int(intNull.Int64)
+	}
+
+	return
+}
+
+// AddAiTask
+// @Description: 添加Ai模块的任务
+// @author: Roc
+// @datetime 2025-04-16 16:55:36
+// @param aiTask *AiTask
+// @param aiRecordList []*AiTaskRecord
+// @return err error
+func AddAiTask(aiTask *AiTask, aiRecordList []*AiTaskRecord) (err error) {
+	to := global.DbMap[utils.DbNameAI].Begin()
+	defer func() {
+		if err != nil {
+			_ = to.Rollback()
+		} else {
+			_ = to.Commit()
+		}
+	}()
+
+	err = to.Create(aiTask).Error
+	if err != nil {
+		return
+	}
+
+	for _, aiTaskRecord := range aiRecordList {
+		aiTaskRecord.AiTaskID = aiTask.AiTaskID
+	}
+
+	err = to.CreateInBatches(aiRecordList, utils.MultiAddNum).Error
+	if err != nil {
+		return
+	}
+
+	return
+}

+ 8 - 0
models/rag/ai_task_record.go

@@ -83,3 +83,11 @@ func (m *AiTaskRecord) GetListByCondition(field, condition string, pars []interf
 
 	return
 }
+
+// QuestionGenerateAbstractParam
+// @Description:
+type QuestionGenerateAbstractParam struct {
+	QuestionId  int    `json:"questionId"`
+	ArticleType string `json:"articleType"`
+	ArticleId   int    `json:"articleId"`
+}

+ 2 - 2
models/rag/rag_eta_report.go

@@ -154,14 +154,14 @@ func (m *RagEtaReport) GetCountByCondition(condition string, pars []interface{})
 	return
 }
 
-func (m *RagEtaReport) GetPageListByCondition(condition string, pars []interface{}, startSize, pageSize int) (total int, items []*RagEtaReport, err error) {
+func (m *RagEtaReport) GetPageListByCondition(field, condition string, pars []interface{}, startSize, pageSize int) (total int, items []*RagEtaReport, err error) {
 
 	total, err = m.GetCountByCondition(condition, pars)
 	if err != nil {
 		return
 	}
 	if total > 0 {
-		items, err = m.GetListByCondition(``, condition, pars, startSize, pageSize)
+		items, err = m.GetListByCondition(field, condition, pars, startSize, pageSize)
 	}
 
 	return

+ 27 - 0
routers/commentsRouter.go

@@ -8719,6 +8719,15 @@ func init() {
             Filters: nil,
             Params: nil})
 
+    beego.GlobalControllerRouter["eta/eta_api/controllers/llm:QuestionController"] = append(beego.GlobalControllerRouter["eta/eta_api/controllers/llm:QuestionController"],
+        beego.ControllerComments{
+            Method: "GenerateAbstract",
+            Router: `/question/abstract/generate`,
+            AllowHTTPMethods: []string{"post"},
+            MethodParams: param.Make(),
+            Filters: nil,
+            Params: nil})
+
     beego.GlobalControllerRouter["eta/eta_api/controllers/llm:QuestionController"] = append(beego.GlobalControllerRouter["eta/eta_api/controllers/llm:QuestionController"],
         beego.ControllerComments{
             Method: "Add",
@@ -8728,6 +8737,24 @@ func init() {
             Filters: nil,
             Params: nil})
 
+    beego.GlobalControllerRouter["eta/eta_api/controllers/llm:QuestionController"] = append(beego.GlobalControllerRouter["eta/eta_api/controllers/llm:QuestionController"],
+        beego.ControllerComments{
+            Method: "SetDefault",
+            Router: `/question/default/set`,
+            AllowHTTPMethods: []string{"post"},
+            MethodParams: param.Make(),
+            Filters: nil,
+            Params: nil})
+
+    beego.GlobalControllerRouter["eta/eta_api/controllers/llm:QuestionController"] = append(beego.GlobalControllerRouter["eta/eta_api/controllers/llm:QuestionController"],
+        beego.ControllerComments{
+            Method: "UnSetDefault",
+            Router: `/question/default/unset`,
+            AllowHTTPMethods: []string{"post"},
+            MethodParams: param.Make(),
+            Filters: nil,
+            Params: nil})
+
     beego.GlobalControllerRouter["eta/eta_api/controllers/llm:QuestionController"] = append(beego.GlobalControllerRouter["eta/eta_api/controllers/llm:QuestionController"],
         beego.ControllerComments{
             Method: "Del",

+ 290 - 0
services/llm.go

@@ -0,0 +1,290 @@
+package services
+
+import (
+	"encoding/json"
+	"eta/eta_api/models/rag"
+	"eta/eta_api/models/system"
+	"eta/eta_api/utils"
+	"fmt"
+	"time"
+)
+
+// AddGenerateAbstractTask
+// @Description: 添加全部报告(微信文章/ETA报告)生成摘要任务
+// @author: Roc
+// @datetime 2025-04-16 17:02:18
+// @param question *rag.Question
+// @param sysUser *system.Admin
+func AddGenerateAbstractTask(question *rag.Question, sysUser *system.Admin) {
+	// 找出所有公众号文章Id
+	wechatArticleIdList, err := getAllWechatArticleIdList()
+	if err != nil {
+		return
+	}
+
+	// 找出所有Eta报告
+	ragEtaReportIdList, err := getAllEtaReportIdList()
+	if err != nil {
+		return
+	}
+
+	taskName := fmt.Sprintf("自动生成摘要%s-%s", time.Now().Format(utils.FormatShortDateTimeUnSpace), question.QuestionTitle)
+
+	aiTask := &rag.AiTask{
+		AiTaskID: 0,
+		TaskName: taskName,
+		TaskType: utils.AI_TASK_TYPE_GENERATE_ABSTRACT,
+		Status:   "init",
+		//StartTime:               time.Time{},
+		//EndTime:                 time.Time{},
+		CreateTime:   time.Now(),
+		UpdateTime:   time.Now(),
+		Parameters:   fmt.Sprint(question.QuestionId),
+		Logs:         "",
+		Errormessage: "",
+		Priority:     0,
+		RetryCount:   0,
+		//EstimatedCompletionTime: time.Time{},
+		//ActualCompletitonTime:   time.Time{},
+		Remark:          "",
+		SysUserID:       sysUser.AdminId,
+		SysUserRealName: sysUser.RealName,
+	}
+
+	taskRecordList := make([]*rag.AiTaskRecord, 0)
+	// 微信文章
+	for _, wechatArticleId := range wechatArticleIdList {
+		param := rag.QuestionGenerateAbstractParam{
+			QuestionId:  question.QuestionId,
+			ArticleType: `wechat_article`,
+			ArticleId:   wechatArticleId,
+		}
+		paramByte, tmpErr := json.Marshal(param)
+		if tmpErr != nil {
+			return
+		}
+		taskRecord := &rag.AiTaskRecord{
+			AiTaskRecordID: 0,
+			AiTaskID:       0,
+			Parameters:     string(paramByte),
+			Status:         "待处理",
+			Remark:         "",
+			ModifyTime:     time.Now(),
+			CreateTime:     time.Now(),
+		}
+		taskRecordList = append(taskRecordList, taskRecord)
+	}
+
+	// eta报告
+	for _, ragEtaReportId := range ragEtaReportIdList {
+		param := rag.QuestionGenerateAbstractParam{
+			QuestionId:  question.QuestionId,
+			ArticleType: `rag_eta_report`,
+			ArticleId:   ragEtaReportId,
+		}
+		paramByte, tmpErr := json.Marshal(param)
+		if tmpErr != nil {
+			return
+		}
+		taskRecord := &rag.AiTaskRecord{
+			AiTaskRecordID: 0,
+			AiTaskID:       0,
+			Parameters:     string(paramByte),
+			Status:         "待处理",
+			Remark:         "",
+			ModifyTime:     time.Now(),
+			CreateTime:     time.Now(),
+		}
+		taskRecordList = append(taskRecordList, taskRecord)
+	}
+
+	// 创建AI模块的任务,用于后面的任务调度去生成摘要
+	err = rag.AddAiTask(aiTask, taskRecordList)
+	if err != nil {
+		return
+	}
+
+	return
+}
+
+// getAllWechatArticleIdList
+// @Description: 获取所有的微信文章Id列表
+// @author: Roc
+// @datetime 2025-04-16 17:18:31
+// @return wechatArticleIdList []int
+// @return err error
+func getAllWechatArticleIdList() (wechatArticleIdList []int, err error) {
+	wechatArticleIdList = make([]int, 0)
+	pageSize := 10000
+	currentIndex := 1
+
+	// 注意,默认是10000条,如果超过10000条,需要分页查询
+	// 避免死循环
+	for {
+		tmpWechatArticleIdList, tmpErr := getWechatArticleIdList(currentIndex, pageSize)
+		if tmpErr != nil {
+			return
+		}
+		wechatArticleIdList = append(wechatArticleIdList, tmpWechatArticleIdList...)
+		if len(tmpWechatArticleIdList) < pageSize {
+			return
+		}
+		currentIndex++
+
+		// 超过100次,那么也退出,避免死循环
+		if currentIndex > 100 {
+			return
+		}
+	}
+
+}
+
+// getWechatArticleIdList
+// @Description: 分页获取微信文章Id列表
+// @author: Roc
+// @datetime 2025-04-16 17:18:44
+// @param currentIndex int
+// @param pageSize int
+// @return wechatArticleIdList []int
+// @return err error
+func getWechatArticleIdList(currentIndex, pageSize int) (wechatArticleIdList []int, err error) {
+	wechatArticleIdList = make([]int, 0)
+	var condition string
+	var pars []interface{}
+
+	var startSize int
+	if pageSize <= 0 {
+		pageSize = utils.PageSize20
+	}
+	if currentIndex <= 0 {
+		currentIndex = 1
+	}
+	startSize = utils.StartIndex(currentIndex, pageSize)
+
+	condition += fmt.Sprintf(` AND %s = ? `, rag.WechatArticleColumns.IsDeleted)
+	pars = append(pars, 0, 1)
+
+	obj := new(rag.WechatArticle)
+	list, err := obj.GetListByCondition(` wechat_article_id `, condition, pars, startSize, pageSize)
+	if err != nil {
+		return
+	}
+	for _, item := range list {
+		wechatArticleIdList = append(wechatArticleIdList, item.WechatArticleId)
+	}
+
+	return
+}
+
+// getAllEtaReportIdList
+// @Description: 获取所有的eta报告Id列表
+// @author: Roc
+// @datetime 2025-04-16 17:19:29
+// @return ragEtaReportIdList []int
+// @return err error
+func getAllEtaReportIdList() (ragEtaReportIdList []int, err error) {
+	ragEtaReportIdList = make([]int, 0)
+	pageSize := 10000
+	currentIndex := 1
+
+	// 注意,默认是10000条,如果超过10000条,需要分页查询
+	// 避免死循环
+	for {
+		tmpRagEtaReportIdList, tmpErr := getEtaReportIdList(currentIndex, pageSize)
+		if tmpErr != nil {
+			return
+		}
+		ragEtaReportIdList = append(ragEtaReportIdList, tmpRagEtaReportIdList...)
+		if len(tmpRagEtaReportIdList) < pageSize {
+			return
+		}
+		currentIndex++
+
+		// 超过100次,那么也退出,避免死循环
+		if currentIndex > 100 {
+			return
+		}
+	}
+
+}
+
+// getEtaReportIdList
+// @Description: 分页获取eta报告Id列表
+// @author: Roc
+// @datetime 2025-04-16 17:19:14
+// @param currentIndex int
+// @param pageSize int
+// @return ragEtaReportIdList []int
+// @return err error
+func getEtaReportIdList(currentIndex, pageSize int) (ragEtaReportIdList []int, err error) {
+	ragEtaReportIdList = make([]int, 0)
+	var condition string
+	var pars []interface{}
+
+	var startSize int
+	if pageSize <= 0 {
+		pageSize = utils.PageSize20
+	}
+	if currentIndex <= 0 {
+		currentIndex = 1
+	}
+	startSize = utils.StartIndex(currentIndex, pageSize)
+
+	condition += fmt.Sprintf(` AND %s = ? AND %s = ? `, rag.RagEtaReportColumns.IsDeleted, rag.RagEtaReportColumns.IsPublished)
+	pars = append(pars, 0, 1)
+
+	obj := new(rag.RagEtaReport)
+	list, err := obj.GetListByCondition(` rag_eta_report_id `, condition, pars, startSize, pageSize)
+	if err != nil {
+		return
+	}
+	for _, item := range list {
+		ragEtaReportIdList = append(ragEtaReportIdList, item.RagEtaReportId)
+	}
+
+	return
+}
+
+// CheckOpQuestionAuth
+// @Description: 校验是否有权限操作提示词
+// @author: Roc
+// @datetime 2025-04-16 17:33:01
+// @return auth bool
+// @return err error
+func CheckOpQuestionAuth() (auth bool, err error) {
+	total, err := getNotFinishGenerateAbstractTaskNum()
+	if err != nil {
+		return
+	}
+	// 存在未完成的任务,则无权限
+	if total > 0 {
+		return
+	}
+
+	auth = true
+
+	return
+}
+
+// getNotFinishGenerateAbstractTaskNum
+// @Description: 获取未完成的生成摘要任务的数量
+// @author: Roc
+// @datetime 2025-04-16 17:31:12
+// @return total int
+// @return err error
+func getNotFinishGenerateAbstractTaskNum() (total int, err error) {
+	obj := rag.AiTask{}
+
+	var condition string
+	var pars []interface{}
+
+	condition += fmt.Sprintf(` AND %s NOT IN (?)  AND %s = ? `, rag.AiTaskColumns.Status, rag.AiTaskColumns.TaskType)
+	pars = append(pars, []string{`done`, `failed`}, utils.AI_TASK_TYPE_GENERATE_ABSTRACT)
+
+	total, err = obj.GetCountByCondition(condition, pars)
+	if err != nil {
+		return
+	}
+
+	return
+}

+ 5 - 0
utils/constants.go

@@ -583,6 +583,7 @@ const (
 const (
 	FICC_ARTICLE_UPDATE_KEY = "FICC_ARTICLE_UPDATE_KEY" //权益报告通知给FICC这边的缓存key
 )
+
 // 图表分类设置精选资源分类
 const (
 	ChartClassifyIsSelected            = 1 // 图表分类设置精选资源分类
@@ -597,3 +598,7 @@ const (
 const (
 	DATA_SOURCE_NAME_RADISH_RESEARCH = "萝卜投研" // 萝卜投研 -> 105
 )
+
+const (
+	AI_TASK_TYPE_GENERATE_ABSTRACT = `question_generate_abstract` // AI任务去批量生成摘要
+)