Ver código fonte

refactor(ai_predict_model): 更新指数数据和配置

- 新增 UpdateIndexAndData 方法,用于导入数据并更新指数
- 在 AiPredictModelIndexConfig 中添加模型框架、预测日期等字段
- 实现 SetCurrIndexConfig 方法,用于设置当前指数配置
- 优化数据更新逻辑,支持批量插入和删除
Roc 3 dias atrás
pai
commit
232bef9dba

+ 161 - 0
models/ai_predict_model/ai_predict_model_index.go

@@ -456,6 +456,167 @@ func (m *AiPredictModelIndex) ImportIndexAndData(createIndexes, updateIndexes []
 	return
 }
 
+// UpdateIndexAndData 导入数据
+func (m *AiPredictModelIndex) UpdateIndexAndData(modelIndexItem *AiPredictModelIndex, dataList []*AiPredictModelData, updateCols []string) (err error) {
+	o := global.DbMap[utils.DbNameIndex]
+	tx := o.Begin()
+	defer func() {
+		if err != nil {
+			_ = tx.Rollback()
+			return
+		}
+		_ = tx.Commit()
+	}()
+
+	// 更新指标
+	e := tx.Select(updateCols).Updates(modelIndexItem).Error
+	if e != nil {
+		err = fmt.Errorf("update index err: %v", e)
+		return
+	}
+
+	var existDataList []*AiPredictModelData
+	// 查询标的的所有数据
+	sqlRun := `SELECT * FROM ai_predict_model_data WHERE ai_predict_model_index_id = ? ORDER BY ai_predict_model_data_id DESC`
+	err = tx.Raw(sqlRun, modelIndexItem.AiPredictModelIndexId).Find(&existDataList).Error
+	if err != nil {
+		err = fmt.Errorf("find index data err: %v", e)
+		return
+	}
+	existDailyMap := make(map[string]*AiPredictModelData)
+	existMonthlyMap := make(map[string]*AiPredictModelData)
+	removeDailyDateMap := make(map[string]bool)
+	removeMonthlyDateMap := make(map[string]bool)
+
+	for _, d := range existDataList {
+		tmpDate := d.DataTime.Format(utils.FormatDate)
+		if d.Source == ModelDataSourceDaily {
+			existDailyMap[tmpDate] = d
+			removeDailyDateMap[tmpDate] = true
+		}
+		if d.Source == ModelDataSourceMonthly {
+			existMonthlyMap[tmpDate] = d
+			removeMonthlyDateMap[tmpDate] = true
+		}
+	}
+
+	addDataList := make([]*AiPredictModelData, 0)
+	for _, tmpData := range dataList {
+		tmpData.AiPredictModelIndexId = modelIndexItem.AiPredictModelIndexId
+		tmpData.IndexCode = modelIndexItem.IndexCode
+		tmpData.DataTimestamp = tmpData.DataTime.UnixNano() / 1e6
+
+		// 档期日期
+		tmpDate := tmpData.DataTime.Format(utils.FormatDate)
+
+		if tmpData.Source == ModelDataSourceDaily {
+			delete(removeDailyDateMap, tmpDate)
+			if existData, ok := existDailyMap[tmpDate]; ok {
+				// 修改
+				dataUpdateCols := make([]string, 0)
+				if existData.Value != tmpData.Value {
+					existData.Value = tmpData.Value
+					dataUpdateCols = append(dataUpdateCols, "Value")
+				}
+				if existData.PredictValue != tmpData.PredictValue {
+					existData.PredictValue = tmpData.PredictValue
+					dataUpdateCols = append(dataUpdateCols, "PredictValue")
+				}
+				if existData.Direction != tmpData.Direction {
+					existData.Direction = tmpData.Direction
+					dataUpdateCols = append(dataUpdateCols, "Direction")
+				}
+				if existData.DeviationRate != tmpData.DeviationRate {
+					existData.DeviationRate = tmpData.DeviationRate
+					dataUpdateCols = append(dataUpdateCols, "DeviationRate")
+				}
+
+				if len(dataUpdateCols) > 0 {
+					existData.ModifyTime = time.Now()
+					dataUpdateCols = append(dataUpdateCols, "ModifyTime")
+					tmpErr := tx.Select(dataUpdateCols).Updates(existData).Error
+					if tmpErr != nil {
+						utils.FileLog.Error("update index data err: %v", tmpErr)
+					}
+				}
+
+			} else {
+				addDataList = append(addDataList, tmpData)
+			}
+		}
+
+		if tmpData.Source == ModelDataSourceMonthly {
+			delete(removeMonthlyDateMap, tmpDate)
+			if existData, ok := existMonthlyMap[tmpDate]; ok {
+				// 修改
+				dataUpdateCols := make([]string, 0)
+				if existData.Value != tmpData.Value {
+					existData.Value = tmpData.Value
+					dataUpdateCols = append(dataUpdateCols, "Value")
+				}
+				if existData.PredictValue.Float64 != tmpData.PredictValue.Float64 {
+					existData.PredictValue = tmpData.PredictValue
+					dataUpdateCols = append(dataUpdateCols, "PredictValue")
+				}
+				if existData.Direction != tmpData.Direction {
+					existData.Direction = tmpData.Direction
+					dataUpdateCols = append(dataUpdateCols, "Direction")
+				}
+				if existData.DeviationRate != tmpData.DeviationRate {
+					existData.DeviationRate = tmpData.DeviationRate
+					dataUpdateCols = append(dataUpdateCols, "DeviationRate")
+				}
+
+				if len(dataUpdateCols) > 0 {
+					existData.ModifyTime = time.Now()
+					dataUpdateCols = append(dataUpdateCols, "ModifyTime")
+					tmpErr := tx.Select(dataUpdateCols).Updates(existData).Error
+					if tmpErr != nil {
+						utils.FileLog.Error("update index data err: %v", tmpErr)
+					}
+				}
+			} else {
+				addDataList = append(addDataList, tmpData)
+			}
+		}
+	}
+
+	// 清除不要了的日度指标
+	if len(removeDailyDateMap) > 0 {
+		removeDateList := make([]string, 0)
+		for date := range removeDailyDateMap {
+			removeDateList = append(removeDateList, date)
+		}
+		sql := `DELETE FROM ai_predict_model_data WHERE ai_predict_model_index_id = ? AND source = ? AND data_time IN (?)`
+		e = tx.Exec(sql, modelIndexItem.AiPredictModelIndexId, ModelDataSourceDaily, removeDateList).Error
+		if e != nil {
+			err = fmt.Errorf("clear index daily data err: %v", e)
+			return
+		}
+	}
+
+	// 清除不要了的月度指标
+	if len(removeMonthlyDateMap) > 0 {
+		removeDateList := make([]string, 0)
+		for date := range removeMonthlyDateMap {
+			removeDateList = append(removeDateList, date)
+		}
+		sql := `DELETE FROM ai_predict_model_data WHERE ai_predict_model_index_id = ? AND source = ? AND data_time IN (?)`
+		e = tx.Exec(sql, modelIndexItem.AiPredictModelIndexId, ModelDataSourceMonthly, removeDateList).Error
+		if e != nil {
+			err = fmt.Errorf("clear index monthly data err: %v", e)
+			return
+		}
+	}
+	e = tx.CreateInBatches(addDataList, utils.MultiAddNum).Error
+	if e != nil {
+		err = fmt.Errorf("insert index data err: %v", e)
+		return
+	}
+
+	return
+}
+
 type AiPredictModelDetailResp struct {
 	TableData      []*AiPredictModelDataItem        `description:"表格数据"`
 	ChartView      *data_manage.ChartInfoDetailResp `description:"月度预测数据图表"`

+ 7 - 0
models/ai_predict_model/ai_predict_model_index_config.go

@@ -22,6 +22,13 @@ type AiPredictModelIndexConfig struct {
 	AiPredictModelIndexId       int       `gorm:"column:ai_predict_model_index_id" description:"ai预测模型id"`
 	TrainStatus                 string    `gorm:"column:train_status" description:"训练状态,枚举值:待训练,训练中,训练成功,训练失败"`
 	Params                      string    `gorm:"column:params" description:"训练参数,json字符串存储,便于以后参数扩展"`
+	ModelFramework              string    `description:"模型框架"`
+	PredictDate                 time.Time `description:"预测日期"`
+	PredictValue                float64   `description:"预测值"`
+	PredictFrequency            string    `description:"预测频度"`
+	DirectionAccuracy           string    `description:"方向准确度"`
+	AbsoluteDeviation           string    `description:"绝对偏差"`
+	ExtraConfig                 string    `description:"模型参数"`
 	TrainMse                    string    `gorm:"column:train_mse" description:"训练集mse"`
 	TrainR2                     string    `gorm:"column:train_r2" description:"训练集r2"`
 	TestMse                     string    `gorm:"column:test_mse" description:"测试集mse"`

+ 72 - 0
services/ai_predict_model_index.go

@@ -654,3 +654,75 @@ func GetCurrentRunningAiPredictModelIndexCount() (total int, err error) {
 	total, err = obj.GetCountByCondition(" AND (train_status in (?) OR run_status in (?) ) ", []interface{}{[]string{aiPredictModel.TrainStatusTraining, aiPredictModel.TrainStatusWaiting}, []string{aiPredictModel.RunStatusRunning, aiPredictModel.RunStatusWaiting}})
 	return
 }
+
+func SetCurrIndexConfig(indexItem *aiPredictModel.AiPredictModelIndex, indexConfigItem *aiPredictModel.AiPredictModelIndexConfig) (err error) {
+	indexOb := aiPredictModel.AiPredictModelIndex{}
+
+	updateCols := []string{indexOb.Cols().AiPredictModelIndexConfigId, indexOb.Cols().RunStatus, indexOb.Cols().PredictValue, indexOb.Cols().DirectionAccuracy, indexOb.Cols().AbsoluteDeviation, indexOb.Cols().ExtraConfig, indexOb.Cols().ModifyTime}
+	// 预测日期,理论上是需要改的,可是产品说不需要改,所以暂时不改
+	updateCols = append(updateCols, indexOb.Cols().PredictDate)
+
+	indexItem.AiPredictModelIndexConfigId = indexConfigItem.AiPredictModelIndexConfigId
+	indexItem.RunStatus = aiPredictModel.RunStatusSuccess
+	indexItem.PredictValue = indexConfigItem.PredictValue
+	indexItem.DirectionAccuracy = indexConfigItem.DirectionAccuracy
+	indexItem.AbsoluteDeviation = indexConfigItem.AbsoluteDeviation
+	indexItem.PredictDate = indexConfigItem.PredictDate
+	indexItem.ModifyTime = time.Now()
+
+	// 图例信息
+	if indexItem.ExtraConfig != "" && indexConfigItem.ExtraConfig != "" {
+		var oldConfig, newConfig aiPredictModel.AiPredictModelIndexExtraConfig
+		if e := json.Unmarshal([]byte(indexItem.ExtraConfig), &oldConfig); e != nil {
+			err = fmt.Errorf("标的原配置解析失败, Config: %s, Err: %v", indexItem.ExtraConfig, e)
+			return
+		}
+		if e := json.Unmarshal([]byte(indexConfigItem.ExtraConfig), &newConfig); e != nil {
+			err = fmt.Errorf("标的新配置解析失败, Config: %s, Err: %v", indexConfigItem.ExtraConfig, e)
+			return
+		}
+		oldConfig.DailyChart.PredictLegendName = newConfig.DailyChart.PredictLegendName
+		b, _ := json.Marshal(oldConfig)
+		indexItem.ExtraConfig = string(b)
+	}
+
+	dataList := make([]*aiPredictModel.AiPredictModelData, 0)
+
+	trainDataList := make([]*aiPredictModel.AiPredictModelIndexConfigTrainData, 0)
+	{
+		dataOb := new(aiPredictModel.AiPredictModelIndexConfigTrainData)
+		dataCond := fmt.Sprintf(` AND %s = ?`, aiPredictModel.AiPredictModelIndexConfigTrainDataColumns.AiPredictModelIndexConfigId)
+		dataPars := make([]interface{}, 0)
+		dataPars = append(dataPars, indexConfigItem.AiPredictModelIndexConfigId)
+		trainDataList, err = dataOb.GetAllListByCondition(dataCond, dataPars, []string{}, fmt.Sprintf("%s DESC", aiPredictModel.AiPredictModelIndexConfigTrainDataColumns.DataTime))
+		if err != nil {
+			err = fmt.Errorf("获取训练结果数据失败, Err: %v", err)
+			return
+		}
+	}
+
+	for _, tmpData := range trainDataList {
+		dataList = append(dataList, &aiPredictModel.AiPredictModelData{
+			//AiPredictModelDataId:  0,
+			AiPredictModelIndexId: indexItem.AiPredictModelIndexId,
+			IndexCode:             indexItem.IndexCode,
+			DataTime:              tmpData.DataTime,
+			Value:                 tmpData.Value,
+			PredictValue:          tmpData.PredictValue,
+			Direction:             tmpData.Direction,
+			DeviationRate:         tmpData.DeviationRate,
+			CreateTime:            time.Now(),
+			ModifyTime:            time.Now(),
+			DataTimestamp:         tmpData.DataTimestamp,
+			Source:                int(tmpData.Source),
+		})
+	}
+
+	// 更新指标和数据
+	err = indexOb.UpdateIndexAndData(indexItem, dataList, updateCols)
+	if err != nil {
+		return
+	}
+
+	return
+}