Roc 1 天之前
父节点
当前提交
4c35bba065

+ 183 - 12
controllers/ai_predict_model/index.go

@@ -7,6 +7,7 @@ import (
 	aiPredictModelLogic "eta/eta_index_lib/logic/ai_predict_model"
 	"eta/eta_index_lib/models"
 	"eta/eta_index_lib/models/ai_predict_model"
+	"eta/eta_index_lib/models/ai_predict_model/request"
 	"eta/eta_index_lib/models/ai_predict_model/response"
 	"eta/eta_index_lib/utils"
 	"fmt"
@@ -24,17 +25,12 @@ type IndexTaskRecordOp struct {
 	TaskType          string
 }
 
-// List
-// @Title 标的列表
-// @Description 标的列表
-// @Param   PageSize   query   int   true   "每页数据条数"
-// @Param   CurrentIndex   query   int   true   "当前页页码,从1开始"
-// @Param   ClassifyId   query   int   false   "分类id"
-// @Param   IndexId   query   int   false   "模型标的ID"
-// @Param   Keyword   query   string   false   "搜索关键词"
-// @Success 200 {object} data_manage.ChartListResp
+// OpToDo
+// @Title 获取待操作的标的
+// @Description 获取待操作的标的
+// @Success 200 {object} response.AiPredictModelIndexConfigResp
 // @router /op_todo [post]
-func (this *AiPredictModelIndexController) List() {
+func (this *AiPredictModelIndexController) OpToDo() {
 	br := new(models.BaseResponse).Init()
 	defer func() {
 		this.Data["json"] = br
@@ -72,8 +68,8 @@ func (this *AiPredictModelIndexController) List() {
 		return
 	}
 	if indexTaskRecordInfo.Status != `待处理` {
-		fmt.Println("任务状态不是待运行!")
-		br.Msg = "任务状态不是待运行"
+		fmt.Println("任务状态不是待处理!")
+		br.Msg = "任务状态不是待处理"
 		return
 	}
 
@@ -188,9 +184,184 @@ func (this *AiPredictModelIndexController) List() {
 	resp.AiPredictModelIndexId = indexConfigItem.AiPredictModelIndexId
 	resp.AiPredictModelIndexConfigId = indexConfigItem.AiPredictModelIndexConfigId
 	resp.ConfigParams = configParams
+	resp.ExecType = indexTaskRecordOp.TaskType
+	resp.ScriptPath = indexItem.ScriptPath
 
 	br.Data = resp
 	br.Ret = 200
 	br.Success = true
 	br.Msg = "获取成功"
 }
+
+// HandleTaskRecordFailByTaskRecord
+// @Title 任务操作失败
+// @Description 任务操作失败
+// @Success 200 {object} response.HandleTaskRecordFailReq
+// @router /handle/fail [post]
+func (this *AiPredictModelIndexController) HandleTaskRecordFailByTaskRecord() {
+	br := new(models.BaseResponse).Init()
+	defer func() {
+		this.Data["json"] = br
+		this.ServeJSON()
+	}()
+
+	var req request.HandleTaskRecordFailReq
+	err := json.Unmarshal(this.Ctx.Input.RequestBody, &req)
+	if err != nil {
+		br.Msg = "参数解析异常!"
+		br.ErrMsg = "参数解析失败,Err:" + err.Error()
+		return
+	}
+
+	indexTaskRecordObj := new(models.IndexTaskRecord)
+	indexTaskRecordInfo, err := indexTaskRecordObj.GetByID(req.IndexTaskRecordId)
+	if err != nil {
+		fmt.Println("get index task record info wrong!")
+		br.Msg = "获取失败"
+		return
+	}
+	if indexTaskRecordInfo.Status != `处理中` {
+		fmt.Println("任务状态不是处理中!")
+		br.Msg = "任务状态不是处理中"
+		return
+	}
+
+	indexTaskObj := models.IndexTask{}
+	indexTaskInfo, tmpErr := indexTaskObj.GetByID(indexTaskRecordInfo.IndexTaskID)
+	if tmpErr != nil {
+		err = fmt.Errorf("查找任务失败, err: %s", tmpErr.Error())
+		return
+	}
+
+	var indexConfigItem *ai_predict_model.AiPredictModelIndexConfig
+	var indexItem *ai_predict_model.AiPredictModelIndex
+	indexConfigObj := new(ai_predict_model.AiPredictModelIndexConfig)
+	indexOb := new(ai_predict_model.AiPredictModelIndex)
+
+	switch indexTaskInfo.TaskType {
+	case utils.INDEX_TASK_TYPE_AI_MODEL_TRAIN:
+		// 训练模型
+		indexConfigId, err := strconv.Atoi(indexTaskRecordInfo.Parameters) // 模型配置ID
+		if err != nil {
+			fmt.Println("模型配置ID转换错误!")
+			br.Msg = "模型配置ID转换错误"
+			br.ErrMsg = "模型配置ID转换错误,err:" + err.Error()
+			return
+		}
+
+		// 查找配置
+		indexConfigItem, err = indexConfigObj.GetById(indexConfigId)
+		if err != nil {
+			br.Msg = "获取模型配置失败"
+			br.ErrMsg = "获取模型配置失败,查找配置失败,Err:" + err.Error()
+			if utils.IsErrNoRow(err) {
+				br.Msg = "配置不存在"
+				br.IsSendEmail = false
+			}
+			return
+		}
+
+		// 查询标的情况
+		indexItem, err = indexOb.GetItemById(indexConfigItem.AiPredictModelIndexId)
+		if err != nil {
+			br.Msg = "训练失败,查找标的失败"
+			br.ErrMsg = fmt.Sprintf("训练失败,查找标的失败, %v", err)
+			if utils.IsErrNoRow(err) {
+				br.Msg = "标的不存在"
+				br.IsSendEmail = false
+			}
+			return
+		}
+
+	case utils.INDEX_TASK_TYPE_AI_MODEL_RUN:
+		// 运行模型
+
+		// 标的id转换
+		indexId, err := strconv.Atoi(indexTaskRecordInfo.Parameters)
+		if err != nil {
+			fmt.Println("标的ID转换错误!")
+			br.Msg = "标的ID转换错误"
+			br.ErrMsg = "标的ID转换错误,err:" + err.Error()
+			return
+		}
+
+		// 查询标的情况
+		indexItem, err = indexOb.GetItemById(indexId)
+		if err != nil {
+			br.Msg = "训练失败,查找标的失败"
+			br.ErrMsg = fmt.Sprintf("训练失败,查找标的失败, %v", err)
+			if utils.IsErrNoRow(err) {
+				br.Msg = "标的不存在"
+				br.IsSendEmail = false
+			}
+			return
+		}
+
+		// 查找配置
+		indexConfigItem, err = indexConfigObj.GetById(indexItem.AiPredictModelIndexConfigId)
+		if err != nil {
+			br.Msg = "获取模型配置失败"
+			br.ErrMsg = "获取模型配置失败,查找配置失败,Err:" + err.Error()
+			if utils.IsErrNoRow(err) {
+				br.Msg = "配置不存在"
+				br.IsSendEmail = false
+			}
+			return
+		}
+	}
+
+	// 标记处理任务失败
+	aiPredictModelLogic.HandleTaskRecordFailByTaskRecord(indexTaskInfo.TaskType, indexTaskRecordInfo, indexConfigItem, indexItem, req.FailMsg)
+
+	br.Ret = 200
+	br.Success = true
+	br.Msg = "处理成功"
+}
+
+// HandleTaskRecordSuccessByTaskRecord
+// @Title 任务操作成功
+// @Description 任务操作成功
+// @Success 200 {object} response.HandleTaskRecordFailReq
+// @router /handle/success [post]
+func (this *AiPredictModelIndexController) HandleTaskRecordSuccessByTaskRecord() {
+	br := new(models.BaseResponse).Init()
+	defer func() {
+		this.Data["json"] = br
+		this.ServeJSON()
+	}()
+
+	var req request.HandleTaskRecordSuccessReq
+	err := json.Unmarshal(this.Ctx.Input.RequestBody, &req)
+	if err != nil {
+		br.Msg = "参数解析异常!"
+		br.ErrMsg = "参数解析失败,Err:" + err.Error()
+		return
+	}
+
+	indexTaskRecordObj := new(models.IndexTaskRecord)
+	indexTaskRecordInfo, err := indexTaskRecordObj.GetByID(req.IndexTaskRecordId)
+	if err != nil {
+		fmt.Println("get index task record info wrong!")
+		br.Msg = "获取失败"
+		return
+	}
+	if indexTaskRecordInfo.Status != `处理中` {
+		fmt.Println("任务状态不是处理中!")
+		br.Msg = "任务状态不是处理中"
+		return
+	}
+
+	indexTaskObj := models.IndexTask{}
+	indexTaskInfo, tmpErr := indexTaskObj.GetByID(indexTaskRecordInfo.IndexTaskID)
+	if tmpErr != nil {
+		err = fmt.Errorf("查找任务失败, err: %s", tmpErr.Error())
+		return
+	}
+
+	// 标记处理任务失败
+	aiPredictModelLogic.HandleTaskRecordSuccessByTaskRecord(indexTaskInfo.TaskType, indexTaskRecordInfo, req.Data)
+
+	br.Ret = 200
+	br.Success = true
+	br.Msg = "处理成功"
+}

+ 248 - 9
logic/ai_predict_model/index.go

@@ -1,10 +1,13 @@
 package ai_predict_model
 
 import (
+	"encoding/json"
 	"eta/eta_index_lib/models"
-	"eta/eta_index_lib/models/ai_predict_model"
+	aiPredictModel "eta/eta_index_lib/models/ai_predict_model"
+	"eta/eta_index_lib/models/ai_predict_model/request"
 	"eta/eta_index_lib/utils"
 	"fmt"
+	"strconv"
 	"time"
 )
 
@@ -17,7 +20,7 @@ import (
 // @param indexConfigItem *ai_predict_model.AiPredictModelIndexConfig
 // @param indexItem *ai_predict_model.AiPredictModelIndex
 // @param errMsg string
-func HandleTaskRecordFailByTaskRecord(taskType string, indexTaskRecordInfo *models.IndexTaskRecord, indexConfigItem *ai_predict_model.AiPredictModelIndexConfig, indexItem *ai_predict_model.AiPredictModelIndex, errMsg string) {
+func HandleTaskRecordFailByTaskRecord(taskType string, indexTaskRecordInfo *models.IndexTaskRecord, indexConfigItem *aiPredictModel.AiPredictModelIndexConfig, indexItem *aiPredictModel.AiPredictModelIndex, errMsg string) {
 	var err error
 	defer func() {
 		if err != nil {
@@ -75,7 +78,7 @@ func HandleTaskRecordFailByTaskRecord(taskType string, indexTaskRecordInfo *mode
 	case utils.INDEX_TASK_TYPE_AI_MODEL_TRAIN: // 训练模型
 		// 修改模型状态信息
 		if indexItem != nil {
-			indexItem.TrainStatus = ai_predict_model.TrainStatusFailed
+			indexItem.TrainStatus = aiPredictModel.TrainStatusFailed
 			indexItem.ModifyTime = time.Now()
 			tmpErr := indexItem.Update([]string{"train_status", "modify_time"})
 			if tmpErr != nil {
@@ -85,7 +88,7 @@ func HandleTaskRecordFailByTaskRecord(taskType string, indexTaskRecordInfo *mode
 
 		// 修改模型配置状态信息
 		if indexConfigItem != nil {
-			indexConfigItem.TrainStatus = ai_predict_model.TrainStatusFailed
+			indexConfigItem.TrainStatus = aiPredictModel.TrainStatusFailed
 			indexConfigItem.Remark = errMsg
 			indexConfigItem.ModifyTime = time.Now()
 			tmpErr := indexConfigItem.Update([]string{"train_status", `remark`, "modify_time"})
@@ -96,7 +99,7 @@ func HandleTaskRecordFailByTaskRecord(taskType string, indexTaskRecordInfo *mode
 
 	case utils.INDEX_TASK_TYPE_AI_MODEL_RUN: // 运行模型
 		if indexItem != nil {
-			indexItem.RunStatus = ai_predict_model.RunStatusFailed
+			indexItem.RunStatus = aiPredictModel.RunStatusFailed
 			indexItem.ModifyTime = time.Now()
 			tmpErr := indexItem.Update([]string{"run_status", "modify_time"})
 			if tmpErr != nil {
@@ -120,7 +123,7 @@ func HandleTaskRecordFailByTaskRecord(taskType string, indexTaskRecordInfo *mode
 // @param indexTaskRecordInfo *models.IndexTaskRecord
 // @param indexConfigItem *ai_predict_model.AiPredictModelIndexConfig
 // @param indexItem *ai_predict_model.AiPredictModelIndex
-func HandleTaskRecordProcessingByTaskRecord(taskType string, indexTaskRecordInfo *models.IndexTaskRecord, indexConfigItem *ai_predict_model.AiPredictModelIndexConfig, indexItem *ai_predict_model.AiPredictModelIndex) {
+func HandleTaskRecordProcessingByTaskRecord(taskType string, indexTaskRecordInfo *models.IndexTaskRecord, indexConfigItem *aiPredictModel.AiPredictModelIndexConfig, indexItem *aiPredictModel.AiPredictModelIndex) {
 	var err error
 	defer func() {
 		if err != nil {
@@ -177,7 +180,7 @@ func HandleTaskRecordProcessingByTaskRecord(taskType string, indexTaskRecordInfo
 	case utils.INDEX_TASK_TYPE_AI_MODEL_TRAIN: // 训练模型
 		// 修改模型状态信息
 		if indexItem != nil {
-			indexItem.TrainStatus = ai_predict_model.TrainStatusTraining
+			indexItem.TrainStatus = aiPredictModel.TrainStatusTraining
 			indexItem.ModifyTime = time.Now()
 			tmpErr := indexItem.Update([]string{"train_status", "modify_time"})
 			if tmpErr != nil {
@@ -187,7 +190,7 @@ func HandleTaskRecordProcessingByTaskRecord(taskType string, indexTaskRecordInfo
 
 		// 修改模型配置状态信息
 		if indexConfigItem != nil {
-			indexConfigItem.TrainStatus = ai_predict_model.TrainStatusTraining
+			indexConfigItem.TrainStatus = aiPredictModel.TrainStatusTraining
 			indexConfigItem.ModifyTime = time.Now()
 			tmpErr := indexConfigItem.Update([]string{"train_status", "modify_time"})
 			if tmpErr != nil {
@@ -198,7 +201,7 @@ func HandleTaskRecordProcessingByTaskRecord(taskType string, indexTaskRecordInfo
 	case utils.INDEX_TASK_TYPE_AI_MODEL_RUN: // 运行模型
 		// 修改模型状态信息
 		if indexItem != nil {
-			indexItem.RunStatus = ai_predict_model.RunStatusRunning
+			indexItem.RunStatus = aiPredictModel.RunStatusRunning
 			indexItem.ModifyTime = time.Now()
 			tmpErr := indexItem.Update([]string{"run_status", "modify_time"})
 			if tmpErr != nil {
@@ -213,3 +216,239 @@ func HandleTaskRecordProcessingByTaskRecord(taskType string, indexTaskRecordInfo
 
 	return
 }
+
+// HandleTaskRecordSuccessByTaskRecord
+// @Description: 标记处理完成
+// @author: Roc
+// @datetime 2025-05-14 16:00:26
+// @param taskType string
+// @param indexTaskRecordInfo *models.IndexTaskRecord
+// @param aiPredictModelImportData request.AiPredictModelImportData
+func HandleTaskRecordSuccessByTaskRecord(taskType string, indexTaskRecordInfo *models.IndexTaskRecord, aiPredictModelImportData request.AiPredictModelImportData) {
+	var err error
+	defer func() {
+		if err != nil {
+			utils.FileLog.Error(fmt.Sprintf(`HandleTaskRecordFailByTaskRecord err:%v`, err))
+		}
+	}()
+
+	// 修改子任务状态
+	indexTaskRecordInfo.Status = `处理成功`
+	//indexTaskRecordInfo.Remark = errMsg
+	indexTaskRecordInfo.ModifyTime = time.Now()
+	err = indexTaskRecordInfo.Update([]string{"status", "modify_time"})
+	if err != nil {
+		fmt.Println("修改子任务状态失败!")
+		return
+	}
+
+	// 处理完成后标记任务状态
+	defer func() {
+		obj := models.IndexTaskRecord{}
+		// 修改任务状态
+		todoCount, tmpErr := obj.GetCountByCondition(fmt.Sprintf(` AND %s = ? AND %s = ? `, models.IndexTaskRecordColumns.IndexTaskID, models.IndexTaskRecordColumns.Status), []interface{}{indexTaskRecordInfo.IndexTaskID, `待处理`})
+		if tmpErr != nil {
+			err = fmt.Errorf("查找剩余任务数量失败, err: %s", tmpErr.Error())
+			return
+		}
+		if todoCount <= 0 {
+			indexTaskObj := models.IndexTask{}
+			indexTaskInfo, tmpErr := indexTaskObj.GetByID(indexTaskRecordInfo.IndexTaskID)
+			if tmpErr != nil {
+				err = fmt.Errorf("查找任务失败, err: %s", tmpErr.Error())
+				return
+			}
+			tmpUpdateCols := []string{`end_time`, "status", "update_time"}
+			indexTaskInfo.EndTime = time.Now()
+			indexTaskInfo.Status = `处理成功`
+			indexTaskInfo.UpdateTime = time.Now()
+
+			if indexTaskInfo.StartTime.IsZero() {
+				indexTaskInfo.StartTime = time.Now()
+				tmpUpdateCols = append(tmpUpdateCols, "start_time")
+			}
+
+			tmpErr = indexTaskInfo.Update(tmpUpdateCols)
+			if tmpErr != nil {
+				utils.FileLog.Error("标记任务状态失败, err: %s", tmpErr.Error())
+			}
+		}
+
+		return
+	}()
+
+	indexOb := new(aiPredictModel.AiPredictModelIndex)
+
+	// 修改模型状态
+	switch taskType {
+	case utils.INDEX_TASK_TYPE_AI_MODEL_TRAIN: // 训练模型
+		// 训练模型
+		indexConfigId, tmpErr := strconv.Atoi(indexTaskRecordInfo.Parameters) // 模型配置ID
+		if tmpErr != nil {
+			err = fmt.Errorf("模型配置ID转换错误, err: %s", tmpErr.Error())
+			return
+		}
+
+		indexConfigObj := new(aiPredictModel.AiPredictModelIndexConfig)
+		// 查找配置
+		indexConfigItem, tmpErr := indexConfigObj.GetById(indexConfigId)
+		if tmpErr != nil {
+			err = fmt.Errorf("获取模型配置失败, err: %s", tmpErr.Error())
+			return
+		}
+
+		// 查询标的情况
+		indexItem, tmpErr := indexOb.GetItemById(indexConfigItem.AiPredictModelIndexId)
+		if err != nil {
+			err = fmt.Errorf("获取标的失败, err: %s", tmpErr.Error())
+			return
+		}
+
+		handleTaskRecordSuccessByTrain(indexConfigItem, indexItem)
+
+	case utils.INDEX_TASK_TYPE_AI_MODEL_RUN: // 运行模型
+
+		// 标的id转换
+		indexId, tmpErr := strconv.Atoi(indexTaskRecordInfo.Parameters)
+		if err != nil {
+			err = fmt.Errorf("标的ID转换错误, err: %s", tmpErr.Error())
+			return
+		}
+
+		// 查询标的情况
+		indexItem, tmpErr := indexOb.GetItemById(indexId)
+		if tmpErr != nil {
+			err = fmt.Errorf("训练失败,查找标的失败, err: %s", tmpErr.Error())
+			return
+		}
+
+		tmpErr = handleTaskRecordSuccessByRun(aiPredictModelImportData, indexItem)
+		if tmpErr != nil {
+			utils.FileLog.Error("%d,修改模型运行状态失败, err: %s", indexItem.AiPredictModelIndexId, tmpErr.Error())
+		}
+
+	default:
+
+		return
+	}
+
+	return
+}
+
+func handleTaskRecordSuccessByTrain(indexConfigItem *aiPredictModel.AiPredictModelIndexConfig, indexItem *aiPredictModel.AiPredictModelIndex) {
+	// 修改模型状态信息
+	if indexItem != nil {
+		indexItem.TrainStatus = aiPredictModel.TrainStatusSuccess
+		indexItem.ModifyTime = time.Now()
+		tmpErr := indexItem.Update([]string{"train_status", "modify_time"})
+		if tmpErr != nil {
+			utils.FileLog.Error("%d,修改模型训练状态失败, err: %s", indexItem.AiPredictModelIndexId, tmpErr.Error())
+		}
+	}
+
+	// 修改模型配置状态信息
+	if indexConfigItem != nil {
+		indexConfigItem.TrainStatus = aiPredictModel.TrainStatusSuccess
+		indexConfigItem.Remark = `成功`
+		indexConfigItem.ModifyTime = time.Now()
+		tmpErr := indexConfigItem.Update([]string{"train_status", `remark`, "modify_time"})
+		if tmpErr != nil {
+			utils.FileLog.Error("%d,修改模型训练状态失败, err: %s", indexItem.AiPredictModelIndexId, tmpErr.Error())
+		}
+	}
+}
+
+// handleTaskRecordSuccessByRun
+// @Description: 运行中的数据处理
+// @author: Roc
+// @datetime 2025-05-14 14:28:11
+// @param aiPredictModelImportData request.AiPredictModelImportData
+// @param indexItem *aiPredictModel.AiPredictModelIndex
+// @return err error
+func handleTaskRecordSuccessByRun(aiPredictModelImportData request.AiPredictModelImportData, indexItem *aiPredictModel.AiPredictModelIndex) (err error) {
+	defer func() {
+		defer func() {
+			if err != nil {
+				utils.FileLog.Error(fmt.Sprintf(`handleTaskRecordSuccessByRun err:%v`, err))
+			}
+		}()
+	}()
+	// 查询已存在的标的
+	indexOb := new(aiPredictModel.AiPredictModelIndex)
+	indexNameItem := make(map[string]*aiPredictModel.AiPredictModelIndex)
+	{
+		list, e := indexOb.GetItemsByCondition("", make([]interface{}, 0), []string{}, "")
+		if e != nil {
+			err = fmt.Errorf("获取标的失败, %v", e)
+			return
+		}
+		for _, v := range list {
+			indexNameItem[v.IndexName] = v
+		}
+	}
+
+	updateCols := []string{indexOb.Cols().RunStatus, indexOb.Cols().PredictValue, indexOb.Cols().DirectionAccuracy, indexOb.Cols().AbsoluteDeviation, indexOb.Cols().ExtraConfig, indexOb.Cols().ModifyTime}
+
+	// 预测日期,理论上是需要改的,可是产品说不需要改,所以暂时不改
+	updateCols = append(updateCols, indexOb.Cols().PredictDate)
+	indexItem.RunStatus = aiPredictModel.RunStatusSuccess
+	indexItem.PredictValue = aiPredictModelImportData.Index.PredictValue
+	indexItem.DirectionAccuracy = aiPredictModelImportData.Index.DirectionAccuracy
+	indexItem.AbsoluteDeviation = aiPredictModelImportData.Index.AbsoluteDeviation
+	indexItem.ModifyTime = time.Now()
+
+	predictDate, e := time.ParseInLocation(utils.FormatDate, aiPredictModelImportData.Index.PredictDate, time.Local)
+	if e != nil {
+		err = fmt.Errorf("预测日期解析失败, %v", e)
+		return
+	}
+	indexItem.PredictDate = predictDate
+
+	// 图例信息
+	if indexItem.ExtraConfig != "" && aiPredictModelImportData.Index.ExtraConfig != "" {
+		var oldConfig, newConfig aiPredictModel.AiPredictModelIndexExtraConfig
+		if e := json.Unmarshal([]byte(indexItem.ExtraConfig), &oldConfig); e != nil {
+			err = fmt.Errorf("标的原配置解析失败, Config: %s, Err: %v", indexItem.ExtraConfig, e)
+			return
+		}
+		if e := json.Unmarshal([]byte(aiPredictModelImportData.Index.ExtraConfig), &newConfig); e != nil {
+			err = fmt.Errorf("标的新配置解析失败, Config: %s, Err: %v", aiPredictModelImportData.Index.ExtraConfig, e)
+			return
+		}
+		oldConfig.DailyChart.PredictLegendName = newConfig.DailyChart.PredictLegendName
+		b, _ := json.Marshal(oldConfig)
+		indexItem.ExtraConfig = string(b)
+	}
+
+	dataList := make([]*aiPredictModel.AiPredictModelData, 0)
+	for _, tmpData := range aiPredictModelImportData.Data {
+		tmpDate, e := time.ParseInLocation(utils.FormatDate, tmpData.DataTime, time.Local)
+		if e != nil {
+			err = fmt.Errorf("数据日期解析失败, %v", e)
+			return
+		}
+
+		dataList = append(dataList, &aiPredictModel.AiPredictModelData{
+			//AiPredictModelDataId:  0,
+			AiPredictModelIndexId: indexItem.AiPredictModelIndexId,
+			IndexCode:             indexItem.IndexCode,
+			DataTime:              tmpDate,
+			Value:                 tmpData.Value,
+			PredictValue:          tmpData.PredictValue,
+			Direction:             tmpData.Direction,
+			DeviationRate:         tmpData.DeviationRate,
+			CreateTime:            time.Now(),
+			ModifyTime:            time.Now(),
+			DataTimestamp:         tmpData.DataTimestamp,
+			Source:                tmpData.Source,
+		})
+	}
+
+	// 更新指标
+	err = indexOb.UpdateIndexAndData(indexItem, dataList, updateCols)
+	if err != nil {
+		return
+	}
+
+	return
+}

+ 7 - 0
models/ai_predict_model/ai_predict_model_data.go

@@ -121,6 +121,13 @@ func (m *AiPredictModelData) GetItemById(id int) (item *AiPredictModelData, err
 	return
 }
 
+func (m *AiPredictModelData) GetItemByModelIndexId(aiPredictModelIndexId int) (items []*AiPredictModelData, err error) {
+	o := global.DbMap[utils.DbNameIndex]
+	sqlRun := fmt.Sprintf(`SELECT * FROM %s WHERE %s = ? LIMIT 1`, m.TableName(), m.Cols().AiPredictModelIndexId)
+	err = o.Raw(sqlRun, aiPredictModelIndexId).Find(&items).Error
+	return
+}
+
 func (m *AiPredictModelData) GetItemByCondition(condition string, pars []interface{}, orderRule string) (item *AiPredictModelData, err error) {
 	o := global.DbMap[utils.DbNameIndex]
 	order := ``

+ 161 - 0
models/ai_predict_model/ai_predict_model_index.go

@@ -521,3 +521,164 @@ func (m *AiPredictModelIndex) UpdateRunStatusByIdList(indexIdList []int) (err er
 
 	return
 }
+
+// UpdateIndexAndData 导入数据
+func (m *AiPredictModelIndex) UpdateIndexAndData(modelIndexItem *AiPredictModelIndex, dataList []*AiPredictModelData, updateCols []string) (err error) {
+	o := global.DbMap[utils.DbNameIndex]
+	tx := o.Begin()
+	defer func() {
+		if err != nil {
+			_ = tx.Rollback()
+			return
+		}
+		_ = tx.Commit()
+	}()
+
+	// 更新指标
+	e := tx.Select(updateCols).Updates(modelIndexItem).Error
+	if e != nil {
+		err = fmt.Errorf("update index err: %v", e)
+		return
+	}
+
+	var existDataList []*AiPredictModelData
+	// 查询标的的所有数据
+	sqlRun := `SELECT * FROM ai_predict_model_data WHERE ai_predict_model_index_id = ? ORDER BY ai_predict_model_data_id DESC`
+	err = tx.Raw(sqlRun, modelIndexItem.AiPredictModelIndexId).Find(&existDataList).Error
+	if err != nil {
+		err = fmt.Errorf("find index data err: %v", e)
+		return
+	}
+	existDailyMap := make(map[string]*AiPredictModelData)
+	existMonthlyMap := make(map[string]*AiPredictModelData)
+	removeDailyDateMap := make(map[string]bool)
+	removeMonthlyDateMap := make(map[string]bool)
+
+	for _, d := range existDataList {
+		tmpDate := d.DataTime.Format(utils.FormatDate)
+		if d.Source == ModelDataSourceDaily {
+			existDailyMap[tmpDate] = d
+			removeDailyDateMap[tmpDate] = true
+		}
+		if d.Source == ModelDataSourceMonthly {
+			existMonthlyMap[tmpDate] = d
+			removeMonthlyDateMap[tmpDate] = true
+		}
+	}
+
+	addDataList := make([]*AiPredictModelData, 0)
+	for _, tmpData := range dataList {
+		tmpData.AiPredictModelIndexId = modelIndexItem.AiPredictModelIndexId
+		tmpData.IndexCode = modelIndexItem.IndexCode
+		tmpData.DataTimestamp = tmpData.DataTime.UnixNano() / 1e6
+
+		// 档期日期
+		tmpDate := tmpData.DataTime.Format(utils.FormatDate)
+
+		if tmpData.Source == ModelDataSourceDaily {
+			delete(removeDailyDateMap, tmpDate)
+			if existData, ok := existDailyMap[tmpDate]; ok {
+				// 修改
+				dataUpdateCols := make([]string, 0)
+				if existData.Value != tmpData.Value {
+					existData.Value = tmpData.Value
+					dataUpdateCols = append(dataUpdateCols, "Value")
+				}
+				if existData.PredictValue != tmpData.PredictValue {
+					existData.PredictValue = tmpData.PredictValue
+					dataUpdateCols = append(dataUpdateCols, "PredictValue")
+				}
+				if existData.Direction != tmpData.Direction {
+					existData.Direction = tmpData.Direction
+					dataUpdateCols = append(dataUpdateCols, "Direction")
+				}
+				if existData.DeviationRate != tmpData.DeviationRate {
+					existData.DeviationRate = tmpData.DeviationRate
+					dataUpdateCols = append(dataUpdateCols, "DeviationRate")
+				}
+
+				if len(dataUpdateCols) > 0 {
+					existData.ModifyTime = time.Now()
+					dataUpdateCols = append(dataUpdateCols, "ModifyTime")
+					tmpErr := tx.Select(dataUpdateCols).Updates(existData).Error
+					if tmpErr != nil {
+						utils.FileLog.Error("update index data err: %v", tmpErr)
+					}
+				}
+
+			} else {
+				addDataList = append(addDataList, tmpData)
+			}
+		}
+
+		if tmpData.Source == ModelDataSourceMonthly {
+			delete(removeMonthlyDateMap, tmpDate)
+			if existData, ok := existMonthlyMap[tmpDate]; ok {
+				// 修改
+				dataUpdateCols := make([]string, 0)
+				if existData.Value != tmpData.Value {
+					existData.Value = tmpData.Value
+					dataUpdateCols = append(dataUpdateCols, "Value")
+				}
+				if existData.PredictValue.Float64 != tmpData.PredictValue.Float64 {
+					existData.PredictValue = tmpData.PredictValue
+					dataUpdateCols = append(dataUpdateCols, "PredictValue")
+				}
+				if existData.Direction != tmpData.Direction {
+					existData.Direction = tmpData.Direction
+					dataUpdateCols = append(dataUpdateCols, "Direction")
+				}
+				if existData.DeviationRate != tmpData.DeviationRate {
+					existData.DeviationRate = tmpData.DeviationRate
+					dataUpdateCols = append(dataUpdateCols, "DeviationRate")
+				}
+
+				if len(dataUpdateCols) > 0 {
+					existData.ModifyTime = time.Now()
+					dataUpdateCols = append(dataUpdateCols, "ModifyTime")
+					tmpErr := tx.Select(dataUpdateCols).Updates(existData).Error
+					if tmpErr != nil {
+						utils.FileLog.Error("update index data err: %v", tmpErr)
+					}
+				}
+			} else {
+				addDataList = append(addDataList, tmpData)
+			}
+		}
+	}
+
+	// 清除不要了的日度指标
+	if len(removeDailyDateMap) > 0 {
+		removeDateList := make([]string, 0)
+		for date := range removeDailyDateMap {
+			removeDateList = append(removeDateList, date)
+		}
+		sql := `DELETE FROM ai_predict_model_data WHERE ai_predict_model_index_id = ? AND source = ? AND data_time IN (?)`
+		e = tx.Exec(sql, modelIndexItem.AiPredictModelIndexId, ModelDataSourceDaily, removeDateList).Error
+		if e != nil {
+			err = fmt.Errorf("clear index daily data err: %v", e)
+			return
+		}
+	}
+
+	// 清除不要了的月度指标
+	if len(removeMonthlyDateMap) > 0 {
+		removeDateList := make([]string, 0)
+		for date := range removeMonthlyDateMap {
+			removeDateList = append(removeDateList, date)
+		}
+		sql := `DELETE FROM ai_predict_model_data WHERE ai_predict_model_index_id = ? AND source = ? AND data_time IN (?)`
+		e = tx.Exec(sql, modelIndexItem.AiPredictModelIndexId, ModelDataSourceMonthly, removeDateList).Error
+		if e != nil {
+			err = fmt.Errorf("clear index monthly data err: %v", e)
+			return
+		}
+	}
+	e = tx.CreateInBatches(addDataList, utils.MultiAddNum).Error
+	if e != nil {
+		err = fmt.Errorf("insert index data err: %v", e)
+		return
+	}
+
+	return
+}

+ 86 - 0
models/ai_predict_model/request/index.go

@@ -0,0 +1,86 @@
+package request
+
+import (
+	"database/sql"
+)
+
+// HandleTaskRecordFailReq
+// @Description: 任务处理失败
+type HandleTaskRecordFailReq struct {
+	IndexTaskRecordId int    `description:"子任务id"`
+	FailMsg           string `description:"失败原因"`
+}
+
+// HandleTaskRecordSuccessReq
+// @Description: 任务处理成功
+type HandleTaskRecordSuccessReq struct {
+	IndexTaskRecordId int                      `gorm:"primaryKey;column:ai_predict_model_index_config_id" description:"子任务记录id"`
+	Data              AiPredictModelImportData `description:"导入的指标"`
+}
+
+// AiPredictModelIndex AI预测模型标的
+type AiPredictModelIndex struct {
+	AiPredictModelIndexId       int     `orm:"column(ai_predict_model_index_id);pk" gorm:"primaryKey"`
+	IndexName                   string  `description:"标的名称"`
+	IndexCode                   string  `description:"自生成的指标编码"`
+	ClassifyId                  int     `description:"分类ID"`
+	ModelFramework              string  `description:"模型框架"`
+	PredictDate                 string  `description:"预测日期"`
+	PredictValue                float64 `description:"预测值"`
+	PredictFrequency            string  `description:"预测频度"`
+	DirectionAccuracy           string  `description:"方向准确度"`
+	AbsoluteDeviation           string  `description:"绝对偏差"`
+	ExtraConfig                 string  `description:"模型参数"`
+	Sort                        int     `description:"排序"`
+	SysUserId                   int     `description:"创建人ID"`
+	SysUserRealName             string  `description:"创建人姓名"`
+	LeftMin                     string  `description:"图表左侧最小值"`
+	LeftMax                     string  `description:"图表左侧最大值"`
+	AiPredictModelIndexConfigId int     `gorm:"column:ai_predict_model_index_config_id" description:"标的当前的配置id"`
+	ScriptPath                  string  `gorm:"column:script_path" description:"脚本的路径"`
+	TrainStatus                 string  `gorm:"column:train_status" description:"训练状态,枚举值:待训练,训练中,训练成功,训练失败"`
+	RunStatus                   string  `gorm:"column:run_status" description:"运行状态,枚举值:待运行,运行中,运行成功,运行失败"`
+}
+
+// AiPredictModelData AI预测模型标的数据
+type AiPredictModelData struct {
+	DataTime      string          `description:"数据日期"`
+	Value         sql.NullFloat64 `description:"实际值"`
+	PredictValue  sql.NullFloat64 `description:"预测值"`
+	Direction     string          `description:"方向"`
+	DeviationRate string          `description:"偏差率"`
+	DataTimestamp int64           `description:"数据日期时间戳"`
+	Source        int             `description:"来源:1-月度预测(默认);2-日度预测"`
+}
+
+type AiPredictModelImportData struct {
+	Index     *AiPredictModelIndex
+	Data      []*AiPredictModelData
+	TrainData TrainData
+}
+
+// TrainData
+// @Description: 训练结果
+type TrainData struct {
+	TrainMse float64
+	TrainR2  float64
+	TestMse  float64
+	TestR2   float64
+}
+
+type AiPredictModelIndexExtraConfig struct {
+	MonthlyChart MonthlyChartConfig
+	DailyChart   DailyChartConfig
+}
+type MonthlyChartConfig struct {
+	LeftMin string `description:"图表左侧最小值"`
+	LeftMax string `description:"图表左侧最大值"`
+	Unit    string `description:"单位"`
+}
+
+type DailyChartConfig struct {
+	LeftMin           string `description:"图表左侧最小值"`
+	LeftMax           string `description:"图表左侧最大值"`
+	Unit              string `description:"单位"`
+	PredictLegendName string `description:"预测图例的名称(通常为Predicted)"`
+}

+ 2 - 0
models/ai_predict_model/response/index.go

@@ -5,6 +5,8 @@ type AiPredictModelIndexConfigResp struct {
 	AiPredictModelIndexId       int          `gorm:"column:ai_predict_model_index_id" description:"ai预测模型id"`
 	AiPredictModelIndexConfigId int          `gorm:"primaryKey;column:ai_predict_model_index_config_id" description:"ai预测模型配置ID"`
 	ConfigParams                ConfigParams `gorm:"column:train_status" description:"运行/训练参数"`
+	ExecType                    string       `description:"枚举值:ai_predict_model_train(训练)、ai_predict_model_run(运行)"`
+	ScriptPath                  string       `description:"脚本路径"`
 }
 
 // ConfigParams

+ 8 - 8
models/index_task.go

@@ -68,32 +68,32 @@ var IndexTaskColumns = struct {
 }
 
 func (m *IndexTask) Create() (err error) {
-	err = global.DbMap[utils.DbNameAI].Create(&m).Error
+	err = global.DEFAULT_DB.Create(&m).Error
 
 	return
 }
 
 func (m *IndexTask) Update(updateCols []string) (err error) {
-	err = global.DbMap[utils.DbNameAI].Select(updateCols).Updates(&m).Error
+	err = global.DEFAULT_DB.Select(updateCols).Updates(&m).Error
 
 	return
 }
 
 func (m *IndexTask) Del() (err error) {
-	err = global.DbMap[utils.DbNameAI].Delete(&m).Error
+	err = global.DEFAULT_DB.Delete(&m).Error
 
 	return
 }
 
 func (m *IndexTask) GetByID(id int) (item *IndexTask, err error) {
-	err = global.DbMap[utils.DbNameAI].Where(fmt.Sprintf("%s = ?", IndexTaskColumns.IndexTaskID), id).First(&item).Error
+	err = global.DEFAULT_DB.Where(fmt.Sprintf("%s = ?", IndexTaskColumns.IndexTaskID), id).First(&item).Error
 
 	return
 }
 
 func (m *IndexTask) GetByCondition(condition string, pars []interface{}) (item *IndexTask, err error) {
 	sqlStr := fmt.Sprintf(`SELECT * FROM %s WHERE 1=1 %s`, m.TableName(), condition)
-	err = global.DbMap[utils.DbNameAI].Raw(sqlStr, pars...).First(&item).Error
+	err = global.DEFAULT_DB.Raw(sqlStr, pars...).First(&item).Error
 
 	return
 }
@@ -104,7 +104,7 @@ func (m *IndexTask) GetListByCondition(field, condition string, pars []interface
 	}
 	sqlStr := fmt.Sprintf(`SELECT %s FROM %s WHERE 1=1 %s order by index_task_id desc LIMIT ?,?`, field, m.TableName(), condition)
 	pars = append(pars, startSize, pageSize)
-	err = global.DbMap[utils.DbNameAI].Raw(sqlStr, pars...).Find(&items).Error
+	err = global.DEFAULT_DB.Raw(sqlStr, pars...).Find(&items).Error
 
 	return
 }
@@ -112,7 +112,7 @@ func (m *IndexTask) GetListByCondition(field, condition string, pars []interface
 func (m *IndexTask) GetCountByCondition(condition string, pars []interface{}) (total int, err error) {
 	var intNull sql.NullInt64
 	sqlStr := fmt.Sprintf(`SELECT COUNT(1) total FROM %s WHERE 1=1 %s`, m.TableName(), condition)
-	err = global.DbMap[utils.DbNameAI].Raw(sqlStr, pars...).Scan(&intNull).Error
+	err = global.DEFAULT_DB.Raw(sqlStr, pars...).Scan(&intNull).Error
 	if err == nil && intNull.Valid {
 		total = int(intNull.Int64)
 	}
@@ -128,7 +128,7 @@ func (m *IndexTask) GetCountByCondition(condition string, pars []interface{}) (t
 // @param indexRecordList []*IndexTaskRecord
 // @return err error
 func AddIndexTask(indexTask *IndexTask, indexRecordList []*IndexTaskRecord) (err error) {
-	to := global.DbMap[utils.DbNameAI].Begin()
+	to := global.DEFAULT_DB.Begin()
 	defer func() {
 		if err != nil {
 			_ = to.Rollback()

+ 8 - 9
models/index_task_record.go

@@ -3,7 +3,6 @@ package models
 import (
 	"database/sql"
 	"eta/eta_index_lib/global"
-	"eta/eta_index_lib/utils"
 	"fmt"
 	"time"
 )
@@ -44,32 +43,32 @@ var IndexTaskRecordColumns = struct {
 }
 
 func (m *IndexTaskRecord) Create() (err error) {
-	err = global.DbMap[utils.DbNameAI].Create(&m).Error
+	err = global.DEFAULT_DB.Create(&m).Error
 
 	return
 }
 
 func (m *IndexTaskRecord) Update(updateCols []string) (err error) {
-	err = global.DbMap[utils.DbNameAI].Select(updateCols).Updates(&m).Error
+	err = global.DEFAULT_DB.Select(updateCols).Updates(&m).Error
 
 	return
 }
 
 func (m *IndexTaskRecord) Del() (err error) {
-	err = global.DbMap[utils.DbNameAI].Delete(&m).Error
+	err = global.DEFAULT_DB.Delete(&m).Error
 
 	return
 }
 
 func (m *IndexTaskRecord) GetByID(id int) (item *IndexTaskRecord, err error) {
-	err = global.DbMap[utils.DbNameAI].Where(fmt.Sprintf("%s = ?", IndexTaskRecordColumns.IndexTaskRecordID), id).First(&item).Error
+	err = global.DEFAULT_DB.Where(fmt.Sprintf("%s = ?", IndexTaskRecordColumns.IndexTaskRecordID), id).First(&item).Error
 
 	return
 }
 
 func (m *IndexTaskRecord) GetByCondition(condition string, pars []interface{}) (item *IndexTaskRecord, err error) {
 	sqlStr := fmt.Sprintf(`SELECT * FROM %s WHERE 1=1 %s`, m.TableName(), condition)
-	err = global.DbMap[utils.DbNameAI].Raw(sqlStr, pars...).First(&item).Error
+	err = global.DEFAULT_DB.Raw(sqlStr, pars...).First(&item).Error
 
 	return
 }
@@ -79,7 +78,7 @@ func (m *IndexTaskRecord) GetAllListByCondition(field, condition string, pars []
 		field = "*"
 	}
 	sqlStr := fmt.Sprintf(`SELECT %s FROM %s WHERE 1=1 %s order by index_task_record_id desc `, field, m.TableName(), condition)
-	err = global.DbMap[utils.DbNameAI].Raw(sqlStr, pars...).Find(&items).Error
+	err = global.DEFAULT_DB.Raw(sqlStr, pars...).Find(&items).Error
 
 	return
 }
@@ -90,7 +89,7 @@ func (m *IndexTaskRecord) GetListByCondition(field, condition string, pars []int
 	}
 	sqlStr := fmt.Sprintf(`SELECT %s FROM %s WHERE 1=1 %s order by index_task_record_id desc LIMIT ?,?`, field, m.TableName(), condition)
 	pars = append(pars, startSize, pageSize)
-	err = global.DbMap[utils.DbNameAI].Raw(sqlStr, pars...).Find(&items).Error
+	err = global.DEFAULT_DB.Raw(sqlStr, pars...).Find(&items).Error
 
 	return
 }
@@ -98,7 +97,7 @@ func (m *IndexTaskRecord) GetListByCondition(field, condition string, pars []int
 func (m *IndexTaskRecord) GetCountByCondition(condition string, pars []interface{}) (total int, err error) {
 	var intNull sql.NullInt64
 	sqlStr := fmt.Sprintf(`SELECT COUNT(1) total FROM %s WHERE 1=1 %s`, m.TableName(), condition)
-	err = global.DbMap[utils.DbNameAI].Raw(sqlStr, pars...).Scan(&intNull).Error
+	err = global.DEFAULT_DB.Raw(sqlStr, pars...).Scan(&intNull).Error
 	if err == nil && intNull.Valid {
 		total = int(intNull.Int64)
 	}

+ 19 - 1
routers/commentsRouter.go

@@ -9,7 +9,25 @@ func init() {
 
     beego.GlobalControllerRouter["eta/eta_index_lib/controllers/ai_predict_model:AiPredictModelIndexController"] = append(beego.GlobalControllerRouter["eta/eta_index_lib/controllers/ai_predict_model:AiPredictModelIndexController"],
         beego.ControllerComments{
-            Method: "List",
+            Method: "HandleTaskRecordFailByTaskRecord",
+            Router: `/handle/fail`,
+            AllowHTTPMethods: []string{"post"},
+            MethodParams: param.Make(),
+            Filters: nil,
+            Params: nil})
+
+    beego.GlobalControllerRouter["eta/eta_index_lib/controllers/ai_predict_model:AiPredictModelIndexController"] = append(beego.GlobalControllerRouter["eta/eta_index_lib/controllers/ai_predict_model:AiPredictModelIndexController"],
+        beego.ControllerComments{
+            Method: "HandleTaskRecordSuccessByTaskRecord",
+            Router: `/handle/success`,
+            AllowHTTPMethods: []string{"post"},
+            MethodParams: param.Make(),
+            Filters: nil,
+            Params: nil})
+
+    beego.GlobalControllerRouter["eta/eta_index_lib/controllers/ai_predict_model:AiPredictModelIndexController"] = append(beego.GlobalControllerRouter["eta/eta_index_lib/controllers/ai_predict_model:AiPredictModelIndexController"],
+        beego.ControllerComments{
+            Method: "OpToDo",
             Router: `/op_todo`,
             AllowHTTPMethods: []string{"post"},
             MethodParams: param.Make(),