Просмотр исходного кода

feat:模型训练、标的运行新增任务调度

Roc 6 дней назад
Родитель
Сommit
05955afd1d

+ 8 - 5
cache/index_task.go

@@ -5,8 +5,9 @@ import (
 	"fmt"
 )
 
-type AiTaskRecordOp struct {
-	AiTaskRecordId int
+type IndexTaskRecordOp struct {
+	IndexTaskRecordId int
+	TaskType          string
 }
 
 // AddIndexTaskRecordOpToCache
@@ -15,9 +16,11 @@ type AiTaskRecordOp struct {
 // @datetime 2025-04-24 09:41:11
 // @param aiTaskRecordId int
 // @return bool
-func AddIndexTaskRecordOpToCache(aiTaskRecordId int) bool {
-	record := new(AiTaskRecordOp)
-	record.AiTaskRecordId = aiTaskRecordId
+func AddIndexTaskRecordOpToCache(aiTaskRecordId int, taskType string) bool {
+	record := new(IndexTaskRecordOp)
+	record.IndexTaskRecordId = aiTaskRecordId
+	record.TaskType = taskType
+
 	if utils.Re == nil {
 		err := utils.Rc.LPush(utils.CACHE_INDEX_TASK, record)
 

+ 16 - 0
controllers/data_manage/ai_predict_model/index.go

@@ -1380,6 +1380,22 @@ func (this *AiPredictModelIndexController) Run() {
 		indexIdList = append(indexIdList, v.AiPredictModelIndexId)
 	}
 
+	if len(indexIdList) <= 0 {
+		br.Msg = "没有找到可以运行的标的"
+		br.IsSendEmail = false
+		return
+	}
+
+	err = indexOb.UpdateRunStatusByIdList(indexIdList)
+	if err != nil {
+		br.Msg = "运行失败"
+		br.ErrMsg = fmt.Sprintf("运行失败, %v", e)
+		return
+	}
+
+	// 加入模型运行任务中
+	go services.AddAiModelRunTask(indexIdList, this.SysUser)
+
 	br.Data = indexIdList
 	br.Ret = 200
 	br.Success = true

+ 36 - 28
controllers/data_manage/ai_predict_model/index_config.go

@@ -4,7 +4,7 @@ import (
 	"encoding/json"
 	"eta/eta_api/controllers"
 	"eta/eta_api/models"
-	data_manage "eta/eta_api/models/ai_predict_model"
+	aiPredictModel "eta/eta_api/models/ai_predict_model"
 	"eta/eta_api/models/ai_predict_model/request"
 	"eta/eta_api/models/ai_predict_model/response"
 	"eta/eta_api/services"
@@ -25,7 +25,7 @@ type AiPredictModelIndexConfigController struct {
 // @Param   PageSize   query   int  true       "每页数据条数"
 // @Param   CurrentIndex   query   int  true       "当前页页码,从1开始"
 // @Param   IndexId   query   int  true       "标的id"
-// @Success 200 {object} []*data_manage.AiPredictModelIndexConfigView
+// @Success 200 {object} []*response.AiPredictModelIndexConfigListResp
 // @router /index_config/list [get]
 func (c *AiPredictModelIndexConfigController) List() {
 	br := new(models.BaseResponse).Init()
@@ -60,15 +60,15 @@ func (c *AiPredictModelIndexConfigController) List() {
 	startSize = utils.StartIndex(currentIndex, pageSize)
 
 	var total int
-	viewList := make([]data_manage.AiPredictModelIndexConfigView, 0)
+	viewList := make([]aiPredictModel.AiPredictModelIndexConfigView, 0)
 
 	var condition string
 	var pars []interface{}
 
-	condition += fmt.Sprintf(` AND %s = ? `, data_manage.AiPredictModelIndexConfigColumns.AiPredictModelIndexId)
+	condition += fmt.Sprintf(` AND %s = ? `, aiPredictModel.AiPredictModelIndexConfigColumns.AiPredictModelIndexId)
 	pars = append(pars, indexId)
 
-	obj := new(data_manage.AiPredictModelIndexConfig)
+	obj := new(aiPredictModel.AiPredictModelIndexConfig)
 	tmpTotal, list, err := obj.GetPageListByCondition(condition, pars, startSize, pageSize)
 	if err != nil {
 		br.Msg = "获取失败"
@@ -120,7 +120,7 @@ func (c *AiPredictModelIndexConfigController) CurrVersion() {
 	}
 
 	// 查询标的情况
-	indexOb := new(data_manage.AiPredictModelIndex)
+	indexOb := new(aiPredictModel.AiPredictModelIndex)
 	indexItem, e := indexOb.GetItemById(indexId)
 	if e != nil {
 		if utils.IsErrNoRow(e) {
@@ -138,7 +138,7 @@ func (c *AiPredictModelIndexConfigController) CurrVersion() {
 		return
 	}
 
-	obj := new(data_manage.AiPredictModelIndexConfig)
+	obj := new(aiPredictModel.AiPredictModelIndexConfig)
 	configItem, err := obj.GetById(indexItem.AiPredictModelIndexConfigId)
 	if err != nil {
 		br.Msg = "获取失败"
@@ -176,7 +176,7 @@ func (c *AiPredictModelIndexConfigController) SetCurr() {
 	var req request.DelConfigReq
 	err := json.Unmarshal(c.Ctx.Input.RequestBody, &req)
 	// 查找配置
-	obj := new(data_manage.AiPredictModelIndexConfig)
+	obj := new(aiPredictModel.AiPredictModelIndexConfig)
 	configItem, err := obj.GetById(req.AiPredictModelIndexConfigId)
 	if err != nil {
 		br.Msg = "修改失败"
@@ -189,7 +189,7 @@ func (c *AiPredictModelIndexConfigController) SetCurr() {
 	}
 
 	// 查询标的情况
-	indexOb := new(data_manage.AiPredictModelIndex)
+	indexOb := new(aiPredictModel.AiPredictModelIndex)
 	indexItem, e := indexOb.GetItemById(configItem.AiPredictModelIndexId)
 	if e != nil {
 		br.Msg = "操作失败"
@@ -237,7 +237,7 @@ func (c *AiPredictModelIndexConfigController) Del() {
 	}
 
 	// 查找配置
-	obj := new(data_manage.AiPredictModelIndexConfig)
+	obj := new(aiPredictModel.AiPredictModelIndexConfig)
 	item, err := obj.GetById(req.AiPredictModelIndexConfigId)
 	if err != nil {
 		br.Msg = "修改失败"
@@ -252,7 +252,7 @@ func (c *AiPredictModelIndexConfigController) Del() {
 	// 查找是否被标的引用为默认模型
 	{
 		// 查询标的情况
-		indexOb := new(data_manage.AiPredictModelIndex)
+		indexOb := new(aiPredictModel.AiPredictModelIndex)
 		count, e := indexOb.GetCountByCondition(fmt.Sprintf(` AND %s = ? `, indexOb.Cols().AiPredictModelIndexConfigId), []interface{}{item.AiPredictModelIndexConfigId})
 		if e != nil {
 			br.Msg = "删除失败"
@@ -267,7 +267,7 @@ func (c *AiPredictModelIndexConfigController) Del() {
 		}
 	}
 
-	if !utils.InArrayByStr([]string{data_manage.TrainStatusSuccess, data_manage.TrainStatusFailed}, item.TrainStatus) {
+	if !utils.InArrayByStr([]string{aiPredictModel.TrainStatusSuccess, aiPredictModel.TrainStatusFailed}, item.TrainStatus) {
 		br.Msg = "删除失败,该版本配置正在训练中"
 		br.IsSendEmail = false
 		return
@@ -291,7 +291,7 @@ func (c *AiPredictModelIndexConfigController) Del() {
 // @Title 获取当前版本的图表信息
 // @Description 获取当前版本的图表信息
 // @Param   AiPredictModelIndexConfigId   query   int   true   "标的配置ID"
-// @Success 200 {object} []*data_manage.AiPredictModelIndexConfigView
+// @Success 200 {object} []*response.AiPredictModelDetailResp
 // @router /index_config/chart/detail [get]
 func (c *AiPredictModelIndexConfigController) ChartDetail() {
 	br := new(models.BaseResponse).Init()
@@ -318,7 +318,7 @@ func (c *AiPredictModelIndexConfigController) ChartDetail() {
 	// TODO 后面加上数据缓存
 
 	// 查找配置
-	obj := new(data_manage.AiPredictModelIndexConfig)
+	obj := new(aiPredictModel.AiPredictModelIndexConfig)
 	configItem, err := obj.GetById(indexConfigId)
 	if err != nil {
 		br.Msg = "修改失败"
@@ -332,7 +332,7 @@ func (c *AiPredictModelIndexConfigController) ChartDetail() {
 
 	// 查找是否被标的引用为默认模型
 	// 查询标的情况
-	indexOb := new(data_manage.AiPredictModelIndex)
+	indexOb := new(aiPredictModel.AiPredictModelIndex)
 	indexItem, e := indexOb.GetItemByConfigId(configItem.AiPredictModelIndexConfigId)
 	if e != nil {
 		br.Msg = "获取失败"
@@ -341,13 +341,13 @@ func (c *AiPredictModelIndexConfigController) ChartDetail() {
 	}
 
 	// 获取标的数据
-	dailyData := make([]*data_manage.AiPredictModelIndexConfigTrainData, 0)
+	dailyData := make([]*aiPredictModel.AiPredictModelIndexConfigTrainData, 0)
 	{
-		dataOb := new(data_manage.AiPredictModelIndexConfigTrainData)
-		dataCond := fmt.Sprintf(` AND %s = ?`, data_manage.AiPredictModelIndexConfigTrainDataColumns.AiPredictModelIndexConfigId)
+		dataOb := new(aiPredictModel.AiPredictModelIndexConfigTrainData)
+		dataCond := fmt.Sprintf(` AND %s = ?`, aiPredictModel.AiPredictModelIndexConfigTrainDataColumns.AiPredictModelIndexConfigId)
 		dataPars := make([]interface{}, 0)
 		dataPars = append(dataPars, configItem.AiPredictModelIndexConfigId)
-		list, e := dataOb.GetAllListByCondition(dataCond, dataPars, []string{}, fmt.Sprintf("%s DESC", data_manage.AiPredictModelIndexConfigTrainDataColumns.DataTime))
+		list, e := dataOb.GetAllListByCondition(dataCond, dataPars, []string{}, fmt.Sprintf("%s DESC", aiPredictModel.AiPredictModelIndexConfigTrainDataColumns.DataTime))
 		if e != nil {
 			br.Msg = "获取失败"
 			br.ErrMsg = fmt.Sprintf("获取标的数据失败, %v", e)
@@ -356,7 +356,7 @@ func (c *AiPredictModelIndexConfigController) ChartDetail() {
 
 		for _, v := range list {
 			// 日度数据
-			if v.Source == data_manage.ModelDataSourceDaily {
+			if v.Source == aiPredictModel.ModelDataSourceDaily {
 				dailyData = append(dailyData, v)
 				continue
 			}
@@ -365,7 +365,7 @@ func (c *AiPredictModelIndexConfigController) ChartDetail() {
 
 	// 日度图表
 	if len(dailyData) > 0 {
-		dailyChartDetail, e := services.GetAiPredictConfigChartDetailByData(indexItem.IndexName, configItem, dailyData, data_manage.ModelDataSourceDaily)
+		dailyChartDetail, e := services.GetAiPredictConfigChartDetailByData(indexItem.IndexName, configItem, dailyData, aiPredictModel.ModelDataSourceDaily)
 		if e != nil {
 			br.Msg = "获取失败"
 			br.ErrMsg = fmt.Sprintf("获取日度图表失败, %v", e)
@@ -413,7 +413,7 @@ func (c *AiPredictModelIndexConfigController) Train() {
 	}
 
 	// 查询标的情况
-	indexOb := new(data_manage.AiPredictModelIndex)
+	indexOb := new(aiPredictModel.AiPredictModelIndex)
 	indexItem, err := indexOb.GetItemById(req.IndexId)
 	if err != nil {
 		br.Msg = "训练失败,查找标的失败"
@@ -430,7 +430,7 @@ func (c *AiPredictModelIndexConfigController) Train() {
 		return
 	}
 
-	obj := new(data_manage.AiPredictModelIndexConfig)
+	obj := new(aiPredictModel.AiPredictModelIndexConfig)
 
 	// 查找当前标的是否存在待训练/训练中的模型
 	count, err := services.GetCurrentRunningAiPredictModelIndexCount()
@@ -445,6 +445,8 @@ func (c *AiPredictModelIndexConfigController) Train() {
 		return
 	}
 
+	var indexConfig *aiPredictModel.AiPredictModelIndexConfig
+
 	if req.AiPredictModelIndexConfigId > 0 {
 		// 查找配置
 		item, err := obj.GetById(req.AiPredictModelIndexConfigId)
@@ -464,7 +466,7 @@ func (c *AiPredictModelIndexConfigController) Train() {
 			return
 		}
 
-		if item.TrainStatus != data_manage.TrainStatusFailed {
+		if item.TrainStatus != aiPredictModel.TrainStatusFailed {
 			br.Msg = "该模型训练状态异常,不允许重新训练"
 			br.ErrMsg = "该模型训练状态异常,不允许重新训练,当前状态:" + item.TrainStatus
 			br.IsSendEmail = false
@@ -479,12 +481,14 @@ func (c *AiPredictModelIndexConfigController) Train() {
 			return
 		}
 
+		indexConfig = item
+
 	} else {
 		// 新增训练模型
-		item := &data_manage.AiPredictModelIndexConfig{
+		item := &aiPredictModel.AiPredictModelIndexConfig{
 			AiPredictModelIndexConfigId: 0,
 			AiPredictModelIndexId:       indexItem.AiPredictModelIndexId,
-			TrainStatus:                 data_manage.TrainStatusWaiting,
+			TrainStatus:                 aiPredictModel.TrainStatusWaiting,
 			Params:                      string(paramsStrByte),
 			TrainMse:                    "",
 			TrainR2:                     "",
@@ -503,9 +507,11 @@ func (c *AiPredictModelIndexConfigController) Train() {
 			br.ErrMsg = "训练失败,Err:" + err.Error()
 			return
 		}
+
+		indexConfig = item
 	}
 
-	indexItem.TrainStatus = data_manage.TrainStatusWaiting
+	indexItem.TrainStatus = aiPredictModel.TrainStatusWaiting
 	indexItem.ModifyTime = time.Now()
 	err = indexItem.Update([]string{"train_status", "modify_time"})
 	if err != nil {
@@ -513,7 +519,9 @@ func (c *AiPredictModelIndexConfigController) Train() {
 		br.ErrMsg = "训练失败,Err:" + err.Error()
 		return
 	}
-	// TODO 加入训练任务中
+
+	// 加入模型训练任务中
+	go services.AddAiModelTrainTask(indexItem, indexConfig, c.SysUser)
 
 	br.Ret = 200
 	br.Success = true

+ 17 - 0
models/ai_predict_model/ai_predict_model_index.go

@@ -510,3 +510,20 @@ func (m *AiPredictModelIndex) GetSortMax() (sort int, err error) {
 	}
 	return
 }
+
+// UpdateRunStatusByIdList
+// @Description: 通过标的ID列表更新运行状态
+// @author: Roc
+// @receiver m
+// @datetime 2025-05-08 14:44:15
+// @param indexIdList []int
+// @return err error
+func (m *AiPredictModelIndex) UpdateRunStatusByIdList(indexIdList []int) (err error) {
+	if len(indexIdList) <= 0 {
+		return
+	}
+	sql := ` UPDATE ai_predict_model_index SET run_status = ? WHERE ai_predict_model_index_id in (?)`
+	err = global.DbMap[utils.DbNameIndex].Exec(sql, RunStatusWaiting, indexIdList).Error
+
+	return
+}

+ 1 - 1
models/data_manage/index_task.go

@@ -13,7 +13,7 @@ type IndexTask struct {
 	IndexTaskID     int       `gorm:"primaryKey;column:index_task_id" description:"-"`
 	TaskName        string    `gorm:"column:task_name" description:"任务名称"`
 	TaskType        string    `gorm:"column:task_type" description:"任务类型"`
-	Status          string    `gorm:"column:status" description:"任务状态"`
+	Status          string    `gorm:"column:status" description:"任务状态,枚举值:待处理,处理成功,处理失败,暂停处理"`
 	StartTime       time.Time `gorm:"column:start_time" description:"开始时间"`
 	EndTime         time.Time `gorm:"column:end_time" description:"结束时间"`
 	CreateTime      time.Time `gorm:"column:create_time" description:"创建时间"`

+ 153 - 0
services/ai_predict_model_index_config.go

@@ -0,0 +1,153 @@
+package services
+
+import (
+	"eta/eta_api/cache"
+	predictModel "eta/eta_api/models/ai_predict_model"
+	"eta/eta_api/models/data_manage"
+	"eta/eta_api/models/system"
+	"eta/eta_api/utils"
+	"fmt"
+	"time"
+)
+
+// AddAiModelTrainTask
+// @Description: 添加模型训练任务
+// @author: Roc
+// @datetime 2025-05-08 14:28:13
+// @param aiIndex *predictModel.AiPredictModelIndex
+// @param indexConfig *predictModel.AiPredictModelIndexConfig
+// @param sysUser *system.Admin
+func AddAiModelTrainTask(aiIndex *predictModel.AiPredictModelIndex, indexConfig *predictModel.AiPredictModelIndexConfig, sysUser *system.Admin) {
+	var err error
+	defer func() {
+		if err != nil {
+			utils.FileLog.Error(fmt.Sprintf("AddAiModelTrainTask error: %s", err.Error()))
+		}
+	}()
+	taskName := fmt.Sprintf("《%s》模型训练-%s", aiIndex.IndexName, time.Now().Format(utils.FormatShortDateTimeUnSpace))
+
+	indexTask := &data_manage.IndexTask{
+		IndexTaskID: 0,
+		TaskName:    taskName,
+		TaskType:    utils.INDEX_TASK_TYPE_AI_MODEL_TRAIN,
+		Status:      "待处理",
+		//StartTime:               time.Time{},
+		//EndTime:                 time.Time{},
+		CreateTime:      time.Now(),
+		UpdateTime:      time.Now(),
+		Logs:            "",
+		Errormessage:    "",
+		Priority:        0,
+		RetryCount:      0,
+		Remark:          "",
+		SysUserID:       sysUser.AdminId,
+		SysUserRealName: sysUser.RealName,
+	}
+
+	taskRecordList := make([]*data_manage.IndexTaskRecord, 0)
+
+	taskRecord := &data_manage.IndexTaskRecord{
+		IndexTaskRecordID: 0,
+		IndexTaskID:       0,
+		Parameters:        fmt.Sprint(indexConfig.AiPredictModelIndexConfigId),
+		Status:            "待处理",
+		Remark:            "",
+		ModifyTime:        time.Now(),
+		CreateTime:        time.Now(),
+	}
+	taskRecordList = append(taskRecordList, taskRecord)
+
+	// 创建AI模块的任务,用于后面的任务调度去生成摘要
+	err = data_manage.AddIndexTask(indexTask, taskRecordList)
+	if err != nil {
+		return
+	}
+
+	// 添加到缓存队列中
+	go addIndexTaskToCache(indexTask.IndexTaskID, indexTask.TaskType)
+
+	return
+}
+
+// AddAiModelRunTask
+// @Description: 添加模型运行任务
+// @author: Roc
+// @datetime 2025-05-08 14:33:38
+// @param aiModelIndexIdList []int 标的模型ID列表
+// @param sysUser *system.Admin
+func AddAiModelRunTask(aiModelIndexIdList []int, sysUser *system.Admin) {
+	var err error
+	defer func() {
+		if err != nil {
+			utils.FileLog.Error(fmt.Sprintf("AddAiModelTrainTask error: %s", err.Error()))
+		}
+	}()
+	taskName := fmt.Sprintf("模型运行-%s", time.Now().Format(utils.FormatShortDateTimeUnSpace))
+
+	indexTask := &data_manage.IndexTask{
+		IndexTaskID: 0,
+		TaskName:    taskName,
+		TaskType:    utils.INDEX_TASK_TYPE_AI_MODEL_RUN,
+		Status:      "待处理",
+		//StartTime:               time.Time{},
+		//EndTime:                 time.Time{},
+		CreateTime:      time.Now(),
+		UpdateTime:      time.Now(),
+		Logs:            "",
+		Errormessage:    "",
+		Priority:        0,
+		RetryCount:      0,
+		Remark:          "",
+		SysUserID:       sysUser.AdminId,
+		SysUserRealName: sysUser.RealName,
+	}
+
+	taskRecordList := make([]*data_manage.IndexTaskRecord, 0)
+
+	for _, aiModelIndexId := range aiModelIndexIdList {
+		taskRecord := &data_manage.IndexTaskRecord{
+			IndexTaskRecordID: 0,
+			IndexTaskID:       0,
+			Parameters:        fmt.Sprint(aiModelIndexId),
+			Status:            "待处理",
+			Remark:            "",
+			ModifyTime:        time.Now(),
+			CreateTime:        time.Now(),
+		}
+		taskRecordList = append(taskRecordList, taskRecord)
+	}
+
+	// 创建AI模块的任务,用于后面的任务调度去生成摘要
+	err = data_manage.AddIndexTask(indexTask, taskRecordList)
+	if err != nil {
+		return
+	}
+
+	// 添加到缓存队列中
+	go addIndexTaskToCache(indexTask.IndexTaskID, indexTask.TaskType)
+
+	return
+}
+
+// addIndexTaskToCache
+// @Description: 根据任务ID将待处理的任务丢入到list中
+// @author: Roc
+// @datetime 2025-05-08 14:30:05
+// @param indexTaskId int
+// @param taskType string
+func addIndexTaskToCache(indexTaskId int, taskType string) {
+	var err error
+	defer func() {
+		if err != nil {
+			utils.FileLog.Error("addTaskToCache error: %v", err)
+		}
+	}()
+	obj := data_manage.IndexTaskRecord{}
+	list, err := obj.GetAllListByCondition(" index_task_record_id ", ` AND index_task_id = ? AND status = ? `, []interface{}{indexTaskId, `待处理`})
+	if err != nil {
+		return
+	}
+	for _, item := range list {
+		cache.AddIndexTaskRecordOpToCache(item.IndexTaskRecordID, taskType)
+	}
+}

+ 5 - 0
utils/constants.go

@@ -604,3 +604,8 @@ const (
 const (
 	DATA_SOURCE_NAME_RADISH_RESEARCH = "萝卜投研" // 萝卜投研 -> 105
 )
+
+const (
+	INDEX_TASK_TYPE_AI_MODEL_TRAIN = `ai_predict_model_train`
+	INDEX_TASK_TYPE_AI_MODEL_RUN   = `ai_predict_model_run`
+)