index.go 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490
  1. package ai_predict_model
  2. import (
  3. "encoding/json"
  4. "eta/eta_index_lib/models"
  5. aiPredictModel "eta/eta_index_lib/models/ai_predict_model"
  6. "eta/eta_index_lib/models/ai_predict_model/request"
  7. "eta/eta_index_lib/utils"
  8. "fmt"
  9. "strconv"
  10. "time"
  11. )
  12. // HandleTaskRecordFailByTaskRecord
  13. // @Description: 任务标记失败
  14. // @author: Roc
  15. // @datetime 2025-05-09 16:24:48
  16. // @param taskType string
  17. // @param indexTaskRecordInfo *models.IndexTaskRecord
  18. // @param indexConfigItem *ai_predict_model.AiPredictModelIndexConfig
  19. // @param indexItem *ai_predict_model.AiPredictModelIndex
  20. // @param errMsg string
  21. func HandleTaskRecordFailByTaskRecord(taskType string, indexTaskRecordInfo *models.IndexTaskRecord, indexConfigItem *aiPredictModel.AiPredictModelIndexConfig, indexItem *aiPredictModel.AiPredictModelIndex, errMsg string) {
  22. var err error
  23. defer func() {
  24. if err != nil {
  25. utils.FileLog.Error(fmt.Sprintf(`HandleTaskRecordFailByTaskRecord err:%v`, err))
  26. }
  27. }()
  28. // 修改子任务状态
  29. indexTaskRecordInfo.Status = `处理失败`
  30. indexTaskRecordInfo.Remark = errMsg
  31. indexTaskRecordInfo.ModifyTime = time.Now()
  32. err = indexTaskRecordInfo.Update([]string{"status", "remark", "modify_time"})
  33. if err != nil {
  34. fmt.Println("修改子任务状态失败!")
  35. return
  36. }
  37. // 处理完成后标记任务状态
  38. defer func() {
  39. obj := models.IndexTaskRecord{}
  40. // 修改任务状态
  41. todoCount, tmpErr := obj.GetCountByCondition(fmt.Sprintf(` AND %s = ? AND %s = ? `, models.IndexTaskRecordColumns.IndexTaskID, models.IndexTaskRecordColumns.Status), []interface{}{indexTaskRecordInfo.IndexTaskID, `待处理`})
  42. if tmpErr != nil {
  43. err = fmt.Errorf("查找剩余任务数量失败, err: %s", tmpErr.Error())
  44. return
  45. }
  46. if todoCount <= 0 {
  47. indexTaskObj := models.IndexTask{}
  48. indexTaskInfo, tmpErr := indexTaskObj.GetByID(indexTaskRecordInfo.IndexTaskID)
  49. if tmpErr != nil {
  50. err = fmt.Errorf("查找任务失败, err: %s", tmpErr.Error())
  51. return
  52. }
  53. tmpUpdateCols := []string{`end_time`, "status", "update_time"}
  54. indexTaskInfo.EndTime = time.Now()
  55. indexTaskInfo.Status = `处理成功`
  56. indexTaskInfo.UpdateTime = time.Now()
  57. if indexTaskInfo.StartTime.IsZero() {
  58. indexTaskInfo.StartTime = time.Now()
  59. tmpUpdateCols = append(tmpUpdateCols, "start_time")
  60. }
  61. tmpErr = indexTaskInfo.Update(tmpUpdateCols)
  62. if tmpErr != nil {
  63. utils.FileLog.Error("标记任务状态失败, err: %s", tmpErr.Error())
  64. }
  65. }
  66. return
  67. }()
  68. // 修改模型状态
  69. switch taskType {
  70. case utils.INDEX_TASK_TYPE_AI_MODEL_TRAIN: // 训练模型
  71. // 修改模型状态信息
  72. if indexItem != nil {
  73. indexItem.TrainStatus = aiPredictModel.TrainStatusFailed
  74. indexItem.ModifyTime = time.Now()
  75. tmpErr := indexItem.Update([]string{"train_status", "modify_time"})
  76. if tmpErr != nil {
  77. utils.FileLog.Error("%d,修改模型训练状态失败, err: %s", indexItem.AiPredictModelIndexId, tmpErr.Error())
  78. }
  79. }
  80. // 修改模型配置状态信息
  81. if indexConfigItem != nil {
  82. indexConfigItem.TrainStatus = aiPredictModel.TrainStatusFailed
  83. indexConfigItem.Remark = errMsg
  84. indexConfigItem.ModifyTime = time.Now()
  85. tmpErr := indexConfigItem.Update([]string{"train_status", `remark`, "modify_time"})
  86. if tmpErr != nil {
  87. utils.FileLog.Error("%d,修改模型训练状态失败, err: %s", indexItem.AiPredictModelIndexId, tmpErr.Error())
  88. }
  89. }
  90. case utils.INDEX_TASK_TYPE_AI_MODEL_RUN: // 运行模型
  91. if indexItem != nil {
  92. indexItem.RunStatus = aiPredictModel.RunStatusFailed
  93. indexItem.ModifyTime = time.Now()
  94. tmpErr := indexItem.Update([]string{"run_status", "modify_time"})
  95. if tmpErr != nil {
  96. utils.FileLog.Error("%d,修改模型运行状态失败, err: %s", indexItem.AiPredictModelIndexId, tmpErr.Error())
  97. }
  98. }
  99. default:
  100. return
  101. }
  102. return
  103. }
  104. // 有处理的任务id集合
  105. var hasHandleTaskIdMap = make(map[int]bool)
  106. // HandleTaskRecordProcessingByTaskRecord
  107. // @Description: 任务标记处理中
  108. // @author: Roc
  109. // @datetime 2025-05-09 16:24:38
  110. // @param taskType string
  111. // @param indexTaskRecordInfo *models.IndexTaskRecord
  112. // @param indexConfigItem *ai_predict_model.AiPredictModelIndexConfig
  113. // @param indexItem *ai_predict_model.AiPredictModelIndex
  114. func HandleTaskRecordProcessingByTaskRecord(taskType string, indexTaskRecordInfo *models.IndexTaskRecord, indexConfigItem *aiPredictModel.AiPredictModelIndexConfig, indexItem *aiPredictModel.AiPredictModelIndex) {
  115. var err error
  116. defer func() {
  117. if err != nil {
  118. utils.FileLog.Error(fmt.Sprintf(`HandleTaskRecordFailByTaskRecord err:%v`, err))
  119. }
  120. }()
  121. // 修改子任务状态
  122. indexTaskRecordInfo.Status = `处理中`
  123. indexTaskRecordInfo.ModifyTime = time.Now()
  124. err = indexTaskRecordInfo.Update([]string{"status", "modify_time"})
  125. if err != nil {
  126. fmt.Println("修改子任务状态失败!")
  127. return
  128. }
  129. // 处理完成后标记任务状态
  130. defer func() {
  131. // 如果没有标记处理中的任务ID,那么需要修改任务状态
  132. if _, ok := hasHandleTaskIdMap[indexTaskRecordInfo.IndexTaskID]; ok {
  133. return
  134. }
  135. indexTaskObj := models.IndexTask{}
  136. indexTaskInfo, tmpErr := indexTaskObj.GetByID(indexTaskRecordInfo.IndexTaskID)
  137. if tmpErr != nil {
  138. err = fmt.Errorf("查找任务失败, err: %s", tmpErr.Error())
  139. return
  140. }
  141. tmpUpdateCols := []string{`end_time`, "status", "update_time"}
  142. indexTaskInfo.Status = `处理中`
  143. indexTaskInfo.UpdateTime = time.Now()
  144. if indexTaskInfo.StartTime.IsZero() {
  145. indexTaskInfo.StartTime = time.Now()
  146. tmpUpdateCols = append(tmpUpdateCols, "start_time")
  147. }
  148. tmpErr = indexTaskInfo.Update(tmpUpdateCols)
  149. if tmpErr != nil {
  150. utils.FileLog.Error("标记任务状态失败, err: %s", tmpErr.Error())
  151. }
  152. return
  153. }()
  154. // 修改模型状态
  155. switch taskType {
  156. case utils.INDEX_TASK_TYPE_AI_MODEL_TRAIN: // 训练模型
  157. // 修改模型状态信息
  158. if indexItem != nil {
  159. indexItem.TrainStatus = aiPredictModel.TrainStatusTraining
  160. indexItem.ModifyTime = time.Now()
  161. tmpErr := indexItem.Update([]string{"train_status", "modify_time"})
  162. if tmpErr != nil {
  163. utils.FileLog.Error("%d,修改模型训练状态失败, err: %s", indexItem.AiPredictModelIndexId, tmpErr.Error())
  164. }
  165. }
  166. // 修改模型配置状态信息
  167. if indexConfigItem != nil {
  168. indexConfigItem.TrainStatus = aiPredictModel.TrainStatusTraining
  169. indexConfigItem.ModifyTime = time.Now()
  170. tmpErr := indexConfigItem.Update([]string{"train_status", "modify_time"})
  171. if tmpErr != nil {
  172. utils.FileLog.Error("%d,修改模型训练状态失败, err: %s", indexItem.AiPredictModelIndexId, tmpErr.Error())
  173. }
  174. }
  175. case utils.INDEX_TASK_TYPE_AI_MODEL_RUN: // 运行模型
  176. // 修改模型状态信息
  177. if indexItem != nil {
  178. indexItem.RunStatus = aiPredictModel.RunStatusRunning
  179. indexItem.ModifyTime = time.Now()
  180. tmpErr := indexItem.Update([]string{"run_status", "modify_time"})
  181. if tmpErr != nil {
  182. utils.FileLog.Error("%d,修改模型运行状态失败, err: %s", indexItem.AiPredictModelIndexId, tmpErr.Error())
  183. }
  184. }
  185. default:
  186. return
  187. }
  188. return
  189. }
  190. // HandleTaskRecordSuccessByTaskRecord
  191. // @Description: 标记处理完成
  192. // @author: Roc
  193. // @datetime 2025-05-14 16:00:26
  194. // @param taskType string
  195. // @param indexTaskRecordInfo *models.IndexTaskRecord
  196. // @param aiPredictModelImportData request.AiPredictModelImportData
  197. func HandleTaskRecordSuccessByTaskRecord(taskType string, indexTaskRecordInfo *models.IndexTaskRecord, aiPredictModelImportData request.AiPredictModelImportData) {
  198. var err error
  199. defer func() {
  200. if err != nil {
  201. utils.FileLog.Error(fmt.Sprintf(`HandleTaskRecordFailByTaskRecord err:%v`, err))
  202. }
  203. }()
  204. // 修改子任务状态
  205. indexTaskRecordInfo.Status = `处理成功`
  206. //indexTaskRecordInfo.Remark = errMsg
  207. indexTaskRecordInfo.ModifyTime = time.Now()
  208. err = indexTaskRecordInfo.Update([]string{"status", "modify_time"})
  209. if err != nil {
  210. fmt.Println("修改子任务状态失败!")
  211. return
  212. }
  213. // 处理完成后标记任务状态
  214. defer func() {
  215. obj := models.IndexTaskRecord{}
  216. // 修改任务状态
  217. todoCount, tmpErr := obj.GetCountByCondition(fmt.Sprintf(` AND %s = ? AND %s = ? `, models.IndexTaskRecordColumns.IndexTaskID, models.IndexTaskRecordColumns.Status), []interface{}{indexTaskRecordInfo.IndexTaskID, `待处理`})
  218. if tmpErr != nil {
  219. err = fmt.Errorf("查找剩余任务数量失败, err: %s", tmpErr.Error())
  220. return
  221. }
  222. if todoCount <= 0 {
  223. indexTaskObj := models.IndexTask{}
  224. indexTaskInfo, tmpErr := indexTaskObj.GetByID(indexTaskRecordInfo.IndexTaskID)
  225. if tmpErr != nil {
  226. err = fmt.Errorf("查找任务失败, err: %s", tmpErr.Error())
  227. return
  228. }
  229. tmpUpdateCols := []string{`end_time`, "status", "update_time"}
  230. indexTaskInfo.EndTime = time.Now()
  231. indexTaskInfo.Status = `处理成功`
  232. indexTaskInfo.UpdateTime = time.Now()
  233. if indexTaskInfo.StartTime.IsZero() {
  234. indexTaskInfo.StartTime = time.Now()
  235. tmpUpdateCols = append(tmpUpdateCols, "start_time")
  236. }
  237. tmpErr = indexTaskInfo.Update(tmpUpdateCols)
  238. if tmpErr != nil {
  239. utils.FileLog.Error("标记任务状态失败, err: %s", tmpErr.Error())
  240. }
  241. }
  242. return
  243. }()
  244. indexOb := new(aiPredictModel.AiPredictModelIndex)
  245. // 修改模型状态
  246. switch taskType {
  247. case utils.INDEX_TASK_TYPE_AI_MODEL_TRAIN: // 训练模型
  248. // 训练模型
  249. indexConfigId, tmpErr := strconv.Atoi(indexTaskRecordInfo.Parameters) // 模型配置ID
  250. if tmpErr != nil {
  251. err = fmt.Errorf("模型配置ID转换错误, err: %s", tmpErr.Error())
  252. return
  253. }
  254. indexConfigObj := new(aiPredictModel.AiPredictModelIndexConfig)
  255. // 查找配置
  256. indexConfigItem, tmpErr := indexConfigObj.GetById(indexConfigId)
  257. if tmpErr != nil {
  258. err = fmt.Errorf("获取模型配置失败, err: %s", tmpErr.Error())
  259. return
  260. }
  261. // 查询标的情况
  262. indexItem, tmpErr := indexOb.GetItemById(indexConfigItem.AiPredictModelIndexId)
  263. if err != nil {
  264. err = fmt.Errorf("获取标的失败, err: %s", tmpErr.Error())
  265. return
  266. }
  267. handleTaskRecordSuccessByTrain(aiPredictModelImportData, indexConfigItem, indexItem)
  268. case utils.INDEX_TASK_TYPE_AI_MODEL_RUN: // 运行模型
  269. // 标的id转换
  270. indexId, tmpErr := strconv.Atoi(indexTaskRecordInfo.Parameters)
  271. if err != nil {
  272. err = fmt.Errorf("标的ID转换错误, err: %s", tmpErr.Error())
  273. return
  274. }
  275. // 查询标的情况
  276. indexItem, tmpErr := indexOb.GetItemById(indexId)
  277. if tmpErr != nil {
  278. err = fmt.Errorf("训练失败,查找标的失败, err: %s", tmpErr.Error())
  279. return
  280. }
  281. tmpErr = handleTaskRecordSuccessByRun(aiPredictModelImportData, indexItem)
  282. if tmpErr != nil {
  283. utils.FileLog.Error("%d,修改模型运行状态失败, err: %s", indexItem.AiPredictModelIndexId, tmpErr.Error())
  284. }
  285. default:
  286. return
  287. }
  288. return
  289. }
  290. // handleTaskRecordSuccessByTrain
  291. // @Description: 处理模型训练成功后的操作
  292. // @author: Roc
  293. // @datetime 2025-05-14 18:25:12
  294. // @param aiPredictModelImportData request.AiPredictModelImportData
  295. // @param indexConfigItem *aiPredictModel.AiPredictModelIndexConfig
  296. // @param indexItem *aiPredictModel.AiPredictModelIndex
  297. // @return err error
  298. func handleTaskRecordSuccessByTrain(aiPredictModelImportData request.AiPredictModelImportData, indexConfigItem *aiPredictModel.AiPredictModelIndexConfig, indexItem *aiPredictModel.AiPredictModelIndex) (err error) {
  299. defer func() {
  300. if err != nil {
  301. utils.FileLog.Error(fmt.Sprintf(`handleTaskRecordSuccessByTrain err:%v`, err))
  302. }
  303. }()
  304. // 标的状态修改
  305. updateIndexCols := []string{"train_status", "modify_time"}
  306. indexItem.TrainStatus = aiPredictModel.TrainStatusSuccess
  307. indexItem.ModifyTime = time.Now()
  308. updateIndexConfigCols := []string{"train_status", `remark`, "modify_time", `train_mse`, `train_r2`, `test_mse`, `test_r2`}
  309. // 配置状态修改
  310. {
  311. // 训练参数
  312. trainData := aiPredictModelImportData.TrainData
  313. indexConfigItem.TrainStatus = aiPredictModel.TrainStatusSuccess
  314. indexConfigItem.Remark = `成功`
  315. indexConfigItem.TrainMse = fmt.Sprint(trainData.TrainMse)
  316. indexConfigItem.TrainR2 = fmt.Sprint(trainData.TrainR2)
  317. indexConfigItem.TestMse = fmt.Sprint(trainData.TestMse)
  318. indexConfigItem.TestR2 = fmt.Sprint(trainData.TestR2)
  319. indexConfigItem.ModifyTime = time.Now()
  320. }
  321. indexConfigOb := new(aiPredictModel.AiPredictModelIndexConfig)
  322. dataList := make([]*aiPredictModel.AiPredictModelIndexConfigTrainData, 0)
  323. for _, tmpData := range aiPredictModelImportData.Data {
  324. tmpDate, e := time.ParseInLocation(utils.FormatDate, tmpData.DataTime, time.Local)
  325. if e != nil {
  326. err = fmt.Errorf("数据日期解析失败, %v", e)
  327. return
  328. }
  329. timestamp := tmpDate.UnixNano() / 1e6
  330. dataList = append(dataList, &aiPredictModel.AiPredictModelIndexConfigTrainData{
  331. //AiPredictModelDataId: 0,
  332. AiPredictModelIndexConfigId: indexConfigItem.AiPredictModelIndexConfigId,
  333. AiPredictModelIndexId: indexItem.AiPredictModelIndexId,
  334. IndexCode: indexItem.IndexCode,
  335. DataTime: tmpDate,
  336. Value: tmpData.Value,
  337. PredictValue: tmpData.PredictValue,
  338. Direction: tmpData.Direction,
  339. DeviationRate: tmpData.DeviationRate,
  340. CreateTime: time.Now(),
  341. ModifyTime: time.Now(),
  342. DataTimestamp: timestamp,
  343. Source: tmpData.Source,
  344. })
  345. }
  346. // 更新指标和数据
  347. err = indexConfigOb.UpdateIndexAndData(indexItem, indexConfigItem, dataList, updateIndexCols, updateIndexConfigCols)
  348. if err != nil {
  349. return
  350. }
  351. return
  352. }
  353. // handleTaskRecordSuccessByRun
  354. // @Description: 运行中的数据处理
  355. // @author: Roc
  356. // @datetime 2025-05-14 14:28:11
  357. // @param aiPredictModelImportData request.AiPredictModelImportData
  358. // @param indexItem *aiPredictModel.AiPredictModelIndex
  359. // @return err error
  360. func handleTaskRecordSuccessByRun(aiPredictModelImportData request.AiPredictModelImportData, indexItem *aiPredictModel.AiPredictModelIndex) (err error) {
  361. defer func() {
  362. defer func() {
  363. if err != nil {
  364. utils.FileLog.Error(fmt.Sprintf(`handleTaskRecordSuccessByRun err:%v`, err))
  365. }
  366. }()
  367. }()
  368. indexOb := new(aiPredictModel.AiPredictModelIndex)
  369. updateCols := []string{indexOb.Cols().RunStatus, indexOb.Cols().PredictValue, indexOb.Cols().DirectionAccuracy, indexOb.Cols().AbsoluteDeviation, indexOb.Cols().ExtraConfig, indexOb.Cols().ModifyTime}
  370. // 预测日期,理论上是需要改的,可是产品说不需要改,所以暂时不改
  371. updateCols = append(updateCols, indexOb.Cols().PredictDate)
  372. indexItem.RunStatus = aiPredictModel.RunStatusSuccess
  373. indexItem.PredictValue = aiPredictModelImportData.Index.PredictValue
  374. indexItem.DirectionAccuracy = aiPredictModelImportData.Index.DirectionAccuracy
  375. indexItem.AbsoluteDeviation = aiPredictModelImportData.Index.AbsoluteDeviation
  376. indexItem.ModifyTime = time.Now()
  377. predictDate, e := time.ParseInLocation(utils.FormatDate, aiPredictModelImportData.Index.PredictDate, time.Local)
  378. if e != nil {
  379. err = fmt.Errorf("预测日期解析失败, %v", e)
  380. return
  381. }
  382. indexItem.PredictDate = predictDate
  383. // 图例信息
  384. if indexItem.ExtraConfig != "" && aiPredictModelImportData.Index.ExtraConfig != "" {
  385. var oldConfig, newConfig aiPredictModel.AiPredictModelIndexExtraConfig
  386. if e := json.Unmarshal([]byte(indexItem.ExtraConfig), &oldConfig); e != nil {
  387. err = fmt.Errorf("标的原配置解析失败, Config: %s, Err: %v", indexItem.ExtraConfig, e)
  388. return
  389. }
  390. if e := json.Unmarshal([]byte(aiPredictModelImportData.Index.ExtraConfig), &newConfig); e != nil {
  391. err = fmt.Errorf("标的新配置解析失败, Config: %s, Err: %v", aiPredictModelImportData.Index.ExtraConfig, e)
  392. return
  393. }
  394. oldConfig.DailyChart.PredictLegendName = newConfig.DailyChart.PredictLegendName
  395. b, _ := json.Marshal(oldConfig)
  396. indexItem.ExtraConfig = string(b)
  397. }
  398. dataList := make([]*aiPredictModel.AiPredictModelData, 0)
  399. for _, tmpData := range aiPredictModelImportData.Data {
  400. tmpDate, e := time.ParseInLocation(utils.FormatDate, tmpData.DataTime, time.Local)
  401. if e != nil {
  402. err = fmt.Errorf("数据日期解析失败, %v", e)
  403. return
  404. }
  405. timestamp := tmpDate.UnixNano() / 1e6
  406. dataList = append(dataList, &aiPredictModel.AiPredictModelData{
  407. //AiPredictModelDataId: 0,
  408. AiPredictModelIndexId: indexItem.AiPredictModelIndexId,
  409. IndexCode: indexItem.IndexCode,
  410. DataTime: tmpDate,
  411. Value: tmpData.Value,
  412. PredictValue: tmpData.PredictValue,
  413. Direction: tmpData.Direction,
  414. DeviationRate: tmpData.DeviationRate,
  415. CreateTime: time.Now(),
  416. ModifyTime: time.Now(),
  417. DataTimestamp: timestamp,
  418. Source: tmpData.Source,
  419. })
  420. }
  421. // 更新指标和数据
  422. err = indexOb.UpdateIndexAndData(indexItem, dataList, updateCols)
  423. if err != nil {
  424. return
  425. }
  426. return
  427. }