Эх сурвалжийг харах

feat:训练成功后的处理逻辑

Roc 1 өдөр өмнө
parent
commit
fe82f74f90

+ 1 - 1
controllers/ai_predict_model/index.go

@@ -358,7 +358,7 @@ func (this *AiPredictModelIndexController) HandleTaskRecordSuccessByTaskRecord()
 		return
 	}
 
-	// 标记处理任务失败
+	// 标记处理任务成功
 	aiPredictModelLogic.HandleTaskRecordSuccessByTaskRecord(indexTaskInfo.TaskType, indexTaskRecordInfo, req.Data)
 
 	br.Ret = 200

+ 64 - 28
logic/ai_predict_model/index.go

@@ -304,7 +304,7 @@ func HandleTaskRecordSuccessByTaskRecord(taskType string, indexTaskRecordInfo *m
 			return
 		}
 
-		handleTaskRecordSuccessByTrain(indexConfigItem, indexItem)
+		handleTaskRecordSuccessByTrain(aiPredictModelImportData, indexConfigItem, indexItem)
 
 	case utils.INDEX_TASK_TYPE_AI_MODEL_RUN: // 运行模型
 
@@ -335,27 +335,74 @@ func HandleTaskRecordSuccessByTaskRecord(taskType string, indexTaskRecordInfo *m
 	return
 }
 
-func handleTaskRecordSuccessByTrain(indexConfigItem *aiPredictModel.AiPredictModelIndexConfig, indexItem *aiPredictModel.AiPredictModelIndex) {
-	// 修改模型状态信息
-	if indexItem != nil {
-		indexItem.TrainStatus = aiPredictModel.TrainStatusSuccess
-		indexItem.ModifyTime = time.Now()
-		tmpErr := indexItem.Update([]string{"train_status", "modify_time"})
-		if tmpErr != nil {
-			utils.FileLog.Error("%d,修改模型训练状态失败, err: %s", indexItem.AiPredictModelIndexId, tmpErr.Error())
+// handleTaskRecordSuccessByTrain
+// @Description: 处理模型训练成功后的操作
+// @author: Roc
+// @datetime 2025-05-14 18:25:12
+// @param aiPredictModelImportData request.AiPredictModelImportData
+// @param indexConfigItem *aiPredictModel.AiPredictModelIndexConfig
+// @param indexItem *aiPredictModel.AiPredictModelIndex
+// @return err error
+func handleTaskRecordSuccessByTrain(aiPredictModelImportData request.AiPredictModelImportData, indexConfigItem *aiPredictModel.AiPredictModelIndexConfig, indexItem *aiPredictModel.AiPredictModelIndex) (err error) {
+	defer func() {
+		if err != nil {
+			utils.FileLog.Error(fmt.Sprintf(`handleTaskRecordSuccessByTrain err:%v`, err))
 		}
-	}
+	}()
 
-	// 修改模型配置状态信息
-	if indexConfigItem != nil {
+	// 标的状态修改
+	updateIndexCols := []string{"train_status", "modify_time"}
+	indexItem.TrainStatus = aiPredictModel.TrainStatusSuccess
+	indexItem.ModifyTime = time.Now()
+
+	updateIndexConfigCols := []string{"train_status", `remark`, "modify_time", `train_mse`, `train_r2`, `test_mse`, `test_r2`}
+	// 配置状态修改
+	{
+		// 训练参数
+		trainData := aiPredictModelImportData.TrainData
 		indexConfigItem.TrainStatus = aiPredictModel.TrainStatusSuccess
 		indexConfigItem.Remark = `成功`
+		indexConfigItem.TrainMse = fmt.Sprint(trainData.TrainMse)
+		indexConfigItem.TrainR2 = fmt.Sprint(trainData.TrainR2)
+		indexConfigItem.TestMse = fmt.Sprint(trainData.TestMse)
+		indexConfigItem.TestR2 = fmt.Sprint(trainData.TestR2)
 		indexConfigItem.ModifyTime = time.Now()
-		tmpErr := indexConfigItem.Update([]string{"train_status", `remark`, "modify_time"})
-		if tmpErr != nil {
-			utils.FileLog.Error("%d,修改模型训练状态失败, err: %s", indexItem.AiPredictModelIndexId, tmpErr.Error())
+	}
+
+	indexConfigOb := new(aiPredictModel.AiPredictModelIndexConfig)
+
+	dataList := make([]*aiPredictModel.AiPredictModelIndexConfigTrainData, 0)
+	for _, tmpData := range aiPredictModelImportData.Data {
+		tmpDate, e := time.ParseInLocation(utils.FormatDate, tmpData.DataTime, time.Local)
+		if e != nil {
+			err = fmt.Errorf("数据日期解析失败, %v", e)
+			return
 		}
+
+		dataList = append(dataList, &aiPredictModel.AiPredictModelIndexConfigTrainData{
+			//AiPredictModelDataId:  0,
+			AiPredictModelIndexConfigId: indexConfigItem.AiPredictModelIndexConfigId,
+			AiPredictModelIndexId:       indexItem.AiPredictModelIndexId,
+			IndexCode:                   indexItem.IndexCode,
+			DataTime:                    tmpDate,
+			Value:                       tmpData.Value,
+			PredictValue:                tmpData.PredictValue,
+			Direction:                   tmpData.Direction,
+			DeviationRate:               tmpData.DeviationRate,
+			CreateTime:                  time.Now(),
+			ModifyTime:                  time.Now(),
+			DataTimestamp:               tmpData.DataTimestamp,
+			Source:                      tmpData.Source,
+		})
+	}
+
+	// 更新指标和数据
+	err = indexConfigOb.UpdateIndexAndData(indexItem, indexConfigItem, dataList, updateIndexCols, updateIndexConfigCols)
+	if err != nil {
+		return
 	}
+
+	return
 }
 
 // handleTaskRecordSuccessByRun
@@ -373,19 +420,8 @@ func handleTaskRecordSuccessByRun(aiPredictModelImportData request.AiPredictMode
 			}
 		}()
 	}()
-	// 查询已存在的标的
+
 	indexOb := new(aiPredictModel.AiPredictModelIndex)
-	indexNameItem := make(map[string]*aiPredictModel.AiPredictModelIndex)
-	{
-		list, e := indexOb.GetItemsByCondition("", make([]interface{}, 0), []string{}, "")
-		if e != nil {
-			err = fmt.Errorf("获取标的失败, %v", e)
-			return
-		}
-		for _, v := range list {
-			indexNameItem[v.IndexName] = v
-		}
-	}
 
 	updateCols := []string{indexOb.Cols().RunStatus, indexOb.Cols().PredictValue, indexOb.Cols().DirectionAccuracy, indexOb.Cols().AbsoluteDeviation, indexOb.Cols().ExtraConfig, indexOb.Cols().ModifyTime}
 
@@ -444,7 +480,7 @@ func handleTaskRecordSuccessByRun(aiPredictModelImportData request.AiPredictMode
 		})
 	}
 
-	// 更新指标
+	// 更新指标和数据
 	err = indexOb.UpdateIndexAndData(indexItem, dataList, updateCols)
 	if err != nil {
 		return

+ 36 - 0
models/ai_predict_model/ai_predict_model_index_config.go

@@ -177,3 +177,39 @@ func (m *AiPredictModelIndexConfig) GetPageListByCondition(condition string, par
 
 	return
 }
+
+// UpdateIndexAndData 导入数据
+func (m *AiPredictModelIndexConfig) UpdateIndexAndData(modelIndexItem *AiPredictModelIndex, modelIndexConfigItem *AiPredictModelIndexConfig, dataList []*AiPredictModelIndexConfigTrainData, updateIndexCols, updateIndexConfigCols []string) (err error) {
+	o := global.DbMap[utils.DbNameIndex]
+	tx := o.Begin()
+	defer func() {
+		if err != nil {
+			_ = tx.Rollback()
+			return
+		}
+		_ = tx.Commit()
+	}()
+
+	// 更新标的
+	e := tx.Select(updateIndexCols).Updates(modelIndexItem).Error
+	if e != nil {
+		err = fmt.Errorf("update index err: %v", e)
+		return
+	}
+
+	// 更新模型配置
+	e = tx.Select(updateIndexConfigCols).Updates(modelIndexConfigItem).Error
+	if e != nil {
+		err = fmt.Errorf("update index err: %v", e)
+		return
+	}
+
+	// 添加训练数据
+	e = tx.CreateInBatches(dataList, utils.MultiAddNum).Error
+	if e != nil {
+		err = fmt.Errorf("insert index data err: %v", e)
+		return
+	}
+
+	return
+}

+ 1 - 1
models/ai_predict_model/ai_predict_model_index_config_train_data.go

@@ -23,7 +23,7 @@ type AiPredictModelIndexConfigTrainData struct {
 	CreateTime                           time.Time       `gorm:"column:create_time" description:"创建时间"`
 	ModifyTime                           time.Time       `gorm:"column:modify_time" description:"修改时间"`
 	DataTimestamp                        int64           `gorm:"column:data_timestamp" description:"数据日期时间戳"`
-	Source                               uint8           `gorm:"column:source" description:"来源:1-月度预测;2-日度预测"`
+	Source                               int             `gorm:"column:source" description:"来源:1-月度预测;2-日度预测"`
 }
 
 // TableName get sql table name.获取数据库表名