瀏覽代碼

feat:训练

Roc 1 周之前
父節點
當前提交
8abd4fb56f

+ 111 - 34
controllers/data_manage/ai_predict_model/index_config.go

@@ -24,7 +24,7 @@ type AiPredictModelIndexConfigController struct {
 // @Description 列表
 // @Param   PageSize   query   int  true       "每页数据条数"
 // @Param   CurrentIndex   query   int  true       "当前页页码,从1开始"
-// @Param   KeyWord   query   string  true       "搜索关键词"
+// @Param   AiPredictModelIndexId   query   int  true       "标的id"
 // @Success 200 {object} []*data_manage.AiPredictModelIndexConfigView
 // @router /index_config/list [get]
 func (c *AiPredictModelIndexConfigController) List() {
@@ -42,6 +42,13 @@ func (c *AiPredictModelIndexConfigController) List() {
 	}
 	pageSize, _ := c.GetInt("PageSize")
 	currentIndex, _ := c.GetInt("CurrentIndex")
+	indexId, _ := c.GetInt("AiPredictModelIndexId")
+
+	if indexId <= 0 {
+		br.Msg = "标的id不能为空"
+		br.ErrMsg = "标的id不能为空"
+		return
+	}
 
 	var startSize int
 	if pageSize <= 0 {
@@ -58,6 +65,9 @@ func (c *AiPredictModelIndexConfigController) List() {
 	var condition string
 	var pars []interface{}
 
+	condition += fmt.Sprintf(` AND %s = ? `, data_manage.AiPredictModelIndexConfigColumns.AiPredictModelIndexId)
+	pars = append(pars, indexId)
+
 	obj := new(data_manage.AiPredictModelIndexConfig)
 	tmpTotal, list, err := obj.GetPageListByCondition(condition, pars, startSize, pageSize)
 	if err != nil {
@@ -148,7 +158,7 @@ func (c *AiPredictModelIndexConfigController) CurrVersion() {
 // @Description 设置为当前版本
 // @Param	request	body request.DelConfigReq true "type json string"
 // @Success 200 Ret=200 设置成功
-// @router /index_config/set_curr [post]
+// @router /index_config/version/set_curr [post]
 func (c *AiPredictModelIndexConfigController) SetCurr() {
 	br := new(models.BaseResponse).Init()
 	defer func() {
@@ -182,7 +192,7 @@ func (c *AiPredictModelIndexConfigController) SetCurr() {
 	indexOb := new(data_manage.AiPredictModelIndex)
 	indexItem, e := indexOb.GetItemById(configItem.AiPredictModelIndexId)
 	if e != nil {
-		br.Msg = "获取失败"
+		br.Msg = "操作失败"
 		br.ErrMsg = fmt.Sprintf("获取失败,根据配置ID获取标的信息失败, %v", e)
 		return
 	}
@@ -191,14 +201,14 @@ func (c *AiPredictModelIndexConfigController) SetCurr() {
 	indexItem.ModifyTime = time.Now()
 	err = indexItem.Update([]string{indexOb.Cols().ModifyTime, indexOb.Cols().AiPredictModelIndexConfigId})
 	if err != nil {
-		br.Msg = "配置失败"
+		br.Msg = "操作失败"
 		br.ErrMsg = fmt.Sprintf("配置失败,Err:%v", e)
 		return
 	}
 
 	br.Ret = 200
 	br.Success = true
-	br.Msg = "获取成功"
+	br.Msg = "操作成功"
 }
 
 // Del
@@ -257,6 +267,12 @@ func (c *AiPredictModelIndexConfigController) Del() {
 		}
 	}
 
+	if !utils.InArrayByStr([]string{data_manage.TrainStatusSuccess, data_manage.TrainStatusFailed}, item.TrainStatus) {
+		br.Msg = "删除失败,该版本配置正在训练中"
+		br.IsSendEmail = false
+		return
+	}
+
 	item.IsDeleted = 1
 	item.ModifyTime = time.Now()
 	err = item.Update([]string{"IsDeleted", "ModifyTime"})
@@ -367,7 +383,7 @@ func (c *AiPredictModelIndexConfigController) ChartDetail() {
 // Train
 // @Title 训练模型
 // @Description 训练模型
-// @Param	request	body request.DelConfigReq true "type json string"
+// @Param	request	body request.TrainReq true "type json string"
 // @Success 200 Ret=200 训练中
 // @router /index_config/train [post]
 func (c *AiPredictModelIndexConfigController) Train() {
@@ -376,60 +392,121 @@ func (c *AiPredictModelIndexConfigController) Train() {
 		c.Data["json"] = br
 		c.ServeJSON()
 	}()
-	var req request.DelConfigReq
+	var req request.TrainReq
 	err := json.Unmarshal(c.Ctx.Input.RequestBody, &req)
 	if err != nil {
 		br.Msg = "参数解析异常!"
 		br.ErrMsg = "参数解析失败,Err:" + err.Error()
 		return
 	}
-	if req.AiPredictModelIndexConfigId <= 0 {
-		br.Msg = "配置id不能为空"
+	if req.AiPredictModelIndexId <= 0 {
+		br.Msg = "标的id不能为空"
 		br.IsSendEmail = false
 		return
 	}
 
-	// 查找配置
-	obj := new(data_manage.AiPredictModelIndexConfig)
-	item, err := obj.GetById(req.AiPredictModelIndexConfigId)
+	paramsStrByte, err := json.Marshal(req.Params)
 	if err != nil {
-		br.Msg = "修改失败"
-		br.ErrMsg = "修改失败,查找配置失败,Err:" + err.Error()
+		br.Msg = "训练失败!"
+		br.ErrMsg = "训练失败,参数转json失败,Err:" + err.Error()
+		return
+	}
+
+	// 查询标的情况
+	indexOb := new(data_manage.AiPredictModelIndex)
+	indexItem, err := indexOb.GetItemById(req.AiPredictModelIndexId)
+	if err != nil {
+		br.Msg = "训练失败,查找标的失败"
+		br.ErrMsg = fmt.Sprintf("训练失败,查找标的失败, %v", err)
 		if utils.IsErrNoRow(err) {
-			br.Msg = "配置不存在"
+			br.Msg = "标的不存在"
 			br.IsSendEmail = false
 		}
 		return
 	}
+	if indexItem.ScriptPath == `` {
+		br.Msg = "训练失败,脚本路径不能为空"
+		br.IsSendEmail = false
+		return
+	}
 
-	// 查找是否被标的引用为默认模型
-	{
-		// 查询标的情况
-		indexOb := new(data_manage.AiPredictModelIndex)
-		count, e := indexOb.GetCountByCondition(fmt.Sprintf(` AND %s = ? `, indexOb.Cols().AiPredictModelIndexConfigId), []interface{}{item.AiPredictModelIndexConfigId})
-		if e != nil {
-			br.Msg = "删除失败"
-			br.ErrMsg = fmt.Sprintf("删除失败,根据配置ID获取标的信息失败, %v", e)
+	obj := new(data_manage.AiPredictModelIndexConfig)
+
+	// 查找当前标的是否存在待训练/训练中的模型
+	count, err := obj.GetCountByCondition(fmt.Sprintf(` AND %s = ? AND %s IN (?,?)`, data_manage.AiPredictModelIndexConfigColumns.AiPredictModelIndexId, data_manage.AiPredictModelIndexConfigColumns.TrainStatus), []interface{}{indexItem.AiPredictModelIndexId, data_manage.TrainStatusWaiting, data_manage.TrainStatusTraining})
+	if err != nil {
+		br.Msg = "训练失败"
+		br.ErrMsg = "训练失败,查找待训练的模型失败,Err:" + err.Error()
+		return
+	}
+	if count > 0 {
+		br.Msg = "该标的存在待训练/训练中的模型,不允许重复训练"
+		br.IsSendEmail = false
+		return
+	}
+
+	if req.AiPredictModelIndexConfigId > 0 {
+		// 查找配置
+		item, err := obj.GetById(req.AiPredictModelIndexConfigId)
+		if err != nil {
+			br.Msg = "训练失败"
+			br.ErrMsg = "训练失败,查找配置失败,Err:" + err.Error()
+			if utils.IsErrNoRow(err) {
+				br.Msg = "配置不存在"
+				br.IsSendEmail = false
+			}
 			return
 		}
 
-		if count > 0 {
-			br.Msg = "删除失败,该版本配置正在被使用"
+		if item.AiPredictModelIndexId != indexItem.AiPredictModelIndexId {
+			br.Msg = "训练失败"
+			br.ErrMsg = "训练失败,配置与标的不匹配"
+			return
+		}
+
+		if item.TrainStatus != data_manage.TrainStatusFailed {
+			br.Msg = "该模型训练状态异常,不允许重新训练"
+			br.ErrMsg = "该模型训练状态异常,不允许重新训练,当前状态:" + item.TrainStatus
 			br.IsSendEmail = false
 			return
 		}
+		item.Params = string(paramsStrByte)
+		item.ModifyTime = time.Now()
+		err = item.Update([]string{"params", "modify_time"})
+		if err != nil {
+			br.Msg = "训练失败"
+			br.ErrMsg = "训练失败,Err:" + err.Error()
+			return
+		}
+	} else {
+		// 新增训练模型
+		item := &data_manage.AiPredictModelIndexConfig{
+			AiPredictModelIndexConfigId: 0,
+			AiPredictModelIndexId:       indexItem.AiPredictModelIndexId,
+			TrainStatus:                 data_manage.TrainStatusWaiting,
+			Params:                      string(paramsStrByte),
+			TrainMse:                    "",
+			TrainR2:                     "",
+			TestMse:                     "",
+			TestR2:                      "",
+			Remark:                      "",
+			IsDeleted:                   0,
+			LeftMin:                     "",
+			LeftMax:                     "",
+			ModifyTime:                  time.Now(),
+			CreateTime:                  time.Now(),
+		}
+		err = item.Create()
+		if err != nil {
+			br.Msg = "训练失败"
+			br.ErrMsg = "训练失败,Err:" + err.Error()
+			return
+		}
 	}
 
-	item.IsDeleted = 1
-	item.ModifyTime = time.Now()
-	err = item.Update([]string{"IsDeleted", "ModifyTime"})
-	if err != nil {
-		br.Msg = "删除失败"
-		br.ErrMsg = "删除失败,Err:" + err.Error()
-		return
-	}
+	// TODO 加入训练任务中
 
 	br.Ret = 200
 	br.Success = true
-	br.Msg = `删除成功`
+	br.Msg = `训练中`
 }

+ 8 - 0
models/ai_predict_model/ai_predict_model_index_config.go

@@ -8,6 +8,14 @@ import (
 	"time"
 )
 
+// 训练状态
+const (
+	TrainStatusWaiting  = "待训练"
+	TrainStatusTraining = "训练中"
+	TrainStatusSuccess  = "训练成功"
+	TrainStatusFailed   = "训练失败"
+)
+
 // AiPredictModelIndexConfig ai预测模型训练配置
 type AiPredictModelIndexConfig struct {
 	AiPredictModelIndexConfigId int       `gorm:"primaryKey;column:ai_predict_model_index_config_id" description:"-"`

+ 27 - 0
models/ai_predict_model/request/index_config.go

@@ -3,3 +3,30 @@ package request
 type DelConfigReq struct {
 	AiPredictModelIndexConfigId int `description:"配置id"`
 }
+
+// TrainReq
+// @Description: 训练模型请求参数
+type TrainReq struct {
+	AiPredictModelIndexId       int          `description:"模型id"`
+	AiPredictModelIndexConfigId int          `description:"配置id"`
+	Params                      ConfigParams `description:"训练参数"`
+}
+
+// ConfigParams
+// @Description: 训练参数
+type ConfigParams struct {
+	Objective       string  `json:"objective" description:"目标(回归任务),枚举值,squarederror、multi、sofamax;默认值:squarederror" `
+	LearningRate    float64 `json:"learning_rate" description:"学习率,如:0.0881"`
+	MaxDepth        int     `json:"max_depth" description:"最大深度(控制树的深度,防止过拟合),如:4;正整数,必须大于0"`
+	MinChildWeight  float64 `json:"min_child_weight" description:"最小子节点权重(防止过拟合),如:6.0601"`
+	Subsample       float64 `json:"subsample" description:"随机采样(防止过拟合),如:0.9627"`
+	ColsampleBytree float64 `json:"colsample_bytree" description:"特征随机采样(防止过拟合),如:0.7046"`
+	Gamma           float64 `json:"gamma" description:"控制分裂,如:0.4100"`
+	RegAlpha        float64 `json:"reg_alpha" description:"L1正则化系数,如:0.3738"`
+	ReqLambda       float64 `json:"reg_lambda" description:"L2正则化系数,如:1.4775"`
+	EvalMetric      string  `json:"eval_metric" description:"评估指标,枚举值,rmse、auc logloss、merror;默认值:rmse"`
+	Seed            float64 `json:"seed" description:"随机种子,如:42"`
+	MaxDeltaStep    int     `json:"max_delta_step" description:"最大步长,如:5;正整数,必须大于0"`
+	TreeMethod      string  `json:"tree_method" description:"树构建方法,枚举值:auto、exact、approx、hist;默认值:auto"`
+	NumBoostRound   int     `json:"num_boost_round" description:"迭代次数;正整数,必须大于0"`
+}

+ 8 - 8
routers/commentsRouter.go

@@ -387,8 +387,8 @@ func init() {
 
     beego.GlobalControllerRouter["eta/eta_api/controllers/data_manage/ai_predict_model:AiPredictModelIndexConfigController"] = append(beego.GlobalControllerRouter["eta/eta_api/controllers/data_manage/ai_predict_model:AiPredictModelIndexConfigController"],
         beego.ControllerComments{
-            Method: "SetCurr",
-            Router: `/index_config/set_curr`,
+            Method: "Train",
+            Router: `/index_config/train`,
             AllowHTTPMethods: []string{"post"},
             MethodParams: param.Make(),
             Filters: nil,
@@ -396,18 +396,18 @@ func init() {
 
     beego.GlobalControllerRouter["eta/eta_api/controllers/data_manage/ai_predict_model:AiPredictModelIndexConfigController"] = append(beego.GlobalControllerRouter["eta/eta_api/controllers/data_manage/ai_predict_model:AiPredictModelIndexConfigController"],
         beego.ControllerComments{
-            Method: "Train",
-            Router: `/index_config/train`,
-            AllowHTTPMethods: []string{"post"},
+            Method: "CurrVersion",
+            Router: `/index_config/version/curr`,
+            AllowHTTPMethods: []string{"get"},
             MethodParams: param.Make(),
             Filters: nil,
             Params: nil})
 
     beego.GlobalControllerRouter["eta/eta_api/controllers/data_manage/ai_predict_model:AiPredictModelIndexConfigController"] = append(beego.GlobalControllerRouter["eta/eta_api/controllers/data_manage/ai_predict_model:AiPredictModelIndexConfigController"],
         beego.ControllerComments{
-            Method: "CurrVersion",
-            Router: `/index_config/version/curr`,
-            AllowHTTPMethods: []string{"get"},
+            Method: "SetCurr",
+            Router: `/index_config/version/set_curr`,
+            AllowHTTPMethods: []string{"post"},
             MethodParams: param.Make(),
             Filters: nil,
             Params: nil})