Browse Source

研究研报告媒体风险等级筛选

kobe6258 6 months ago
parent
commit
2520064c44

+ 43 - 0
api/ht_account_facade.go

@@ -0,0 +1,43 @@
+package api
+
+import (
+	"eta/eta_mini_ht_api/common/component/config"
+	"eta/eta_mini_ht_api/common/utils/client"
+	"fmt"
+	"sync"
+)
+
+const (
+	clientSuitInfoUrl = "getClientSuitInfo"
+)
+
+var (
+	htFacadeOnce sync.Once
+
+	htFacade *HTAccountFacade
+)
+
+type HTAccountFacade struct {
+	htConfig *config.HTBizConfig
+	// HTTP请求客户端
+	client *client.HttpClient
+}
+
+func (f *HTAccountFacade) GetInstance() *HTAccountFacade {
+	htFacadeOnce.Do(func() {
+		htFacade = &HTAccountFacade{
+			htConfig: config.GetConfig("HT").(*config.HTBizConfig),
+			client:   client.DefaultClient()}
+	})
+	return htFacade
+}
+
+func (f *HTAccountFacade) GetCustomerRiskLevelInfo() string {
+	url := f.htConfig.GetAccountApiUrl() + clientSuitInfoUrl
+	resp, err := f.client.Post(url, nil)
+	if err != nil {
+		return ""
+	}
+	fmt.Sprintln(resp)
+	return ""
+}

+ 16 - 10
common/component/config/ht_biz_config.go

@@ -4,11 +4,12 @@ import "eta/eta_mini_ht_api/common/contants"
 
 // ESOpts es连接属性
 type HTOpts struct {
-	ReportIndex string
-	MediaIndex  string
-	Encode      string
-	DesCode     string
-	Task        string
+	ReportIndex   string
+	MediaIndex    string
+	Encode        string
+	DesCode       string
+	Task          string
+	AccountApiUrl string
 }
 type HTBizConfig struct {
 	BaseConfig
@@ -35,16 +36,21 @@ func (e *HTBizConfig) EnableTask() bool {
 	}
 	return false
 }
+
+func (e *HTBizConfig) GetAccountApiUrl() string {
+	return e.opts.AccountApiUrl
+}
 func (e *HTBizConfig) GetDesCode() string {
 	return e.opts.DesCode
 }
 func (e *HTBizConfig) InitConfig() {
 	opts := HTOpts{
-		ReportIndex: e.GetString("es_report_index"),
-		MediaIndex:  e.GetString("es_media_index"),
-		Encode:      e.GetString("response.encode"),
-		DesCode:     e.GetString("response.des_code"),
-		Task:        e.GetString("task"),
+		ReportIndex:   e.GetString("es_report_index"),
+		MediaIndex:    e.GetString("es_media_index"),
+		Encode:        e.GetString("response.encode"),
+		DesCode:       e.GetString("response.des_code"),
+		Task:          e.GetString("task"),
+		AccountApiUrl: e.GetString("api.account_url"),
 	}
 	e.opts = opts
 }

+ 24 - 5
common/utils/client/http_client.go

@@ -2,10 +2,14 @@ package client
 
 import (
 	"context"
+	"encoding/json"
+	"errors"
+	logger "eta/eta_mini_ht_api/common/component/log"
 	"fmt"
 	"io"
 	"log"
 	"net/http"
+	"strings"
 	"time"
 )
 
@@ -43,13 +47,18 @@ func defaultRetryDelayFunc(attempt int) time.Duration {
 
 type RetryDelayFunc func(attempt int) time.Duration
 
+func retryErr(err error) bool {
+	return errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled)
+}
+
 // DoWithRetry 发送带有重试机制的HTTP请求,允许用户自定义重试延迟逻辑
-func (hc *HttpClient) DoWithRetry(ctx context.Context, req *http.Request) (*http.Response, error) {
+func (hc *HttpClient) DoWithRetry(ctx context.Context, req *http.Request) (resp *http.Response, err error) {
 	attempt := 0
 	for {
-		resp, err := hc.Do(req.WithContext(ctx))
-		if err != nil {
+		resp, err = hc.Do(req.WithContext(ctx))
+		if err != nil && retryErr(err) {
 			if attempt >= hc.maxRetries {
+
 				return nil, fmt.Errorf("请求失败: %w", err)
 			}
 			attempt++
@@ -57,11 +66,21 @@ func (hc *HttpClient) DoWithRetry(ctx context.Context, req *http.Request) (*http
 			time.Sleep(delay)
 			continue
 		}
-		return resp, nil
+		return
 	}
 }
 
-func (hc *HttpClient) Post(url string, contentType string, buf io.Reader) (resp *http.Response, err error) {
+func (hc *HttpClient) Post(url string, data interface{}) (resp *http.Response, err error) {
+	dataStr, err := json.Marshal(data)
+	if err != nil {
+		logger.Error("请求data json序列化失败,err:" + err.Error())
+	}
+	body := io.NopCloser(strings.NewReader(string(dataStr)))
+	req, err := http.NewRequest(http.MethodPost, url, body)
+	if err != nil {
+		logger.Error("创建POST请求失败: %v", err)
+	}
+	resp, err = hc.DoWithRetry(req.Context(), req)
 	return
 }
 

+ 0 - 6
controllers/report/report_controller.go

@@ -182,7 +182,6 @@ func (r *ReportController) HotRanked(permissionIds string, limit int) {
 				for _, permissionIdWIthRisk := range permissionIdsWithRisk {
 					if permissionId == permissionIdWIthRisk {
 						filterPermissionIds = append(filterPermissionIds, permissionId)
-						break
 					}
 				}
 			}
@@ -190,7 +189,6 @@ func (r *ReportController) HotRanked(permissionIds string, limit int) {
 				for _, permissionId := range filterPermissionIds {
 					if _, ok := item.SecondPermissions[permissionId]; ok {
 						filterList = append(filterList, item)
-						break
 					}
 				}
 			}
@@ -199,7 +197,6 @@ func (r *ReportController) HotRanked(permissionIds string, limit int) {
 				for _, permissionId := range permissionIdsWithRisk {
 					if _, ok := item.SecondPermissions[permissionId]; ok {
 						filterList = append(filterList, item)
-						break
 					}
 				}
 			}
@@ -260,7 +257,6 @@ func (r *ReportController) PublishRanked(permissionIds string, limit int, week b
 				for _, permissionIdWIthRisk := range permissionIdsWithRisk {
 					if permissionId == permissionIdWIthRisk {
 						filterPermissionIds = append(filterPermissionIds, permissionId)
-						break
 					}
 				}
 			}
@@ -268,7 +264,6 @@ func (r *ReportController) PublishRanked(permissionIds string, limit int, week b
 				for _, permissionId := range filterPermissionIds {
 					if _, ok := item.SecondPermissions[permissionId]; ok {
 						filterList = append(filterList, item)
-						break
 					}
 				}
 			}
@@ -277,7 +272,6 @@ func (r *ReportController) PublishRanked(permissionIds string, limit int, week b
 				for _, permissionId := range permissionIdsWithRisk {
 					if _, ok := item.SecondPermissions[permissionId]; ok {
 						filterList = append(filterList, item)
-						break
 					}
 				}
 			}

+ 22 - 4
controllers/user/analyst_controller.go

@@ -81,15 +81,24 @@ func (an *AnalystController) AnalystReportList(analystName string) {
 			Current:  an.PageInfo.Current,
 			PageSize: an.PageInfo.PageSize,
 		}
+		userInfo := an.Data["user"].(user.User)
+		var reportIds []int
+		pageRes.Total, pageRes.LatestId, reportIds = report.RangeSearchByAnalyst(analystName, userInfo.Id)
+		if len(reportIds) == 0 {
+			reports := new(page.PageResult)
+			reports.Data = []interface{}{}
+			reports.Page = pageRes
+			an.SuccessResult("分页获取研究员报告列表成功", reports, result)
+		}
 		if an.PageInfo.LatestId == 0 {
-			pageRes.Total, pageRes.LatestId = report.GetTotalPageCountByAnalyst(analystName)
+			//pageRes.Total, pageRes.LatestId = report.GetTotalPageCountByAnalyst(analystName)
 			an.PageInfo.LatestId = pageRes.LatestId
 		} else {
 			pageRes.LatestId = an.PageInfo.LatestId
 			pageRes.Total = an.PageInfo.Total
 		}
 		pageRes.TotalPage = page.TotalPages(pageRes.Total, pageRes.PageSize)
-		list, err := report.GetReportPageByAnalyst(an.PageInfo, analystName)
+		list, err := report.GetReportPageByAnalyst(an.PageInfo, analystName, reportIds)
 		if err != nil {
 			an.FailedResult("分页获取研究员报告列表失败", result)
 			return
@@ -118,8 +127,17 @@ func (an *AnalystController) MediaList(mediaType string, analystId int) {
 			Current:  an.PageInfo.Current,
 			PageSize: an.PageInfo.PageSize,
 		}
+		userInfo := an.Data["user"].(user.User)
+		var mediaIds []int
+		pageRes.Total, pageRes.LatestId, mediaIds = media.RangeSearchByAnalyst(mediaType, analystId, userInfo.Id)
+		if len(mediaIds) == 0 {
+			mediaList := new(page.PageResult)
+			mediaList.Data = []interface{}{}
+			mediaList.Page = pageRes
+			an.SuccessResult("分页查询研究员媒体列表成功", mediaList, result)
+		}
 		if an.PageInfo.LatestId == 0 {
-			pageRes.Total, pageRes.LatestId = media.GetTotalPageCountByAnalystId(mediaType, analystId)
+			//	pageRes.Total, pageRes.LatestId = media.GetTotalPageCountByAnalystId(mediaType, analystId)
 			an.PageInfo.LatestId = pageRes.LatestId
 			an.PageInfo.Total = pageRes.Total
 		} else {
@@ -127,7 +145,7 @@ func (an *AnalystController) MediaList(mediaType string, analystId int) {
 			pageRes.Total = an.PageInfo.Total
 		}
 		pageRes.TotalPage = page.TotalPages(pageRes.Total, pageRes.PageSize)
-		list, err := media.GetMediaPageByAnalystId(mediaType, an.PageInfo, analystId)
+		list, err := media.GetMediaPageByAnalystId(mediaType, an.PageInfo, analystId, mediaIds)
 		if err != nil {
 			an.FailedResult("分页查询研究员媒体列表失败", result)
 			return

+ 1 - 1
controllers/ws_controller.go

@@ -28,7 +28,7 @@ func (c *WebSocketController) Connect() {
 
 	for {
 		var msg string
-		_, _, err := ws.ReadMessage()
+		_, _, err = ws.ReadMessage()
 		if err != nil {
 			//	beego.Error(err)
 			break

+ 14 - 2
domian/media/media_service.go

@@ -125,6 +125,18 @@ func GetMediaPermissionMappingByPermissionIds(mediaType string, permissionIds []
 	})
 	return int64(len(ids)), int64(ids[0]), ids
 }
+func GetAnalystMediaPermissionMappingByPermissionIds(mediaType string, permissionIds []int, analystId int) (total int64, latestId int64, ids []int) {
+	ids, err := mediaDao.GetMediaPermissionMappingByPermissionId(mediaType, permissionIds)
+	if err != nil {
+		logger.Error("获取当前最大媒体id失败:%v", err)
+		return 0, 0, ids
+	}
+	ids, err = mediaDao.GetAnalystMediaRangeReportIds(ids, analystId)
+	sort.Slice(ids, func(i, j int) bool {
+		return ids[i] > ids[j]
+	})
+	return int64(len(ids)), int64(ids[0]), ids
+}
 func GetTotalPageCount(mediaType string) (count int64, latestId int64) {
 	return mediaDao.GetCountByMediaType(mediaType)
 }
@@ -143,9 +155,9 @@ func GetMediaPageByIds(mediaType string, pageInfo page.PageInfo, mediaIds []int)
 	return
 }
 
-func GetMediaPageByAnalystId(mediaType string, pageInfo page.PageInfo, analystId int) (list []MediaDTO, err error) {
+func GetMediaPageByAnalystId(mediaType string, pageInfo page.PageInfo, analystId int, mediaIds []int) (list []MediaDTO, err error) {
 	offset := page.StartIndex(pageInfo.Current, pageInfo.PageSize)
-	medias, err := mediaDao.GetMediaPageByAnalystId(pageInfo.LatestId, pageInfo.PageSize, offset, mediaType, analystId)
+	medias, err := mediaDao.GetMediaPageByAnalystId(pageInfo.LatestId, pageInfo.PageSize, offset, mediaType, analystId, mediaIds)
 	if err == nil && medias != nil {
 		for _, media := range medias {
 			dto := convertMediaDTO(media, false)

+ 72 - 4
domian/report/report_service.go

@@ -131,8 +131,76 @@ func GetGetReportById(reportId int) (ReportDTO ReportDTO, err error) {
 func GetTotalPageCount() (total int64, latestId int64, err error) {
 	return reportDao.GetTotalPageCount()
 }
-func GetTotalPageCountByAnalyst(analyst string) (total int64, latestId int64) {
-	return reportDao.GetTotalPageCountByAnalyst(analyst)
+func GetTotalPageCountByAnalyst(analyst string, permissionIds []int) (total int64, latestId int64, ids []int) {
+	ids, err := reportDao.GetReportsByAnalyst(analyst)
+	if err != nil {
+		logger.Error("查询研究研报告列表id失败:%v", err)
+		return
+	}
+	//查询这些包含在列表中的权限的报告ids
+	htOrgIds, err := GetHTReportIdsByPermissionIds(permissionIds)
+	if err != nil {
+		logger.Error("品种筛选ht报告id失败:%v", err)
+		htOrgIds = []int{}
+	}
+	etaOrgIds, err := GetETAReportIdsByPermissionIds(permissionIds)
+	if err != nil {
+		logger.Error("品种筛选eta报告id失败:%v", err)
+		etaOrgIds = []int{}
+	}
+	if len(etaOrgIds) == 0 && len(htOrgIds) == 0 {
+		logger.Info("没有符合权限的研报")
+		return
+	}
+	orgIds := make(map[string][]int, 2)
+	if len(etaOrgIds) == 0 {
+		orgIds["ETA"] = []int{}
+	} else {
+		orgIds["ETA"] = etaOrgIds
+	}
+	if len(htOrgIds) == 0 {
+		orgIds["HT"] = []int{}
+	} else {
+		orgIds["HT"] = htOrgIds
+	}
+	permitReportIds, err := reportDao.GetReportIdListByOrgIds(orgIds)
+	if err != nil {
+		logger.Error("根据原始报告id获取报告id列表失败:%v", err)
+		return
+	}
+	var filterReportIds []int
+	for _, id := range ids {
+		for _, permitReportId := range permitReportIds {
+			if id == permitReportId {
+				filterReportIds = append(filterReportIds, id)
+			}
+		}
+	}
+	if len(filterReportIds) == 0 {
+		logger.Info("没有符合权限的研究员研报")
+		return
+	}
+	ids = filterReportIds
+	total = int64(len(filterReportIds))
+	latestId = int64(findMax(filterReportIds))
+	return
+}
+
+// findMaxWithError 函数用于找到整型数组中的最大值,并返回错误信息
+func findMax(nums []int) (max int) {
+	if len(nums) == 0 {
+		return 0
+	}
+	// 初始化最大值为数组的第一个元素
+	max = nums[0]
+
+	// 遍历数组,找到最大值
+	for _, num := range nums {
+		if num > max {
+			max = num
+		}
+	}
+	return
 }
 func SearchMaxReportId(key string) (total int64, reportId int64) {
 	sort := []string{"reportId:desc"}
@@ -196,9 +264,9 @@ func SearchReportList(key string, ids []int, from int, size int, max int64) (rep
 	}
 	return
 }
-func GetReportPageByAnalyst(pageInfo page.PageInfo, analyst string) (list []ReportDTO, err error) {
+func GetReportPageByAnalyst(pageInfo page.PageInfo, analyst string, reportIds []int) (list []ReportDTO, err error) {
 	offset := page.StartIndex(pageInfo.Current, pageInfo.PageSize)
-	reports, err := reportDao.GetReportPageByAnalyst(pageInfo.LatestId, pageInfo.PageSize, offset, analyst)
+	reports, err := reportDao.GetReportPageByAnalyst(pageInfo.LatestId, pageInfo.PageSize, offset, analyst, reportIds)
 	if err != nil {
 		logger.Error("分页查询报告列表失败:%v", err)
 		return

+ 2 - 0
middleware/auth_middleware.go

@@ -72,6 +72,8 @@ var privateRoutes = []string{
 	"/user/message",
 	"/analyst/analystDetail",
 	"/analyst/list",
+	"/analyst/reportList",
+	"/analyst/mediaList",
 	"/media/count",
 	"/report/count",
 }

+ 13 - 3
models/media/media.go

@@ -104,7 +104,17 @@ func GetMediaPageByIds(latestId int64, limit int, offset int, mediaType string,
 	err = db.Select(CommonColumns).Where(" id<= ? and media_type= ? and id in ? and deleted =?", latestId, mediaType, mediasIds, false).Order("created_time desc").Limit(limit).Offset(offset).Find(&mediaList).Error
 	return
 }
-func GetMediaPageByAnalystId(latestId int64, limit int, offset int, mediaType string, analystId int) (mediaList []Media, err error) {
+
+func GetAnalystMediaRangeReportIds(srcIds []int, analystId int) (mediaIds []int, err error) {
+	if len(srcIds) == 0 {
+		logger.Info("过滤的媒体ID为空")
+		return
+	}
+	db := models.Main()
+	err = db.Select(CommonColumns).Where(" id in ? and author_id = ? and id in ? and deleted =?", mediaIds, analystId, false).Find(&mediaIds).Error
+	return
+}
+func GetMediaPageByAnalystId(latestId int64, limit int, offset int, mediaType string, analystId int, mediaIds []int) (mediaList []Media, err error) {
 	if latestId < 0 {
 		err = errors.New("非法的id参数")
 		logger.Error("非法的id参数:%d", latestId)
@@ -115,12 +125,12 @@ func GetMediaPageByAnalystId(latestId int64, limit int, offset int, mediaType st
 		logger.Error("非法的limit参数:%d", limit)
 	}
 	db := models.Main()
-	err = db.Select(CommonColumns).Where(" id<= ? and media_type= ? and author_id = ?", latestId, mediaType, analystId).Order("published_time desc").Limit(limit).Offset(offset).Find(&mediaList).Error
+	err = db.Select(CommonColumns).Where(" id<= ?  and media_type= ? and author_id = ? and id in ? and deleted = ?", latestId, mediaType, analystId, mediaIds, false).Order("published_time desc").Limit(limit).Offset(offset).Find(&mediaList).Error
 	return
 }
 func GetMediaById(mediaType string, mediaId int) (media Media, err error) {
 	db := models.Main()
-	err = db.Select(DetailColumns).Where("id =? and media_type=?", mediaId, mediaType).First(&media).Error
+	err = db.Select(DetailColumns).Where("id =? and media_type=? and deleted =?", mediaId, mediaType, false).First(&media).Error
 	return
 }
 

+ 3 - 2
models/media/media_permission_mapping.go

@@ -1,6 +1,8 @@
 package media
 
-import "eta/eta_mini_ht_api/models"
+import (
+	"eta/eta_mini_ht_api/models"
+)
 
 // MediaPermissionMapping 表示媒体权限映射
 type MediaPermissionMapping struct {
@@ -15,7 +17,6 @@ func GetMediaPermissionMappingByPermissionId(mediaType string, permissionIds []i
 	err = db.Model(&MediaPermissionMapping{}).Select("DISTINCT media_id").Where("media_type = ? and deleted =? and permission_id in ?", mediaType, false, permissionIds).Scan(&mediaIds).Error
 	return
 }
-
 func InsertMediaPermissionMapping(mediaType MediaPermissionMapping) {
 	db := models.Main()
 	_ = db.Create(&mediaType).Error

+ 15 - 3
models/report/report.go

@@ -268,6 +268,7 @@ func GetMaxIdByPermissionIds(orgIds map[string][]int) (total int64, maxId int64,
 	}
 	return
 }
+
 func GetTotalPageCount() (total int64, latestId int64, err error) {
 	db := models.Main()
 	err = db.Model(&Report{}).Select("MAX(id) id").Where("status = ?", StatusPublish).Scan(&latestId).Error
@@ -280,7 +281,14 @@ func GetTotalPageCount() (total int64, latestId int64, err error) {
 	}
 	return
 }
-
+func GetReportsByAnalyst(analyst string) (ids []int, err error) {
+	db := models.Main()
+	err = db.Model(&Report{}).Select("id").Where("status = ?", StatusPublish).Where("author like ?", "%"+analyst+"%").Scan(&ids).Error
+	if err != nil {
+		return
+	}
+	return
+}
 func GetTotalPageCountByAnalyst(analyst string) (total int64, latestId int64) {
 	db := models.Main()
 	err := db.Model(&Report{}).Where("status = ?", StatusPublish).Where("author like ?", "%"+analyst+"%").Count(&total).Error
@@ -346,7 +354,11 @@ func GetNewReportByPublishTime(time time.Time) (list []Report) {
 	}
 	return
 }
-func GetReportPageByAnalyst(latestId int64, limit int, offset int, analyst string) (list []Report, err error) {
+func GetReportPageByAnalyst(latestId int64, limit int, offset int, analyst string, reportIds []int) (list []Report, err error) {
+	if len(reportIds) == 0 {
+		logger.Info("reportIds为空")
+		return
+	}
 	if latestId < 0 {
 		err = errors.New("非法的id参数")
 		logger.Error("非法的id参数:%d", latestId)
@@ -357,6 +369,6 @@ func GetReportPageByAnalyst(latestId int64, limit int, offset int, analyst strin
 		logger.Error("非法的limit参数:%d", limit)
 	}
 	db := models.Main()
-	err = db.Select(CommonColumns).Where("status = ?", StatusPublish).Where("id<= ? and author like  ?", latestId, "%"+analyst+"%").Order("published_time desc").Limit(limit).Offset(offset).Find(&list).Error
+	err = db.Select(CommonColumns).Where("status = ?", StatusPublish).Where("id<= ?").Where("id in ? and author like  ?", latestId, reportIds, "%"+analyst+"%").Order("published_time desc").Limit(limit).Offset(offset).Find(&list).Error
 	return
 }

+ 43 - 2
service/media/media_service.go

@@ -65,6 +65,47 @@ func GetTotalPageCountByAnalystId(mediaType string, analystId int) (total int64,
 	return mediaService.GetTotalPageCountByAnalystId(mediaType, analystId)
 }
 
+func RangeSearchByAnalyst(mediaType string, analystId int, userId int) (total int64, latestId int64, ids []int) {
+	var err error
+	//登录了需要校验风险等级,如果风险等级没做或者过期直接返回空,做了就筛选风险等级
+	userProfile, userErr := user.GetUserProfile(userId)
+	if userErr != nil {
+		if errors.Is(userErr, gorm.ErrRecordNotFound) {
+			err = exception.New(exception.TemplateUserNotFound)
+		} else {
+			err = exception.New(exception.TemplateUserFoundFailed)
+		}
+		logger.Error("分页查询报告列表失败:%v", err)
+		return
+	}
+	//获取产品风险等级
+	if userProfile.RiskLevel == user.RiskUnTest {
+		logger.Error("客户未做风险等级测评,mobile:%d", userProfile.Mobile)
+		return
+	}
+	if userProfile.RiskLevelStatus == user.RiskExpired {
+		logger.Error("客户风险等级已过期,mobile:%d", userProfile.Mobile)
+		return
+	}
+	mapping, mappingErr := permissionService.GetRiskMappingByCustomerRiskLevel(userProfile.RiskLevel)
+	if mappingErr != nil {
+		logger.Error("查询产品风险等级映射失败:%v", mappingErr)
+		return
+	}
+	var permissionList []permissionService.PermissionDTO
+	//获取所有设置风险等级的品种
+	permissionList, err = permissionService.GetPermissionListWithRisk()
+	permissionList = filterPermissionsByRisk(permissionList, mapping.ProductRiskLevel)
+	if len(permissionList) == 0 {
+		return
+	}
+	var filterPermissionIds []int
+	for _, permission := range permissionList {
+		filterPermissionIds = append(filterPermissionIds, permission.PermissionId)
+	}
+	return mediaService.GetAnalystMediaPermissionMappingByPermissionIds(mediaType, filterPermissionIds, analystId)
+}
+
 func RangeSearch(mediaType string, isLogin bool, userId int) (total int64, latestId int64, ids []int) {
 	var err error
 	//登录了需要校验风险等级,如果风险等级没做或者过期直接返回空,做了就筛选风险等级
@@ -341,8 +382,8 @@ func getLowestRiskLevel(permissions []permissionService.PermissionDTO) (riskLeve
 	}
 	return
 }
-func GetMediaPageByAnalystId(mediaType string, pageInfo page.PageInfo, analystId int) (list []mediaService.MediaDTO, err error) {
-	list, err = mediaService.GetMediaPageByAnalystId(mediaType, pageInfo, analystId)
+func GetMediaPageByAnalystId(mediaType string, pageInfo page.PageInfo, analystId int, mediaIds []int) (list []mediaService.MediaDTO, err error) {
+	list, err = mediaService.GetMediaPageByAnalystId(mediaType, pageInfo, analystId, mediaIds)
 	if err != nil {
 		err = exception.New(exception.GetAnalystMediaListFailed)
 		return

+ 45 - 5
service/report/report_service.go

@@ -361,7 +361,47 @@ func SearchMaxReportId(key string) (total int64, id int64) {
 
 	return reportService.SearchMaxReportId(key)
 }
+func RangeSearchByAnalyst(analystName string, userId int) (total int64, latestId int64, ids []int) {
+	var err error
+	//登录了需要校验风险等级,如果风险等级没做或者过期直接返回空,做了就筛选风险等级
+	userProfile, userErr := user.GetUserProfile(userId)
+	if userErr != nil {
+		if errors.Is(userErr, gorm.ErrRecordNotFound) {
+			err = exception.New(exception.TemplateUserNotFound)
+		} else {
+			err = exception.New(exception.TemplateUserFoundFailed)
+		}
+		logger.Error("分页查询报告列表失败:%v", err)
+		return
+	}
+	//获取产品风险等级
+	if userProfile.RiskLevel == user.RiskUnTest {
+		logger.Error("客户未做风险等级测评,mobile:%d", userProfile.Mobile)
+		return
+	}
+	if userProfile.RiskLevelStatus == user.RiskExpired {
+		logger.Error("客户风险等级已过期,mobile:%d", userProfile.Mobile)
+		return
+	}
+	mapping, mappingErr := permissionService.GetRiskMappingByCustomerRiskLevel(userProfile.RiskLevel)
+	if mappingErr != nil {
+		logger.Error("查询产品风险等级映射失败:%v", mappingErr)
+		return
+	}
+	var permissionList []permissionService.PermissionDTO
+	//获取所有设置风险等级的品种
+	permissionList, err = permissionService.GetPermissionListWithRisk()
+	permissionList = filterPermissionsByRisk(permissionList, mapping.ProductRiskLevel)
+	if len(permissionList) == 0 {
+		return
+	}
+	var filterPermissionIds []int
+	for _, permission := range permissionList {
+		filterPermissionIds = append(filterPermissionIds, permission.PermissionId)
+	}
 
+	return reportService.GetTotalPageCountByAnalyst(analystName, filterPermissionIds)
+}
 func RangeSearch(isLogin bool, userId int) (total int64, latestId int64, orgIds map[string][]int) {
 	var err error
 	//登录了需要校验风险等级,如果风险等级没做或者过期直接返回空,做了就筛选风险等级
@@ -446,11 +486,11 @@ func GetReportPage(pageInfo page.PageInfo, orgIds map[string][]int, searchAll bo
 	return
 }
 
-func GetTotalPageCountByAnalyst(analyst string) (total int64, latestId int64) {
-	return reportService.GetTotalPageCountByAnalyst(analyst)
-}
-func GetReportPageByAnalyst(pageInfo page.PageInfo, analyst string) (list []reportService.ReportDTO, err error) {
-	list, err = reportService.GetReportPageByAnalyst(pageInfo, analyst)
+//	func GetTotalPageCountByAnalyst(analyst string) (total int64, latestId int64) {
+//		return reportService.GetTotalPageCountByAnalyst(analyst)
+//	}
+func GetReportPageByAnalyst(pageInfo page.PageInfo, analyst string, reportIds []int) (list []reportService.ReportDTO, err error) {
+	list, err = reportService.GetReportPageByAnalyst(pageInfo, analyst, reportIds)
 	//并发获取研报的标签
 	var wg sync.WaitGroup
 	wg.Add(len(list))