ai_predict_model_index.go 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512
  1. package data_manage
  2. import (
  3. "eta/eta_api/global"
  4. "eta/eta_api/models/data_manage"
  5. "eta/eta_api/utils"
  6. "fmt"
  7. "github.com/rdlucklib/rdluck_tools/paging"
  8. "strconv"
  9. "strings"
  10. "time"
  11. )
  12. // 训练状态
  13. const (
  14. RunStatusWaiting = "待运行"
  15. RunStatusRunning = "运行中"
  16. RunStatusSuccess = "运行成功"
  17. RunStatusFailed = "运行失败"
  18. )
  19. // AiPredictModelIndex AI预测模型标的
  20. type AiPredictModelIndex struct {
  21. AiPredictModelIndexId int `orm:"column(ai_predict_model_index_id);pk" gorm:"primaryKey"`
  22. IndexName string `description:"标的名称"`
  23. IndexCode string `description:"自生成的指标编码"`
  24. ClassifyId int `description:"分类ID"`
  25. ModelFramework string `description:"模型框架"`
  26. PredictDate time.Time `description:"预测日期"`
  27. PredictValue float64 `description:"预测值"`
  28. PredictFrequency string `description:"预测频度"`
  29. DirectionAccuracy string `description:"方向准确度"`
  30. AbsoluteDeviation string `description:"绝对偏差"`
  31. ExtraConfig string `description:"模型参数"`
  32. Sort int `description:"排序"`
  33. SysUserId int `description:"创建人ID"`
  34. SysUserRealName string `description:"创建人姓名"`
  35. LeftMin string `description:"图表左侧最小值"`
  36. LeftMax string `description:"图表左侧最大值"`
  37. CreateTime time.Time `description:"创建时间"`
  38. ModifyTime time.Time `description:"修改时间"`
  39. AiPredictModelIndexConfigId int `gorm:"column:ai_predict_model_index_config_id" description:"标的当前的配置id"`
  40. ScriptPath string `gorm:"column:script_path" description:"脚本的路径"`
  41. TrainStatus string `gorm:"column:train_status" description:"训练状态,枚举值:待训练,训练中,训练成功,训练失败"`
  42. RunStatus string `gorm:"column:run_status" description:"运行状态,枚举值:待运行,运行中,运行成功,运行失败"`
  43. }
  44. func (m *AiPredictModelIndex) TableName() string {
  45. return "ai_predict_model_index"
  46. }
  47. type AiPredictModelIndexCols struct {
  48. PrimaryId string
  49. IndexName string
  50. IndexCode string
  51. ClassifyId string
  52. ModelFramework string
  53. PredictDate string
  54. PredictValue string
  55. DirectionAccuracy string
  56. AbsoluteDeviation string
  57. ExtraConfig string
  58. Sort string
  59. SysUserId string
  60. SysUserRealName string
  61. LeftMin string
  62. LeftMax string
  63. CreateTime string
  64. ModifyTime string
  65. AiPredictModelIndexConfigId string
  66. ScriptPath string
  67. TrainStatus string
  68. RunStatus string
  69. }
  70. func (m *AiPredictModelIndex) Cols() AiPredictModelIndexCols {
  71. return AiPredictModelIndexCols{
  72. PrimaryId: "ai_predict_model_index_id",
  73. IndexName: "index_name",
  74. IndexCode: "index_code",
  75. ClassifyId: "classify_id",
  76. ModelFramework: "model_framework",
  77. PredictDate: "predict_date",
  78. PredictValue: "predict_value",
  79. DirectionAccuracy: "direction_accuracy",
  80. AbsoluteDeviation: "absolute_deviation",
  81. ExtraConfig: "extra_config",
  82. Sort: "sort",
  83. SysUserId: "sys_user_id",
  84. SysUserRealName: "sys_user_real_name",
  85. LeftMin: "left_min",
  86. LeftMax: "left_max",
  87. CreateTime: "create_time",
  88. ModifyTime: "modify_time",
  89. AiPredictModelIndexConfigId: "ai_predict_model_index_config_id",
  90. ScriptPath: "script_path",
  91. TrainStatus: "train_status",
  92. RunStatus: "run_status",
  93. }
  94. }
  95. func (m *AiPredictModelIndex) Create() (err error) {
  96. o := global.DbMap[utils.DbNameIndex]
  97. err = o.Create(m).Error
  98. return
  99. }
  100. func (m *AiPredictModelIndex) CreateMulti(items []*AiPredictModelIndex) (err error) {
  101. if len(items) == 0 {
  102. return
  103. }
  104. o := global.DbMap[utils.DbNameIndex]
  105. err = o.CreateInBatches(items, utils.MultiAddNum).Error
  106. return
  107. }
  108. func (m *AiPredictModelIndex) Update(cols []string) (err error) {
  109. o := global.DbMap[utils.DbNameIndex]
  110. err = o.Select(cols).Updates(m).Error
  111. return
  112. }
  113. func (m *AiPredictModelIndex) Remove() (err error) {
  114. o := global.DbMap[utils.DbNameIndex]
  115. sql := fmt.Sprintf(`DELETE FROM %s WHERE %s = ? LIMIT 1`, m.TableName(), m.Cols().PrimaryId)
  116. err = o.Exec(sql, m.AiPredictModelIndexId).Error
  117. return
  118. }
  119. func (m *AiPredictModelIndex) MultiRemove(ids []int) (err error) {
  120. if len(ids) == 0 {
  121. return
  122. }
  123. o := global.DbMap[utils.DbNameIndex]
  124. sql := fmt.Sprintf(`DELETE FROM %s WHERE %s IN (%s)`, m.TableName(), m.Cols().PrimaryId, utils.GetOrmInReplace(len(ids)))
  125. err = o.Exec(sql, ids).Error
  126. return
  127. }
  128. func (m *AiPredictModelIndex) RemoveByCondition(condition string, pars []interface{}) (err error) {
  129. if condition == "" {
  130. return
  131. }
  132. o := global.DbMap[utils.DbNameIndex]
  133. sql := fmt.Sprintf(`DELETE FROM %s WHERE %s`, m.TableName(), condition)
  134. err = o.Exec(sql, pars...).Error
  135. return
  136. }
  137. func (m *AiPredictModelIndex) GetItemById(id int) (item *AiPredictModelIndex, err error) {
  138. o := global.DbMap[utils.DbNameIndex]
  139. sql := fmt.Sprintf(`SELECT * FROM %s WHERE %s = ? LIMIT 1`, m.TableName(), m.Cols().PrimaryId)
  140. err = o.Raw(sql, id).First(&item).Error
  141. return
  142. }
  143. // GetItemByConfigId
  144. // @Description: 根据配置id获取标的信息
  145. // @author: Roc
  146. // @receiver m
  147. // @datetime 2025-05-06 13:31:24
  148. // @param configId int
  149. // @return item *AiPredictModelIndex
  150. // @return err error
  151. func (m *AiPredictModelIndex) GetItemByConfigId(configId int) (item *AiPredictModelIndex, err error) {
  152. o := global.DbMap[utils.DbNameIndex]
  153. sql := fmt.Sprintf(`SELECT * FROM %s WHERE %s = ? LIMIT 1`, m.TableName(), m.Cols().AiPredictModelIndexConfigId)
  154. err = o.Raw(sql, configId).First(&item).Error
  155. return
  156. }
  157. func (m *AiPredictModelIndex) GetItemByCondition(condition string, pars []interface{}, orderRule string) (item *AiPredictModelIndex, err error) {
  158. o := global.DbMap[utils.DbNameIndex]
  159. order := ``
  160. if orderRule != "" {
  161. order = ` ORDER BY ` + orderRule
  162. }
  163. sql := fmt.Sprintf(`SELECT * FROM %s WHERE 1=1 %s %s LIMIT 1`, m.TableName(), condition, order)
  164. err = o.Raw(sql, pars...).First(&item).Error
  165. return
  166. }
  167. func (m *AiPredictModelIndex) GetCountByCondition(condition string, pars []interface{}) (count int, err error) {
  168. o := global.DbMap[utils.DbNameIndex]
  169. sql := fmt.Sprintf(`SELECT COUNT(1) FROM %s WHERE 1=1 %s`, m.TableName(), condition)
  170. err = o.Raw(sql, pars...).Scan(&count).Error
  171. return
  172. }
  173. func (m *AiPredictModelIndex) GetItemsByCondition(condition string, pars []interface{}, fieldArr []string, orderRule string) (items []*AiPredictModelIndex, err error) {
  174. o := global.DbMap[utils.DbNameIndex]
  175. fields := strings.Join(fieldArr, ",")
  176. if len(fieldArr) == 0 {
  177. fields = `*`
  178. }
  179. order := fmt.Sprintf(`ORDER BY %s DESC`, m.Cols().CreateTime)
  180. if orderRule != "" {
  181. order = ` ORDER BY ` + orderRule
  182. }
  183. sql := fmt.Sprintf(`SELECT %s FROM %s WHERE 1=1 %s %s`, fields, m.TableName(), condition, order)
  184. err = o.Raw(sql, pars...).Find(&items).Error
  185. return
  186. }
  187. func (m *AiPredictModelIndex) GetPageItemsByCondition(condition string, pars []interface{}, fieldArr []string, orderRule string, startSize, pageSize int) (items []*AiPredictModelIndex, err error) {
  188. o := global.DbMap[utils.DbNameIndex]
  189. fields := strings.Join(fieldArr, ",")
  190. if len(fieldArr) == 0 {
  191. fields = `*`
  192. }
  193. order := fmt.Sprintf(`ORDER BY %s DESC`, m.Cols().CreateTime)
  194. if orderRule != "" {
  195. order = ` ORDER BY ` + orderRule
  196. }
  197. sql := fmt.Sprintf(`SELECT %s FROM %s WHERE 1=1 %s %s LIMIT ?,?`, fields, m.TableName(), condition, order)
  198. pars = append(pars, startSize, pageSize)
  199. err = o.Raw(sql, pars...).Find(&items).Error
  200. return
  201. }
  202. // AiPredictModelIndexItem AI预测模型标的信息
  203. type AiPredictModelIndexItem struct {
  204. IndexId int `description:"标的ID"`
  205. IndexName string `description:"标的名称"`
  206. IndexCode string `description:"自生成的指标编码"`
  207. ClassifyId int `description:"分类ID"`
  208. ClassifyName string `description:"分类名称"`
  209. ModelFramework string `description:"模型框架"`
  210. PredictDate string `description:"预测日期"`
  211. PredictValue float64 `description:"预测值"`
  212. PredictFrequency string `description:"预测频度"`
  213. DirectionAccuracy string `description:"方向准确度"`
  214. AbsoluteDeviation string `description:"绝对偏差"`
  215. ExtraConfig string `description:"模型参数"`
  216. SysUserId int `description:"创建人ID"`
  217. SysUserRealName string `description:"创建人姓名"`
  218. CreateTime string `description:"创建时间"`
  219. ModifyTime string `description:"修改时间"`
  220. SearchText string `description:"搜索结果(含高亮)"`
  221. AiPredictModelIndexConfigId int `gorm:"column:ai_predict_model_index_config_id" description:"标的当前的配置id"`
  222. ScriptPath string `gorm:"column:script_path" description:"脚本的路径"`
  223. TrainStatus string `gorm:"column:train_status" description:"训练状态,枚举值:待训练,训练中,训练成功,训练失败"`
  224. RunStatus string `gorm:"column:run_status" description:"运行状态,枚举值:待运行,运行中,运行成功,运行失败"`
  225. }
  226. func (m *AiPredictModelIndex) Format2Item() (item *AiPredictModelIndexItem) {
  227. item = new(AiPredictModelIndexItem)
  228. item.IndexId = m.AiPredictModelIndexId
  229. item.IndexName = m.IndexName
  230. item.IndexCode = m.IndexCode
  231. item.ClassifyId = m.ClassifyId
  232. item.ModelFramework = m.ModelFramework
  233. item.PredictDate = utils.TimeTransferString(utils.FormatDate, m.PredictDate)
  234. item.PredictValue = m.PredictValue
  235. item.PredictFrequency = m.PredictFrequency
  236. item.DirectionAccuracy = m.DirectionAccuracy
  237. item.AbsoluteDeviation = m.AbsoluteDeviation
  238. item.ExtraConfig = m.ExtraConfig
  239. item.SysUserId = m.SysUserId
  240. item.SysUserRealName = m.SysUserRealName
  241. item.AiPredictModelIndexConfigId = m.AiPredictModelIndexConfigId
  242. item.ScriptPath = m.ScriptPath
  243. item.TrainStatus = m.TrainStatus
  244. item.RunStatus = m.RunStatus
  245. item.CreateTime = utils.TimeTransferString(utils.FormatDateTime, m.CreateTime)
  246. item.ModifyTime = utils.TimeTransferString(utils.FormatDateTime, m.ModifyTime)
  247. return
  248. }
  249. type AiPredictModelIndexPageListResp struct {
  250. Paging *paging.PagingItem
  251. List []*AiPredictModelIndexItem `description:"列表"`
  252. }
  253. // RemoveIndexAndData 删除标的及数据
  254. func (m *AiPredictModelIndex) RemoveIndexAndData(indexId int, chartIds []int) (err error) {
  255. o := global.DbMap[utils.DbNameIndex]
  256. tx := o.Begin()
  257. defer func() {
  258. if err != nil {
  259. _ = tx.Rollback()
  260. return
  261. }
  262. _ = tx.Commit()
  263. }()
  264. sql := `DELETE FROM ai_predict_model_index WHERE ai_predict_model_index_id = ? LIMIT 1`
  265. e := tx.Exec(sql, indexId).Error
  266. if e != nil {
  267. err = fmt.Errorf("remove index err: %v", e)
  268. return
  269. }
  270. sql = ` DELETE FROM ai_predict_model_data WHERE ai_predict_model_index_id = ?`
  271. e = tx.Exec(sql, indexId).Error
  272. if e != nil {
  273. err = fmt.Errorf("remove index data err: %v", e)
  274. return
  275. }
  276. // 删除图表
  277. if len(chartIds) == 0 {
  278. return
  279. }
  280. sql = ` DELETE FROM chart_info WHERE chart_info_id IN ?`
  281. if e = tx.Exec(sql, chartIds).Error; e != nil {
  282. err = fmt.Errorf("remove charts err: %v", e)
  283. return
  284. }
  285. sql = ` DELETE FROM chart_edb_mapping WHERE chart_info_id IN ?`
  286. if e = tx.Exec(sql, chartIds).Error; e != nil {
  287. err = fmt.Errorf("remove chart mappings err: %v", e)
  288. return
  289. }
  290. return
  291. }
  292. // UpdateAiPredictModelIndexSortByClassifyId 根据分类id更新排序
  293. func UpdateAiPredictModelIndexSortByClassifyId(classifyId, nowSort int, prevEdbInfoId int, updateSort string) (err error) {
  294. o := global.DbMap[utils.DbNameIndex]
  295. sql := ` UPDATE ai_predict_model_index SET sort = ` + updateSort + ` WHERE classify_id = ?`
  296. if prevEdbInfoId > 0 {
  297. sql += ` AND ( sort > ? or ( ai_predict_model_index_id > ` + fmt.Sprint(prevEdbInfoId) + ` and sort=` + fmt.Sprint(nowSort) + ` )) `
  298. } else {
  299. sql += ` AND ( sort > ? )`
  300. }
  301. err = o.Exec(sql, classifyId, nowSort).Error
  302. return
  303. }
  304. // GetFirstAiPredictModelIndexByClassifyId 获取当前分类下,且排序数相同 的排序第一条的数据
  305. func GetFirstAiPredictModelIndexByClassifyId(classifyId int) (item *AiPredictModelIndex, err error) {
  306. o := global.DbMap[utils.DbNameIndex]
  307. sql := ` SELECT * FROM ai_predict_model_index WHERE classify_id = ? order by sort asc,ai_predict_model_index_id asc limit 1`
  308. err = o.Raw(sql, classifyId).First(&item).Error
  309. return
  310. }
  311. type AiPredictModelImportData struct {
  312. Index *AiPredictModelIndex
  313. Data []*AiPredictModelData
  314. Charts []*AiPredictModelImportCharts
  315. }
  316. type AiPredictModelImportCharts struct {
  317. ChartInfo *data_manage.ChartInfo
  318. EdbMappings []*data_manage.ChartEdbMapping
  319. }
  320. // ImportIndexAndData 导入数据
  321. func (m *AiPredictModelIndex) ImportIndexAndData(createIndexes, updateIndexes []*AiPredictModelImportData, updateCols []string) (chartIds []int, err error) {
  322. if len(createIndexes) == 0 && len(updateIndexes) == 0 {
  323. return
  324. }
  325. o := global.DbMap[utils.DbNameIndex]
  326. tx := o.Begin()
  327. defer func() {
  328. if err != nil {
  329. _ = tx.Rollback()
  330. return
  331. }
  332. _ = tx.Commit()
  333. }()
  334. if len(updateIndexes) > 0 {
  335. for _, v := range updateIndexes {
  336. // 更新指标
  337. e := tx.Select(updateCols).Updates(v.Index).Error
  338. if e != nil {
  339. err = fmt.Errorf("update index err: %v", e)
  340. return
  341. }
  342. var hasDaily, hasMonthly bool
  343. for _, d := range v.Data {
  344. d.AiPredictModelIndexId = v.Index.AiPredictModelIndexId
  345. d.IndexCode = v.Index.IndexCode
  346. d.DataTimestamp = d.DataTime.UnixNano() / 1e6
  347. if d.Source == ModelDataSourceDaily {
  348. hasDaily = true
  349. }
  350. if d.Source == ModelDataSourceMonthly {
  351. hasMonthly = true
  352. }
  353. }
  354. // 哪个有数据就先清空然后重新写入,没数据就保留旧数据, 都没就忽略
  355. if !hasDaily && !hasMonthly {
  356. continue
  357. }
  358. removeCond := ``
  359. removePars := make([]interface{}, 0)
  360. removePars = append(removePars, v.Index.AiPredictModelIndexId)
  361. if hasDaily && !hasMonthly {
  362. removeCond += ` AND source = ?`
  363. removePars = append(removePars, ModelDataSourceDaily)
  364. }
  365. if !hasDaily && hasMonthly {
  366. removeCond += ` AND source = ?`
  367. removePars = append(removePars, ModelDataSourceMonthly)
  368. }
  369. // 清空指标并新增
  370. sql := fmt.Sprintf(`DELETE FROM ai_predict_model_data WHERE ai_predict_model_index_id = ? %s`, removeCond)
  371. e = tx.Exec(sql, removePars...).Error
  372. if e != nil {
  373. err = fmt.Errorf("clear index data err: %v", e)
  374. return
  375. }
  376. e = tx.CreateInBatches(v.Data, utils.MultiAddNum).Error
  377. if e != nil {
  378. err = fmt.Errorf("insert index data err: %v", e)
  379. return
  380. }
  381. }
  382. }
  383. if len(createIndexes) > 0 {
  384. for _, v := range createIndexes {
  385. if e := tx.Create(v.Index).Error; e != nil {
  386. err = fmt.Errorf("insert index err: %v", e)
  387. return
  388. }
  389. indexId := v.Index.AiPredictModelIndexId
  390. for _, d := range v.Data {
  391. d.AiPredictModelIndexId = indexId
  392. d.IndexCode = v.Index.IndexCode
  393. d.DataTimestamp = d.DataTime.UnixNano() / 1e6
  394. }
  395. if e := tx.CreateInBatches(v.Data, utils.MultiAddNum).Error; e != nil {
  396. err = fmt.Errorf("insert index data err: %v", e)
  397. return
  398. }
  399. // 图表
  400. if len(v.Charts) == 0 {
  401. continue
  402. }
  403. for _, ct := range v.Charts {
  404. if e := tx.Create(ct.ChartInfo).Error; e != nil {
  405. err = fmt.Errorf("insert chart err: %v", e)
  406. return
  407. }
  408. for _, cm := range ct.EdbMappings {
  409. cm.ChartInfoId = ct.ChartInfo.ChartInfoId
  410. cm.EdbInfoId = indexId
  411. time.Sleep(time.Microsecond)
  412. cm.UniqueCode = utils.MD5(fmt.Sprint(utils.CHART_PREFIX, "_", indexId, "_", strconv.FormatInt(time.Now().UnixNano(), 10)))
  413. }
  414. if e := tx.CreateInBatches(ct.EdbMappings, utils.MultiAddNum).Error; e != nil {
  415. err = fmt.Errorf("insert chart mapping err: %v", e)
  416. return
  417. }
  418. chartIds = append(chartIds, ct.ChartInfo.ChartInfoId)
  419. }
  420. }
  421. }
  422. return
  423. }
  424. type AiPredictModelDetailResp struct {
  425. TableData []*AiPredictModelDataItem `description:"表格数据"`
  426. ChartView *data_manage.ChartInfoDetailResp `description:"月度预测数据图表"`
  427. DailyChartView *data_manage.ChartInfoDetailResp `description:"日度预测数据图表"`
  428. }
  429. type AiPredictModelIndexSaveReq struct {
  430. IndexId int `description:"指标ID"`
  431. MonthlyChart *AiPredictModelIndexSaveChart `description:"月度图表信息"`
  432. DailyChart *AiPredictModelIndexSaveChart `description:"日度图表信息"`
  433. }
  434. type AiPredictModelIndexSaveChart struct {
  435. LeftMin string `description:"图表左侧最小值"`
  436. LeftMax string `description:"图表左侧最大值"`
  437. Unit string `description:"单位"`
  438. }
  439. type AiPredictModelIndexExtraConfig struct {
  440. MonthlyChart MonthlyChartConfig
  441. DailyChart DailyChartConfig
  442. }
  443. type MonthlyChartConfig struct {
  444. LeftMin string `description:"图表左侧最小值"`
  445. LeftMax string `description:"图表左侧最大值"`
  446. Unit string `description:"单位"`
  447. }
  448. type DailyChartConfig struct {
  449. LeftMin string `description:"图表左侧最小值"`
  450. LeftMax string `description:"图表左侧最大值"`
  451. Unit string `description:"单位"`
  452. PredictLegendName string `description:"预测图例的名称(通常为Predicted)"`
  453. }
  454. func (m *AiPredictModelIndex) GetSortMax() (sort int, err error) {
  455. o := global.DbMap[utils.DbNameIndex]
  456. sql := `SELECT COALESCE(MAX(sort), 0) AS sort FROM ai_predict_model_index`
  457. err = o.Raw(sql).Scan(&sort).Error
  458. if err != nil {
  459. return
  460. }
  461. // 查询分类的最大排序
  462. sql = `SELECT COALESCE(MAX(sort), 0) AS sort FROM ai_predict_model_classify`
  463. var classifySort int
  464. err = o.Raw(sql).Scan(&classifySort).Error
  465. if err != nil {
  466. return
  467. }
  468. if classifySort > sort {
  469. sort = classifySort
  470. }
  471. return
  472. }