Roc 1 tydzień temu
rodzic
commit
f77ace06e4

+ 86 - 1
controllers/data_manage/ai_predict_model/index.go

@@ -6,6 +6,7 @@ import (
 	"eta/eta_api/models"
 	aiPredictModel "eta/eta_api/models/ai_predict_model"
 	"eta/eta_api/models/ai_predict_model/request"
+	"eta/eta_api/models/ai_predict_model/response"
 	"eta/eta_api/models/data_manage"
 	dataSourceModel "eta/eta_api/models/data_source"
 	"eta/eta_api/models/system"
@@ -110,7 +111,7 @@ func (this *AiPredictModelIndexController) List() {
 			_, list, e := elastic.SearchDataSourceIndex(utils.EsDataSourceIndexName, keyword, utils.DATA_SOURCE_AI_PREDICT_MODEL, 0, []int{}, []int{}, []string{}, startSize, pageSize)
 			if e != nil {
 				br.Msg = "获取失败"
-				br.ErrMsg = fmt.Sprintf("ES-搜索手工指标列表失败, %v", e)
+				br.ErrMsg = fmt.Sprintf("ES-搜索AI预测模型列表失败, %v", e)
 				return
 			}
 			if len(list) == 0 {
@@ -280,6 +281,8 @@ func (this *AiPredictModelIndexController) Import() {
 				imports[indexName].Index.IndexName = indexName
 				imports[indexName].Index.CreateTime = time.Now()
 				imports[indexName].Index.ModifyTime = time.Now()
+				imports[indexName].Index.TrainStatus = `训练成功`
+				imports[indexName].Index.RunStatus = `运行成功`
 
 				// 分类
 				classifyName := strings.TrimSpace(cells[1].String())
@@ -1197,3 +1200,85 @@ func (this *AiPredictModelIndexController) ScriptPathSave() {
 	br.Msg = "操作成功"
 	br.Success = true
 }
+
+// GetCurrentRunningAiPredictModelIndexCount
+// @Title 获取当前正在运行中的模型数量
+// @Description 获取当前正在运行中的模型数量
+// @Success 200 Ret=200 保存成功
+// @Success 200 {object} response.CurrentRunningCountResp
+// @router /index/running/count [get]
+func (this *AiPredictModelIndexController) GetCurrentRunningAiPredictModelIndexCount() {
+	br := new(models.BaseResponse).Init()
+	defer func() {
+		if br.ErrMsg == "" {
+			br.IsSendEmail = false
+		}
+		this.Data["json"] = br
+		this.ServeJSON()
+	}()
+	sysUser := this.SysUser
+	if sysUser == nil {
+		br.Msg = "请登录"
+		br.ErrMsg = "请登录,SysUser Is Empty"
+		br.Ret = 408
+		return
+	}
+
+	// 查找当前标的是否存在待训练/训练中的模型
+	count, err := services.GetCurrentRunningAiPredictModelIndexCount()
+	if err != nil {
+		br.Msg = "训练失败"
+		br.ErrMsg = "训练失败,查找待训练的模型失败,Err:" + err.Error()
+		return
+	}
+
+	resp := response.CurrentRunningCountResp{
+		Total: count,
+	}
+
+	br.Data = resp
+	br.Ret = 200
+	br.Success = true
+	br.Msg = "获取成功"
+}
+
+// Run
+// @Title 获取当前正在运行中的模型数量
+// @Description 获取当前正在运行中的模型数量
+// @Success 200 Ret=200 保存成功
+// @Success 200 {object} response.CurrentRunningCountResp
+// @router /index/run [get]
+func (this *AiPredictModelIndexController) Run() {
+	br := new(models.BaseResponse).Init()
+	defer func() {
+		if br.ErrMsg == "" {
+			br.IsSendEmail = false
+		}
+		this.Data["json"] = br
+		this.ServeJSON()
+	}()
+	sysUser := this.SysUser
+	if sysUser == nil {
+		br.Msg = "请登录"
+		br.ErrMsg = "请登录,SysUser Is Empty"
+		br.Ret = 408
+		return
+	}
+
+	// 查找当前标的是否存在待训练/训练中的模型
+	count, err := services.GetCurrentRunningAiPredictModelIndexCount()
+	if err != nil {
+		br.Msg = "训练失败"
+		br.ErrMsg = "训练失败,查找待训练的模型失败,Err:" + err.Error()
+		return
+	}
+
+	resp := response.CurrentRunningCountResp{
+		Total: count,
+	}
+
+	br.Data = resp
+	br.Ret = 200
+	br.Success = true
+	br.Msg = "获取成功"
+}

+ 10 - 1
controllers/data_manage/ai_predict_model/index_config.go

@@ -433,7 +433,7 @@ func (c *AiPredictModelIndexConfigController) Train() {
 	obj := new(data_manage.AiPredictModelIndexConfig)
 
 	// 查找当前标的是否存在待训练/训练中的模型
-	count, err := obj.GetCountByCondition(fmt.Sprintf(` AND %s = ? AND %s IN (?,?)`, data_manage.AiPredictModelIndexConfigColumns.AiPredictModelIndexId, data_manage.AiPredictModelIndexConfigColumns.TrainStatus), []interface{}{indexItem.AiPredictModelIndexId, data_manage.TrainStatusWaiting, data_manage.TrainStatusTraining})
+	count, err := services.GetCurrentRunningAiPredictModelIndexCount()
 	if err != nil {
 		br.Msg = "训练失败"
 		br.ErrMsg = "训练失败,查找待训练的模型失败,Err:" + err.Error()
@@ -478,6 +478,7 @@ func (c *AiPredictModelIndexConfigController) Train() {
 			br.ErrMsg = "训练失败,Err:" + err.Error()
 			return
 		}
+
 	} else {
 		// 新增训练模型
 		item := &data_manage.AiPredictModelIndexConfig{
@@ -504,6 +505,14 @@ func (c *AiPredictModelIndexConfigController) Train() {
 		}
 	}
 
+	indexItem.TrainStatus = data_manage.TrainStatusWaiting
+	indexItem.ModifyTime = time.Now()
+	err = indexItem.Update([]string{"train_status", "modify_time"})
+	if err != nil {
+		br.Msg = "训练失败"
+		br.ErrMsg = "训练失败,Err:" + err.Error()
+		return
+	}
 	// TODO 加入训练任务中
 
 	br.Ret = 200

+ 18 - 0
models/ai_predict_model/ai_predict_model_index.go

@@ -11,6 +11,14 @@ import (
 	"time"
 )
 
+// 训练状态
+const (
+	RunStatusWaiting  = "待运行"
+	RunStatusTraining = "运行中"
+	RunStatusSuccess  = "运行成功"
+	RunStatusFailed   = "运行失败"
+)
+
 // AiPredictModelIndex AI预测模型标的
 type AiPredictModelIndex struct {
 	AiPredictModelIndexId       int       `orm:"column(ai_predict_model_index_id);pk" gorm:"primaryKey"`
@@ -33,6 +41,8 @@ type AiPredictModelIndex struct {
 	ModifyTime                  time.Time `description:"修改时间"`
 	AiPredictModelIndexConfigId int       `gorm:"column:ai_predict_model_index_config_id" description:"标的当前的配置id"`
 	ScriptPath                  string    `gorm:"column:script_path" description:"脚本的路径"`
+	TrainStatus                 string    `gorm:"column:train_status" description:"训练状态,枚举值:待训练,训练中,训练成功,训练失败"`
+	RunStatus                   string    `gorm:"column:run_status" description:"运行状态,枚举值:待运行,运行中,运行成功,运行失败"`
 }
 
 func (m *AiPredictModelIndex) TableName() string {
@@ -59,6 +69,8 @@ type AiPredictModelIndexCols struct {
 	ModifyTime                  string
 	AiPredictModelIndexConfigId string
 	ScriptPath                  string
+	TrainStatus                 string
+	RunStatus                   string
 }
 
 func (m *AiPredictModelIndex) Cols() AiPredictModelIndexCols {
@@ -82,6 +94,8 @@ func (m *AiPredictModelIndex) Cols() AiPredictModelIndexCols {
 		ModifyTime:                  "modify_time",
 		AiPredictModelIndexConfigId: "ai_predict_model_index_config_id",
 		ScriptPath:                  "script_path",
+		TrainStatus:                 "train_status",
+		RunStatus:                   "run_status",
 	}
 }
 
@@ -225,6 +239,8 @@ type AiPredictModelIndexItem struct {
 	SearchText                  string  `description:"搜索结果(含高亮)"`
 	AiPredictModelIndexConfigId int     `gorm:"column:ai_predict_model_index_config_id" description:"标的当前的配置id"`
 	ScriptPath                  string  `gorm:"column:script_path" description:"脚本的路径"`
+	TrainStatus                 string  `gorm:"column:train_status" description:"训练状态,枚举值:待训练,训练中,训练成功,训练失败"`
+	RunStatus                   string  `gorm:"column:run_status" description:"运行状态,枚举值:待运行,运行中,运行成功,运行失败"`
 }
 
 func (m *AiPredictModelIndex) Format2Item() (item *AiPredictModelIndexItem) {
@@ -244,6 +260,8 @@ func (m *AiPredictModelIndex) Format2Item() (item *AiPredictModelIndexItem) {
 	item.SysUserRealName = m.SysUserRealName
 	item.AiPredictModelIndexConfigId = m.AiPredictModelIndexConfigId
 	item.ScriptPath = m.ScriptPath
+	item.TrainStatus = m.TrainStatus
+	item.RunStatus = m.RunStatus
 	item.CreateTime = utils.TimeTransferString(utils.FormatDateTime, m.CreateTime)
 	item.ModifyTime = utils.TimeTransferString(utils.FormatDateTime, m.ModifyTime)
 	return

+ 7 - 0
models/ai_predict_model/response/index.go

@@ -0,0 +1,7 @@
+package response
+
+// CurrentRunningCountResp
+// @Description: 当前正在运行/训练的模型总数量返回
+type CurrentRunningCountResp struct {
+	Total int `description:"当前正在运行/训练的模型总数量"`
+}

+ 9 - 0
routers/commentsRouter.go

@@ -466,6 +466,15 @@ func init() {
             Filters: nil,
             Params: nil})
 
+    beego.GlobalControllerRouter["eta/eta_api/controllers/data_manage/ai_predict_model:AiPredictModelIndexController"] = append(beego.GlobalControllerRouter["eta/eta_api/controllers/data_manage/ai_predict_model:AiPredictModelIndexController"],
+        beego.ControllerComments{
+            Method: "GetCurrentRunningAiPredictModelIndexCount",
+            Router: `/index/running/count`,
+            AllowHTTPMethods: []string{"get"},
+            MethodParams: param.Make(),
+            Filters: nil,
+            Params: nil})
+
     beego.GlobalControllerRouter["eta/eta_api/controllers/data_manage/ai_predict_model:AiPredictModelIndexController"] = append(beego.GlobalControllerRouter["eta/eta_api/controllers/data_manage/ai_predict_model:AiPredictModelIndexController"],
         beego.ControllerComments{
             Method: "Save",

+ 23 - 0
services/ai_predict_model_index.go

@@ -439,6 +439,17 @@ func FixAiPredictCharts() {
 	}
 	return
 }
+
+// GetAiPredictConfigChartDetailByData
+// @Description: 获取AI预测模型训练完成后的数据
+// @author: Roc
+// @datetime 2025-05-07 15:29:11
+// @param indexName string
+// @param indexConfigItem *aiPredictModel.AiPredictModelIndexConfig
+// @param indexData []*aiPredictModel.AiPredictModelIndexConfigTrainData
+// @param source int
+// @return resp *data_manage.ChartInfoDetailResp
+// @return err error
 func GetAiPredictConfigChartDetailByData(indexName string, indexConfigItem *aiPredictModel.AiPredictModelIndexConfig, indexData []*aiPredictModel.AiPredictModelIndexConfigTrainData, source int) (resp *data_manage.ChartInfoDetailResp, err error) {
 	resp = new(data_manage.ChartInfoDetailResp)
 	// 标的配置
@@ -631,3 +642,15 @@ func GetAiPredictConfigChartDetailByData(indexName string, indexConfigItem *aiPr
 	resp.EdbInfoList = edbList
 	return
 }
+
+// GetCurrentRunningAiPredictModelIndexCount
+// @Description: 获取正在处理中的AI预测模型数量
+// @author: Roc
+// @datetime 2025-05-07 15:30:29
+// @return total int
+// @return err error
+func GetCurrentRunningAiPredictModelIndexCount() (total int, err error) {
+	obj := new(aiPredictModel.AiPredictModelIndex)
+	total, err = obj.GetCountByCondition(" AND (train_status in (?) OR run_status in (?) ) ", []interface{}{[]string{aiPredictModel.TrainStatusTraining, aiPredictModel.TrainStatusWaiting}, []string{aiPredictModel.RunStatusTraining, aiPredictModel.RunStatusWaiting}})
+	return
+}