Roc 1 周之前
父節點
當前提交
1b524df36d
共有 4 個文件被更改,包括 413 次插入20 次删除
  1. 201 20
      controllers/llm/question.go
  2. 112 0
      models/rag/ai_task.go
  3. 85 0
      models/rag/ai_task_record.go
  4. 15 0
      models/rag/question.go

+ 201 - 20
controllers/llm/question.go

@@ -17,7 +17,7 @@ import (
 )
 
 // QuestionController
-// @Description: 问题库管理
+// @Description: 提示词库管理
 type QuestionController struct {
 	controllers.BaseAuthController
 }
@@ -197,7 +197,7 @@ func (c *QuestionController) TitleList() {
 // Detail
 // @Title 列表
 // @Description 列表
-// @Param   QuestionId   query   int  true       "问题id"
+// @Param   QuestionId   query   int  true       "提示词id"
 // @Success 200 {object} []*rag.QuestionListListResp
 // @router /question/detail [get]
 func (c *QuestionController) Detail() {
@@ -215,8 +215,8 @@ func (c *QuestionController) Detail() {
 	}
 	questionId, _ := c.GetInt("QuestionId")
 	if questionId <= 0 {
-		br.Msg = "问题id不能为空"
-		br.ErrMsg = "问题id不能为空"
+		br.Msg = "提示词id不能为空"
+		br.ErrMsg = "提示词id不能为空"
 		return
 	}
 
@@ -236,8 +236,8 @@ func (c *QuestionController) Detail() {
 }
 
 // Add
-// @Title 新增问题
-// @Description 新增问题
+// @Title 新增提示词
+// @Description 新增提示词
 // @Param	request	body request.AddQuestionReq true "type json string"
 // @Success 200 Ret=200 新增成功
 // @router /question/add [post]
@@ -256,14 +256,14 @@ func (c *QuestionController) Add() {
 	}
 	req.Content = strings.TrimSpace(req.Content)
 	if req.Content == "" {
-		br.Msg = "请输入问题"
+		br.Msg = "请输入提示词"
 		br.IsSendEmail = false
 		return
 	}
 	//obj := rag.Question{}
 	//_, err = obj.GetByCondition(` AND question_content = ? `, []interface{}{req.Content})
 	//if err == nil {
-	//	br.Msg = "问题已入库,请不要重复添加"
+	//	br.Msg = "提示词已入库,请不要重复添加"
 	//	br.IsSendEmail = false
 	//	return
 	//}
@@ -293,8 +293,8 @@ func (c *QuestionController) Add() {
 }
 
 // Edit
-// @Title 编辑问题
-// @Description 编辑问题
+// @Title 编辑提示词
+// @Description 编辑提示词
 // @Param	request	body request.EditQuestionReq true "type json string"
 // @Success 200 Ret=200 新增成功
 // @router /question/edit [post]
@@ -312,13 +312,13 @@ func (c *QuestionController) Edit() {
 		return
 	}
 	if req.QuestionId <= 0 {
-		br.Msg = "问题id不能为空"
+		br.Msg = "提示词id不能为空"
 		br.IsSendEmail = false
 		return
 	}
 	req.Content = strings.TrimSpace(req.Content)
 	if req.Content == "" {
-		br.Msg = "请输入问题"
+		br.Msg = "请输入提示词"
 		br.IsSendEmail = false
 		return
 	}
@@ -327,9 +327,9 @@ func (c *QuestionController) Edit() {
 	item, err := obj.GetByID(req.QuestionId)
 	if err != nil {
 		br.Msg = "修改失败"
-		br.ErrMsg = "修改失败,查找问题失败,Err:" + err.Error()
+		br.ErrMsg = "修改失败,查找提示词失败,Err:" + err.Error()
 		if utils.IsErrNoRow(err) {
-			br.Msg = "问题不存在"
+			br.Msg = "提示词不存在"
 			br.IsSendEmail = false
 		}
 		return
@@ -352,8 +352,8 @@ func (c *QuestionController) Edit() {
 }
 
 // Del
-// @Title 删除问题
-// @Description 删除问题
+// @Title 删除提示词
+// @Description 删除提示词
 // @Param	request	body request.EditQuestionReq true "type json string"
 // @Success 200 Ret=200 新增成功
 // @router /question/del [post]
@@ -371,7 +371,7 @@ func (c *QuestionController) Del() {
 		return
 	}
 	if req.QuestionId <= 0 {
-		br.Msg = "问题id不能为空"
+		br.Msg = "提示词id不能为空"
 		br.IsSendEmail = false
 		return
 	}
@@ -380,9 +380,9 @@ func (c *QuestionController) Del() {
 	item, err := obj.GetByID(req.QuestionId)
 	if err != nil {
 		br.Msg = "修改失败"
-		br.ErrMsg = "修改失败,查找问题失败,Err:" + err.Error()
+		br.ErrMsg = "修改失败,查找提示词失败,Err:" + err.Error()
 		if utils.IsErrNoRow(err) {
-			br.Msg = "问题不存在"
+			br.Msg = "提示词不存在"
 			br.IsSendEmail = false
 		}
 		return
@@ -402,8 +402,189 @@ func (c *QuestionController) Del() {
 	br.Msg = `删除成功`
 }
 
+// SetDefault
+// @Title 设置默认提示词
+// @Description 设置默认提示词
+// @Param	request	body request.EditQuestionReq true "type json string"
+// @Success 200 Ret=200 设置成功
+// @router /question/default/set [post]
+func (c *QuestionController) SetDefault() {
+	br := new(models.BaseResponse).Init()
+	defer func() {
+		c.Data["json"] = br
+		c.ServeJSON()
+	}()
+	var req request.EditQuestionReq
+	err := json.Unmarshal(c.Ctx.Input.RequestBody, &req)
+	if err != nil {
+		br.Msg = "参数解析异常!"
+		br.ErrMsg = "参数解析失败,Err:" + err.Error()
+		return
+	}
+	if req.QuestionId <= 0 {
+		br.Msg = "提示词id不能为空"
+		br.IsSendEmail = false
+		return
+	}
+
+	obj := rag.Question{}
+	item, err := obj.GetByID(req.QuestionId)
+	if err != nil {
+		br.Msg = "修改失败"
+		br.ErrMsg = "修改失败,查找提示词失败,Err:" + err.Error()
+		if utils.IsErrNoRow(err) {
+			br.Msg = "提示词不存在"
+			br.IsSendEmail = false
+		}
+		return
+	}
+
+	if item.IsDefault == 1 {
+		br.Msg = "该提示词已经是默认提示词,无需设置"
+		br.IsSendEmail = false
+		return
+	}
+	item.IsDefault = 1
+	item.GenerateStatus = `undo`
+	item.ModifyTime = time.Now()
+	err = item.Update([]string{"is_default", "generate_status", "modify_time"})
+	if err != nil {
+		br.Msg = "设置失败"
+		br.ErrMsg = "设置失败,Err:" + err.Error()
+		return
+	}
+
+	br.Ret = 200
+	br.Success = true
+	br.Msg = `设置成功`
+}
+
+// UnSetDefault
+// @Title 取消设置默认提示词
+// @Description 取消设置默认提示词
+// @Param	request	body request.EditQuestionReq true "type json string"
+// @Success 200 Ret=200 设置成功
+// @router /question/default/unset [post]
+func (c *QuestionController) UnSetDefault() {
+	br := new(models.BaseResponse).Init()
+	defer func() {
+		c.Data["json"] = br
+		c.ServeJSON()
+	}()
+	var req request.EditQuestionReq
+	err := json.Unmarshal(c.Ctx.Input.RequestBody, &req)
+	if err != nil {
+		br.Msg = "参数解析异常!"
+		br.ErrMsg = "参数解析失败,Err:" + err.Error()
+		return
+	}
+	if req.QuestionId <= 0 {
+		br.Msg = "提示词id不能为空"
+		br.IsSendEmail = false
+		return
+	}
+
+	obj := rag.Question{}
+	item, err := obj.GetByID(req.QuestionId)
+	if err != nil {
+		br.Msg = "修改失败"
+		br.ErrMsg = "修改失败,查找提示词失败,Err:" + err.Error()
+		if utils.IsErrNoRow(err) {
+			br.Msg = "提示词不存在"
+			br.IsSendEmail = false
+		}
+		return
+	}
+
+	if item.IsDefault == 0 {
+		br.Msg = "该提示词不是默认提示词,无需取消"
+		br.IsSendEmail = false
+		return
+	}
+	item.IsDefault = 1
+	item.GenerateStatus = `undo`
+	item.ModifyTime = time.Now()
+	err = item.Update([]string{"is_default", "generate_status", "modify_time"})
+	if err != nil {
+		br.Msg = "取消设置失败"
+		br.ErrMsg = "取消设置失败,Err:" + err.Error()
+		return
+	}
+
+	// todo 对应的提示词生成的摘要库和向量库内容也取消,同时需要加锁,不允许重复操作
+
+	br.Ret = 200
+	br.Success = true
+	br.Msg = `取消设置成功`
+}
+
+// GenerateAbstract
+// @Title 生成摘要
+// @Description 生成摘要
+// @Param	request	body request.EditQuestionReq true "type json string"
+// @Success 200 Ret=200 设置成功
+// @router /question/abstract/generate [post]
+func (c *QuestionController) GenerateAbstract() {
+	br := new(models.BaseResponse).Init()
+	defer func() {
+		c.Data["json"] = br
+		c.ServeJSON()
+	}()
+	var req request.EditQuestionReq
+	err := json.Unmarshal(c.Ctx.Input.RequestBody, &req)
+	if err != nil {
+		br.Msg = "参数解析异常!"
+		br.ErrMsg = "参数解析失败,Err:" + err.Error()
+		return
+	}
+	if req.QuestionId <= 0 {
+		br.Msg = "提示词id不能为空"
+		br.IsSendEmail = false
+		return
+	}
+
+	obj := rag.Question{}
+	item, err := obj.GetByID(req.QuestionId)
+	if err != nil {
+		br.Msg = "修改失败"
+		br.ErrMsg = "修改失败,查找提示词失败,Err:" + err.Error()
+		if utils.IsErrNoRow(err) {
+			br.Msg = "提示词不存在"
+			br.IsSendEmail = false
+		}
+		return
+	}
+
+	if item.IsDefault != 1 {
+		br.Msg = "该提示词不是默认提示词,不允许生成"
+		br.IsSendEmail = false
+		return
+	}
+	if item.GenerateStatus != `undo` {
+		br.Msg = "该提示词已经生成过摘要,不允许重复生成"
+		br.IsSendEmail = false
+		return
+	}
+
+	// 标记摘要生成状态,避免重复生成
+	item.GenerateStatus = `done`
+	item.ModifyTime = time.Now()
+	err = item.Update([]string{"generate_status", "modify_time"})
+	if err != nil {
+		br.Msg = "取消设置失败"
+		br.ErrMsg = "取消设置失败,Err:" + err.Error()
+		return
+	}
+
+	// todo 对应的提示词生成的摘要库和向量库内容也取消,同时需要加锁,不允许重复操作
+
+	br.Ret = 200
+	br.Success = true
+	br.Msg = `取消设置成功`
+}
+
 //func init() {
-//	// 问题加到es
+//	// 提示词加到es
 //	{
 //		obj := rag.Question{}
 //		list, _ := obj.GetListByCondition(``, ` `, []interface{}{}, 0, 10000)

+ 112 - 0
models/rag/ai_task.go

@@ -0,0 +1,112 @@
+package rag
+
+import (
+	"eta/eta_api/global"
+	"eta/eta_api/utils"
+	"fmt"
+	"time"
+)
+
+// AiTask ai这边的任务表
+type AiTask struct {
+	AiTaskID                int       `gorm:"primaryKey;column:ai_task_id" json:"-"`                           // id
+	TaskName                string    `gorm:"column:task_name" json:"taskName"`                                // 任务名称
+	TaskType                string    `gorm:"column:task_type" json:"taskType"`                                // 任务类型
+	Status                  string    `gorm:"column:status" json:"status"`                                     // 任务状态
+	StartTime               time.Time `gorm:"column:start_time" json:"startTime"`                              // 开始时间
+	EndTime                 time.Time `gorm:"column:end_time" json:"endTime"`                                  // 结束时间
+	CreateTime              time.Time `gorm:"column:create_time" json:"createTime"`                            // 创建时间
+	UpdateTime              time.Time `gorm:"column:update_time" json:"updateTime"`                            // 更新时间
+	Parameters              string    `gorm:"column:parameters" json:"parameters"`                             // 执行参数
+	Logs                    string    `gorm:"column:logs" json:"logs"`                                         // 日志
+	Errormessage            string    `gorm:"column:ErrorMessage" json:"errorMessage"`                         // 错误信息
+	Priority                int       `gorm:"column:priority" json:"priority"`                                 // 优先级
+	RetryCount              int       `gorm:"column:retry_count" json:"retryCount"`                            // 重试次数
+	EstimatedCompletionTime time.Time `gorm:"column:estimated_completion_time" json:"estimatedCompletionTime"` // 预计完成时间
+	ActualCompletitonTime   time.Time `gorm:"column:actual_completiton_time" json:"actualCompletitonTime"`     // 实际完成时间
+	Remark                  string    `gorm:"column:remark" json:"remark"`                                     // 备注
+}
+
+// TableName get sql table name.获取数据库表名
+func (m *AiTask) TableName() string {
+	return "ai_task"
+}
+
+// AiTaskColumns get sql column name.获取数据库列名
+var AiTaskColumns = struct {
+	AiTaskID                string
+	TaskName                string
+	TaskType                string
+	Status                  string
+	StartTime               string
+	EndTime                 string
+	CreateTime              string
+	UpdateTime              string
+	Parameters              string
+	Logs                    string
+	Errormessage            string
+	Priority                string
+	RetryCount              string
+	EstimatedCompletionTime string
+	ActualCompletitonTime   string
+	Remark                  string
+}{
+	AiTaskID:                "ai_task_id",
+	TaskName:                "task_name",
+	TaskType:                "task_type",
+	Status:                  "status",
+	StartTime:               "start_time",
+	EndTime:                 "end_time",
+	CreateTime:              "create_time",
+	UpdateTime:              "update_time",
+	Parameters:              "parameters",
+	Logs:                    "logs",
+	Errormessage:            "ErrorMessage",
+	Priority:                "priority",
+	RetryCount:              "retry_count",
+	EstimatedCompletionTime: "estimated_completion_time",
+	ActualCompletitonTime:   "actual_completiton_time",
+	Remark:                  "remark",
+}
+
+func (m *AiTask) Create() (err error) {
+	err = global.DbMap[utils.DbNameAI].Create(&m).Error
+
+	return
+}
+
+func (m *AiTask) Update(updateCols []string) (err error) {
+	err = global.DbMap[utils.DbNameAI].Select(updateCols).Updates(&m).Error
+
+	return
+}
+
+func (m *AiTask) Del() (err error) {
+	err = global.DbMap[utils.DbNameAI].Delete(&m).Error
+
+	return
+}
+
+func (m *AiTask) GetByID(id int) (item *AiTask, err error) {
+	err = global.DbMap[utils.DbNameAI].Where(fmt.Sprintf("%s = ?", AiTaskColumns.AiTaskID), id).First(&item).Error
+
+	return
+}
+
+func (m *AiTask) GetByCondition(condition string, pars []interface{}) (item *AiTask, err error) {
+	sqlStr := fmt.Sprintf(`SELECT * FROM %s WHERE 1=1 %s`, m.TableName(), condition)
+	err = global.DbMap[utils.DbNameAI].Raw(sqlStr, pars...).First(&item).Error
+
+	return
+}
+
+func (m *AiTask) GetListByCondition(field, condition string, pars []interface{}, startSize, pageSize int) (items []*AiTask, err error) {
+	if field == "" {
+		field = "*"
+	}
+	sqlStr := fmt.Sprintf(`SELECT %s FROM %s WHERE 1=1 %s order by AiTask_id desc LIMIT ?,?`, field, m.TableName(), condition)
+	pars = append(pars, startSize, pageSize)
+	err = global.DbMap[utils.DbNameAI].Raw(sqlStr, pars...).Find(&items).Error
+
+	return
+}

+ 85 - 0
models/rag/ai_task_record.go

@@ -0,0 +1,85 @@
+package rag
+
+import (
+	"eta/eta_api/global"
+	"eta/eta_api/utils"
+	"fmt"
+	"time"
+)
+
+// AiTaskRecord AI任务的子记录
+type AiTaskRecord struct {
+	AiTaskRecordID int       `gorm:"primaryKey;column:ai_task_record_id" json:"-"` // 任务记录id
+	AiTaskID       int       `gorm:"column:ai_task_id" json:"aiTaskId"`            // 任务id
+	Parameters     string    `gorm:"column:parameters" json:"parameters"`          // 子任务参数
+	Status         string    `gorm:"column:status" json:"status"`                  // 状态
+	Remark         string    `gorm:"column:remark" json:"remark"`                  // 备注
+	ModifyTime     time.Time `gorm:"column:modify_time" json:"modifyTime"`         // 最后一次修改时间
+	CreateTime     time.Time `gorm:"column:create_time" json:"createTime"`         // 任务创建时间
+}
+
+// TableName get sql table name.获取数据库表名
+func (m *AiTaskRecord) TableName() string {
+	return "ai_task_record"
+}
+
+// AiTaskRecordColumns get sql column name.获取数据库列名
+var AiTaskRecordColumns = struct {
+	AiTaskRecordID string
+	AiTaskID       string
+	Parameters     string
+	Status         string
+	Remark         string
+	ModifyTime     string
+	CreateTime     string
+}{
+	AiTaskRecordID: "ai_task_record_id",
+	AiTaskID:       "ai_task_id",
+	Parameters:     "parameters",
+	Status:         "status",
+	Remark:         "remark",
+	ModifyTime:     "modify_time",
+	CreateTime:     "create_time",
+}
+
+func (m *AiTaskRecord) Create() (err error) {
+	err = global.DbMap[utils.DbNameAI].Create(&m).Error
+
+	return
+}
+
+func (m *AiTaskRecord) Update(updateCols []string) (err error) {
+	err = global.DbMap[utils.DbNameAI].Select(updateCols).Updates(&m).Error
+
+	return
+}
+
+func (m *AiTaskRecord) Del() (err error) {
+	err = global.DbMap[utils.DbNameAI].Delete(&m).Error
+
+	return
+}
+
+func (m *AiTaskRecord) GetByID(id int) (item *AiTaskRecord, err error) {
+	err = global.DbMap[utils.DbNameAI].Where(fmt.Sprintf("%s = ?", AiTaskRecordColumns.AiTaskRecordID), id).First(&item).Error
+
+	return
+}
+
+func (m *AiTaskRecord) GetByCondition(condition string, pars []interface{}) (item *AiTaskRecord, err error) {
+	sqlStr := fmt.Sprintf(`SELECT * FROM %s WHERE 1=1 %s`, m.TableName(), condition)
+	err = global.DbMap[utils.DbNameAI].Raw(sqlStr, pars...).First(&item).Error
+
+	return
+}
+
+func (m *AiTaskRecord) GetListByCondition(field, condition string, pars []interface{}, startSize, pageSize int) (items []*AiTaskRecord, err error) {
+	if field == "" {
+		field = "*"
+	}
+	sqlStr := fmt.Sprintf(`SELECT %s FROM %s WHERE 1=1 %s order by AiTaskRecord_id desc LIMIT ?,?`, field, m.TableName(), condition)
+	pars = append(pars, startSize, pageSize)
+	err = global.DbMap[utils.DbNameAI].Raw(sqlStr, pars...).Find(&items).Error
+
+	return
+}

+ 15 - 0
models/rag/question.go

@@ -13,6 +13,9 @@ type Question struct {
 	QuestionTitle   string    `gorm:"column:question_title;type:varchar(255);comment:问题标题;" description:"问题标题"`
 	QuestionContent string    `gorm:"column:question_content;type:varchar(255);comment:问题内容;" description:"问题内容"`
 	Sort            int       `gorm:"column:sort;type:int(11);comment:排序;default:0;" description:"排序"`
+	Version         string    `gorm:"column:version;type:varchar(4);comment:问题版本;default:NULL;" description:"问题版本"`
+	GenerateStatus  string    `gorm:"column:generate_status;type:enum('undo', 'done');comment:生成摘要状态;default:NULL;" description:"生成摘要状态"`
+	IsDefault       int       `gorm:"column:is_default;type:int(1);comment:是否默认提示词;default:NULL;" description:"是否默认提示词"`
 	SysUserId       int       `gorm:"column:sys_user_id;type:int(11);comment:添加人id;default:0;" description:"添加人id"`
 	SysUserRealName string    `gorm:"column:sys_user_real_name;type:varchar(255);comment:添加人真实名称;" description:"添加人真实名称"`
 	ModifyTime      time.Time `gorm:"column:modify_time;type:datetime;default:NULL;" description:"modify_time"`
@@ -30,6 +33,9 @@ var QuestionColumns = struct {
 	QuestionTitle   string
 	QuestionContent string
 	Sort            string
+	Version         string
+	GenerateStatus  string
+	IsDefault       string
 	ModifyTime      string
 	CreateTime      string
 }{
@@ -37,6 +43,9 @@ var QuestionColumns = struct {
 	QuestionTitle:   "question_title",
 	QuestionContent: "question_content",
 	Sort:            "sort",
+	Version:         "version",
+	GenerateStatus:  "generate_status",
+	IsDefault:       "is_default",
 	ModifyTime:      "modify_time",
 	CreateTime:      "create_time",
 }
@@ -64,6 +73,9 @@ type QuestionView struct {
 	QuestionTitle   string `gorm:"column:question_title;type:varchar(255);comment:问题标题;" description:"问题标题"`
 	QuestionContent string `gorm:"column:question_content;type:varchar(255);comment:问题内容;" description:"问题内容"`
 	Sort            int    `gorm:"column:sort;type:int(11);comment:排序;default:0;" description:"排序"`
+	Version         string `gorm:"column:version;type:varchar(4);comment:问题版本;default:NULL;" description:"问题版本"`
+	GenerateStatus  string `gorm:"column:generate_status;type:enum('undo', 'done');comment:生成摘要状态;default:NULL;" description:"生成摘要状态"`
+	IsDefault       int    `gorm:"column:is_default;type:int(1);comment:是否默认提示词;default:NULL;" description:"是否默认提示词"`
 	SysUserId       int    `gorm:"column:sys_user_id;type:int(11);comment:添加人id;default:0;" description:"添加人id"`
 	SysUserRealName string `gorm:"column:sys_user_real_name;type:varchar(255);comment:添加人真实名称;" description:"添加人真实名称"`
 	ModifyTime      string `gorm:"column:modify_time;type:datetime;default:NULL;" description:"modify_time"`
@@ -84,6 +96,9 @@ func (m *Question) ToView() QuestionView {
 		QuestionTitle:   m.QuestionTitle,
 		QuestionContent: m.QuestionContent,
 		Sort:            m.Sort,
+		Version:         m.Version,
+		GenerateStatus:  m.GenerateStatus,
+		IsDefault:       m.IsDefault,
 		SysUserId:       m.SysUserId,
 		SysUserRealName: m.SysUserRealName,
 		ModifyTime:      modifyTime,