Roc 6 일 전
부모
커밋
20df7d94da

+ 162 - 136
controllers/data_manage/ai_predict_model/index.go

@@ -168,135 +168,6 @@ func (this *AiPredictModelIndexController) List() {
 	br.Msg = "获取成功"
 }
 
-// GetAll
-// @Title 获取标的全量数据
-// @Description 获取标的全量数据
-// @Param   ClassifyId   query   int   false   "分类id"
-// @Param   IndexId   query   int   false   "模型标的ID"
-// @Param   Keyword   query   string   false   "搜索关键词"
-// @Success 200 {object} data_manage.ChartListResp
-// @router /index/all [get]
-func (this *AiPredictModelIndexController) GetAll() {
-	br := new(models.BaseResponse).Init()
-	defer func() {
-		this.Data["json"] = br
-		this.ServeJSON()
-	}()
-	sysUser := this.SysUser
-	if sysUser == nil {
-		br.Msg = "请登录"
-		br.ErrMsg = "请登录,SysUser Is Empty"
-		br.Ret = 408
-		return
-	}
-	classifyId, _ := this.GetInt("ClassifyId")
-	indexId, _ := this.GetInt("IndexId")
-	keyword := this.GetString("KeyWord")
-	if keyword == "" {
-		keyword = this.GetString("Keyword")
-	}
-	keyword = strings.TrimSpace(keyword)
-	resp := new(aiPredictModel.AiPredictModelIndexPageListResp)
-
-	// 分类
-	classifyIdName := make(map[int]string)
-	{
-		classifyOb := new(aiPredictModel.AiPredictModelClassify)
-		list, e := classifyOb.GetItemsByCondition("", make([]interface{}, 0), []string{}, "")
-		if e != nil {
-			br.Msg = "获取失败"
-			br.ErrMsg = fmt.Sprintf("获取分类失败, %v", e)
-			return
-		}
-		for _, v := range list {
-			classifyIdName[v.AiPredictModelClassifyId] = v.ClassifyName
-		}
-	}
-
-	// 筛选条件
-	highlightMap := make(map[int]string)
-	indexOb := new(aiPredictModel.AiPredictModelIndex)
-	var cond string
-	var pars []interface{}
-	{
-		if indexId > 0 {
-			cond += fmt.Sprintf(" AND %s = ?", indexOb.Cols().PrimaryId)
-			pars = append(pars, indexId)
-		}
-		if classifyId > 0 {
-			cond += fmt.Sprintf(" AND %s = ?", indexOb.Cols().ClassifyId)
-			pars = append(pars, classifyId)
-		}
-
-		// 有关键词从es中搜索
-		if keyword != "" {
-			// 使用scroll API获取所有匹配的数据
-			scrollId, list, e := elastic.SearchDataSourceIndexWithScroll(utils.EsDataSourceIndexName, keyword, utils.DATA_SOURCE_AI_PREDICT_MODEL, 0, []int{}, []int{}, []string{}, 1000)
-			if e != nil {
-				br.Msg = "获取失败"
-				br.ErrMsg = fmt.Sprintf("ES-搜索AI预测模型列表失败, %v", e)
-				return
-			}
-
-			// 如果scrollId不为空,说明还有更多数据,继续获取
-			for scrollId != "" {
-				nextScrollId, nextList, e := elastic.ScrollDataSourceIndex(utils.EsDataSourceIndexName, scrollId)
-				if e != nil {
-					br.Msg = "获取失败"
-					br.ErrMsg = fmt.Sprintf("ES-获取更多数据失败, %v", e)
-					return
-				}
-				if len(nextList) > 0 {
-					list = append(list, nextList...)
-				}
-				scrollId = nextScrollId
-			}
-
-			if len(list) == 0 {
-				resp.List = make([]*aiPredictModel.AiPredictModelIndexItem, 0)
-				br.Ret = 200
-				br.Success = true
-				br.Msg = "获取成功"
-				br.Data = resp
-				return
-			}
-			var ids []int
-			for _, v := range list {
-				ids = append(ids, v.PrimaryId)
-				highlightMap[v.PrimaryId] = v.SearchText
-			}
-			cond += fmt.Sprintf(` AND %s IN (%s)`, indexOb.Cols().PrimaryId, utils.GetOrmInReplace(len(ids)))
-			pars = append(pars, ids)
-		}
-	}
-
-	// 获取列表
-	list, e := indexOb.GetItemsByCondition(cond, pars, []string{}, "")
-	if e != nil {
-		br.Msg = "获取失败"
-		br.ErrMsg = fmt.Sprintf("获取列表失败, %v", e)
-		return
-	}
-	pageList := make([]*aiPredictModel.AiPredictModelIndexItem, 0)
-	for _, v := range list {
-		t := v.Format2Item()
-		t.ClassifyName = classifyIdName[v.ClassifyId]
-		// 搜索高亮
-		t.SearchText = v.IndexName
-		s := highlightMap[v.AiPredictModelIndexId]
-		if s != "" {
-			t.SearchText = s
-		}
-		pageList = append(pageList, t)
-	}
-
-	resp.List = pageList
-	br.Data = resp
-	br.Ret = 200
-	br.Success = true
-	br.Msg = "获取成功"
-}
-
 // Import
 // @Title 导入标的和数据
 // @Description 导入标的和数据
@@ -1377,13 +1248,10 @@ func (this *AiPredictModelIndexController) GetCurrentRunningAiPredictModelIndexC
 // @Description 获取当前正在运行中的模型数量
 // @Success 200 Ret=200 保存成功
 // @Success 200 {object} response.CurrentRunningCountResp
-// @router /index/run [get]
+// @router /index/run [post]
 func (this *AiPredictModelIndexController) Run() {
 	br := new(models.BaseResponse).Init()
 	defer func() {
-		if br.ErrMsg == "" {
-			br.IsSendEmail = false
-		}
 		this.Data["json"] = br
 		this.ServeJSON()
 	}()
@@ -1395,6 +1263,17 @@ func (this *AiPredictModelIndexController) Run() {
 		return
 	}
 
+	var req request.AiPredictModelIndexRunReq
+	if e := json.Unmarshal(this.Ctx.Input.RequestBody, &req); e != nil {
+		br.Msg = "参数解析异常"
+		br.ErrMsg = fmt.Sprintf("参数解析异常, %v", e)
+		return
+	}
+
+	classifyId := req.ClassifyId
+	//indexId, _ := this.GetInt("IndexId")
+	keyword := strings.TrimSpace(req.Keyword)
+
 	// 查找当前标的是否存在待训练/训练中的模型
 	count, err := services.GetCurrentRunningAiPredictModelIndexCount()
 	if err != nil {
@@ -1402,13 +1281,160 @@ func (this *AiPredictModelIndexController) Run() {
 		br.ErrMsg = "训练失败,查找待训练的模型失败,Err:" + err.Error()
 		return
 	}
+	if count > 0 {
+		br.Msg = "当前有模型正在训练/运行中,请勿重复训练"
+		br.ErrMsg = "当前有模型正在训练/运行中,请勿重复训练"
+		br.IsSendEmail = false
+		return
+	}
 
-	resp := response.CurrentRunningCountResp{
-		Total: count,
+	// 分类
+	classifyIdName := make(map[int]string)
+	{
+		classifyOb := new(aiPredictModel.AiPredictModelClassify)
+		list, e := classifyOb.GetItemsByCondition("", make([]interface{}, 0), []string{}, "")
+		if e != nil {
+			br.Msg = "获取失败"
+			br.ErrMsg = fmt.Sprintf("获取分类失败, %v", e)
+			return
+		}
+		for _, v := range list {
+			classifyIdName[v.AiPredictModelClassifyId] = v.ClassifyName
+		}
 	}
 
-	br.Data = resp
+	// 筛选条件
+	highlightMap := make(map[int]string)
+	indexOb := new(aiPredictModel.AiPredictModelIndex)
+	var cond string
+	var pars []interface{}
+
+	cond += fmt.Sprintf(` AND %s NOT IN (?)  AND %s NOT IN (?) `, indexOb.Cols().TrainStatus, indexOb.Cols().RunStatus)
+	pars = append(pars, []string{aiPredictModel.TrainStatusWaiting, aiPredictModel.TrainStatusTraining}, []string{aiPredictModel.RunStatusWaiting, aiPredictModel.RunStatusRunning})
+
+	if req.SelectAll {
+		// 如果列表全选
+		if classifyId > 0 {
+			cond += fmt.Sprintf(" AND %s = ?", indexOb.Cols().ClassifyId)
+			pars = append(pars, classifyId)
+		}
+
+		// 有关键词从es中搜索
+		if keyword != "" {
+			// 使用scroll API获取所有匹配的数据
+			list, e := getAllSearchDataSource(keyword, utils.DATA_SOURCE_AI_PREDICT_MODEL, 0)
+			if e != nil {
+				br.Msg = "获取失败"
+				br.ErrMsg = fmt.Sprintf("获取失败,Err:%v", e)
+				return
+			}
+
+			if len(list) == 0 {
+				br.Msg = "没有找到可以运行的标的"
+				br.IsSendEmail = false
+				return
+			}
+			var ids []int
+			for _, v := range list {
+				ids = append(ids, v.PrimaryId)
+				highlightMap[v.PrimaryId] = v.SearchText
+			}
+			cond += fmt.Sprintf(` AND %s IN (%s)`, indexOb.Cols().PrimaryId, utils.GetOrmInReplace(len(ids)))
+			pars = append(pars, ids)
+		}
+
+		// 不勾选的标的
+		if len(req.NotIndexIdList) > 0 {
+			var ids []int
+			for _, v := range req.NotIndexIdList {
+				ids = append(ids, v)
+			}
+			cond += fmt.Sprintf(` AND %s NOT IN (%s)`, indexOb.Cols().PrimaryId, utils.GetOrmInReplace(len(ids)))
+			pars = append(pars, ids)
+		}
+	} else {
+		// 如果不是列表全选
+		if len(req.IndexIdList) <= 0 {
+			br.Msg = `请选择标的`
+			br.IsSendEmail = false
+			return
+		}
+		var ids []int
+		for _, v := range req.IndexIdList {
+			ids = append(ids, v)
+		}
+		cond += fmt.Sprintf(` AND %s IN (%s)`, indexOb.Cols().PrimaryId, utils.GetOrmInReplace(len(ids)))
+		pars = append(pars, ids)
+	}
+
+	// 获取列表
+	list, e := indexOb.GetItemsByCondition(cond, pars, []string{`ai_predict_model_index_id`}, "")
+	if e != nil {
+		br.Msg = "获取失败"
+		br.ErrMsg = fmt.Sprintf("获取列表失败, %v", e)
+		return
+	}
+
+	indexIdList := make([]int, 0)
+	for _, v := range list {
+		indexIdList = append(indexIdList, v.AiPredictModelIndexId)
+	}
+
+	br.Data = indexIdList
 	br.Ret = 200
 	br.Success = true
 	br.Msg = "获取成功"
 }
+
+// getAllSearchDataSource
+// @Description: 根据条件获取ES中的所有数据
+// @author: Roc
+// @datetime 2025-05-08 11:15:27
+// @param keyword string
+// @param source int
+// @param subSource int
+// @return list []*dataSourceModel.SearchDataSourceItem
+// @return err error
+func getAllSearchDataSource(keyword string, source, subSource int) (list []*dataSourceModel.SearchDataSourceItem, err error) {
+	dataLimit := 1000 // 每页获取的数据量
+	var scrollId string
+	// 使用scroll API获取所有匹配的数据
+	scrollId, list, e := elastic.SearchDataSourceIndexWithScroll(utils.EsDataSourceIndexName, keyword, source, subSource, []int{}, []int{}, []string{}, dataLimit)
+	if e != nil {
+		err = fmt.Errorf("ES-搜索列表失败, %v", e)
+		return
+	}
+	// 最终清理用的 scrollId 放到一个指针里,保证 defer 拿到最新值
+	finalScrollId := &scrollId
+
+	defer func() {
+		// 清除滚动查询的缓存
+		if *finalScrollId != "" {
+			elastic.ClearScrollDataSourceIndex(*finalScrollId)
+		}
+	}()
+
+	// 如果scrollId不为空,说明还有更多数据,继续获取
+	for scrollId != "" {
+		nextScrollId, nextList, e := elastic.ScrollDataSourceIndex(scrollId)
+		if e != nil {
+			err = fmt.Errorf("ES-获取更多数据失败, %v", e)
+			return
+		}
+
+		// 如果没有更多数据,则退出循环
+		if len(nextList) <= 0 {
+			break
+		}
+
+		list = append(list, nextList...)
+
+		if nextScrollId != `` {
+			scrollId = nextScrollId
+			*finalScrollId = scrollId // 更新 finalScrollId 的指向内容
+		}
+
+	}
+
+	return
+}

+ 4 - 4
models/ai_predict_model/ai_predict_model_index.go

@@ -13,10 +13,10 @@ import (
 
 // 训练状态
 const (
-	RunStatusWaiting  = "待运行"
-	RunStatusTraining = "运行中"
-	RunStatusSuccess  = "运行成功"
-	RunStatusFailed   = "运行失败"
+	RunStatusWaiting = "待运行"
+	RunStatusRunning = "运行中"
+	RunStatusSuccess = "运行成功"
+	RunStatusFailed  = "运行失败"
 )
 
 // AiPredictModelIndex AI预测模型标的

+ 9 - 0
models/ai_predict_model/request/index.go

@@ -4,3 +4,12 @@ type AiPredictModelIndexSaveScriptPathReq struct {
 	IndexId    int    `description:"指标ID"`
 	ScriptPath string `description:"脚本的路径"`
 }
+
+// AiPredictModelIndexRunReq AI预测模型运行请求参数
+type AiPredictModelIndexRunReq struct {
+	ClassifyId     int    `description:"分类ID"`
+	Keyword        string `description:"关键词-指标ID/指标名称"`
+	IndexIdList    []int  `description:"选中的模型ID列表,SelectAll-false时,会用到这个字段"`
+	NotIndexIdList []int  `description:"排除的模型ID列表,SelectAll-true时,会用到这个字段"`
+	SelectAll      bool   `description:"列表全选"`
+}

+ 9 - 0
routers/commentsRouter.go

@@ -466,6 +466,15 @@ func init() {
             Filters: nil,
             Params: nil})
 
+    beego.GlobalControllerRouter["eta/eta_api/controllers/data_manage/ai_predict_model:AiPredictModelIndexController"] = append(beego.GlobalControllerRouter["eta/eta_api/controllers/data_manage/ai_predict_model:AiPredictModelIndexController"],
+        beego.ControllerComments{
+            Method: "Run",
+            Router: `/index/run`,
+            AllowHTTPMethods: []string{"post"},
+            MethodParams: param.Make(),
+            Filters: nil,
+            Params: nil})
+
     beego.GlobalControllerRouter["eta/eta_api/controllers/data_manage/ai_predict_model:AiPredictModelIndexController"] = append(beego.GlobalControllerRouter["eta/eta_api/controllers/data_manage/ai_predict_model:AiPredictModelIndexController"],
         beego.ControllerComments{
             Method: "GetCurrentRunningAiPredictModelIndexCount",

+ 1 - 1
services/ai_predict_model_index.go

@@ -651,6 +651,6 @@ func GetAiPredictConfigChartDetailByData(indexName string, indexConfigItem *aiPr
 // @return err error
 func GetCurrentRunningAiPredictModelIndexCount() (total int, err error) {
 	obj := new(aiPredictModel.AiPredictModelIndex)
-	total, err = obj.GetCountByCondition(" AND (train_status in (?) OR run_status in (?) ) ", []interface{}{[]string{aiPredictModel.TrainStatusTraining, aiPredictModel.TrainStatusWaiting}, []string{aiPredictModel.RunStatusTraining, aiPredictModel.RunStatusWaiting}})
+	total, err = obj.GetCountByCondition(" AND (train_status in (?) OR run_status in (?) ) ", []interface{}{[]string{aiPredictModel.TrainStatusTraining, aiPredictModel.TrainStatusWaiting}, []string{aiPredictModel.RunStatusRunning, aiPredictModel.RunStatusWaiting}})
 	return
 }

+ 33 - 1
services/elastic/elastic.go

@@ -10,6 +10,7 @@ import (
 	dataSourceModel "eta/eta_api/models/data_source"
 	"eta/eta_api/utils"
 	"fmt"
+	"io"
 	"strconv"
 	"strings"
 
@@ -2354,7 +2355,7 @@ func SearchDataSourceIndexWithScroll(indexName, keyword string, source, subSourc
 }
 
 // ScrollDataSourceIndex 使用scroll API获取下一页数据
-func ScrollDataSourceIndex(indexName, scrollId string) (nextScrollId string, list []*dataSourceModel.SearchDataSourceItem, err error) {
+func ScrollDataSourceIndex(scrollId string) (nextScrollId string, list []*dataSourceModel.SearchDataSourceItem, err error) {
 	list = make([]*dataSourceModel.SearchDataSourceItem, 0)
 	defer func() {
 		if err != nil {
@@ -2368,6 +2369,10 @@ func ScrollDataSourceIndex(indexName, scrollId string) (nextScrollId string, lis
 	request := client.Scroll().ScrollId(scrollId).Scroll("5m")
 	searchResp, e := request.Do(context.Background())
 	if e != nil {
+		if errors.Is(e, io.EOF) {
+			// 结束了,没有更多的数据了
+			return
+		}
 		err = fmt.Errorf("scroll do err: %v", e)
 		return
 	}
@@ -2379,6 +2384,8 @@ func ScrollDataSourceIndex(indexName, scrollId string) (nextScrollId string, lis
 		return
 	}
 
+	//理论上 scrollId 在整个 Scroll 生命周期中是固定不变的。
+	//不过为了代码可读性和未来可能的扩展性,仍保留 nextScrollId = searchResp.ScrollId 这一行代码,这样即使逻辑迁移或改动,也不会出错。
 	nextScrollId = searchResp.ScrollId
 	searchMap := make(map[string]string)
 	for _, v := range searchResp.Hits.Hits {
@@ -2406,3 +2413,28 @@ func ScrollDataSourceIndex(indexName, scrollId string) (nextScrollId string, lis
 	}
 	return
 }
+
+// ClearScrollDataSourceIndex
+// @Description: 清除scroll查询
+// @author: Roc
+// @datetime 2025-05-08 10:17:01
+// @param scrollId string
+// @return nextScrollId string
+// @return list []*dataSourceModel.SearchDataSourceItem
+// @return err error
+func ClearScrollDataSourceIndex(scrollId string) (nextScrollId string, list []*dataSourceModel.SearchDataSourceItem, err error) {
+	list = make([]*dataSourceModel.SearchDataSourceItem, 0)
+	defer func() {
+		if err != nil {
+			tips := fmt.Sprintf("ScrollDataSourceIndex err: %v", err)
+			utils.FileLog.Info(tips)
+		}
+	}()
+	client := utils.EsClient
+
+	// 使用scroll API获取下一页
+	request := client.ClearScroll().ScrollId(scrollId)
+	_, err = request.Do(context.Background())
+
+	return
+}