report.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458
  1. package report
  2. import (
  3. "errors"
  4. logger "eta/eta_mini_ht_api/common/component/log"
  5. "eta/eta_mini_ht_api/common/utils/date"
  6. silce_utils "eta/eta_mini_ht_api/common/utils/silce"
  7. "eta/eta_mini_ht_api/models"
  8. permissionDao "eta/eta_mini_ht_api/models/config"
  9. "fmt"
  10. "gorm.io/gorm"
  11. "gorm.io/gorm/clause"
  12. "time"
  13. )
  14. type ReportStatus string
  15. type ReportSource string
  16. type SendStatus string
  17. const (
  18. SEND SendStatus = "SEND"
  19. UNSEND SendStatus = "UNSEND"
  20. SourceETA ReportSource = "ETA"
  21. SourceHT ReportSource = "HT"
  22. StatusPublish = "PUBLISH"
  23. StatusUnPublish ReportStatus = "UNPUBLISH"
  24. StatusDeleted ReportStatus = "DELETED"
  25. MaxBatchNum = 1000
  26. CommonColumns = "id,org_id,author,abstract,title,source,cover_src,published_time,status,plate_name,classify_id"
  27. taskColumns = "id,author,published_time,status,plate_name"
  28. )
  29. type Report struct {
  30. ID int `gorm:"column:id;primary_key;comment:'id'" json:"id"`
  31. OrgID int `gorm:"column:org_id;comment:'原始id'" json:"org_id"`
  32. Source ReportSource `gorm:"column:source;comment:'研报来源1:eta 2:海通'" json:"source"`
  33. Title string `gorm:"column:title;comment:'标题'" json:"title"`
  34. Abstract string `gorm:"column:abstract;comment:'摘要'" json:"abstract"`
  35. ClassifyId int `gorm:"column:classify_id"`
  36. PlateName string `gorm:"column:plate_name;comment:'板块'" json:"plate_name"`
  37. Author string `gorm:"column:author;comment:'作者'" json:"author"`
  38. CoverSrc int `gorm:"column:cover_src;comment:'封面图片'" json:"cover_src"`
  39. Status ReportStatus `gorm:"column:status;comment:'报告状态 publish:发布 unpublish:未发布" json:"status"`
  40. SendStatus SendStatus `gorm:"column:send_status;comment:'发送状态'" json:"send_status"`
  41. PublishedTime string `gorm:"column:published_time;comment:'发布时间'" json:"published_time"`
  42. CreatedTime time.Time `gorm:"column:created_time;comment:'创建时间'" json:"created_time"`
  43. UpdatedTime time.Time `gorm:"column:updated_time;comment:'修改时间'" json:"updated_time"`
  44. }
  45. func GetOrgIdsByPlateNames(plateName []string) (ids []int, err error) {
  46. db := models.Main()
  47. err = db.Model(&Report{}).Select("distinct org_id").Where(" source ='HT' and plate_name in ? ", plateName).Scan(&ids).Error
  48. return
  49. }
  50. func BatchInsertReport(list *[]Report) (err error) {
  51. db := models.Main()
  52. //手动事务
  53. tx := db.Begin()
  54. err = tx.CreateInBatches(list, MaxBatchNum).Error
  55. if err != nil {
  56. logger.Error("批量插入研报失败:%v", err)
  57. tx.Rollback()
  58. return
  59. }
  60. tx.Commit()
  61. return nil
  62. }
  63. func InsertOrUpdateReport(list []Report, source string) (result []Report, err error) {
  64. var orgIds []int
  65. //现有的作者名字
  66. for _, report := range list {
  67. orgIds = append(orgIds, report.OrgID)
  68. }
  69. orgIds = silce_utils.RemoveDuplicates(orgIds)
  70. db := models.Main()
  71. //数据库找到作者的名字
  72. if err != nil {
  73. logger.Error("查询研报失败:%v", err)
  74. }
  75. for i := 0; i < len(list); i++ {
  76. var dbReport Report
  77. err = db.Model(&Report{}).Select(CommonColumns).Where("org_id = ? and source =? ", list[i].OrgID, source).First(&dbReport).Error
  78. if err != nil {
  79. logger.Error("查询数据库失败研报失败:%v,发布时间不更新", err)
  80. }
  81. if dbReport.SendStatus == SEND {
  82. list[i].PublishedTime = dbReport.PublishedTime
  83. }
  84. }
  85. //手动事务
  86. tx := db.Begin()
  87. OnConflictFunc := clause.OnConflict{
  88. Columns: []clause.Column{{Name: "org_id"}, {Name: "source"}},
  89. DoUpdates: clause.AssignmentColumns([]string{"abstract", "title", "author", "published_time", "status", "classify_id"}),
  90. }
  91. // 执行批量插入或更新操作
  92. err = tx.Clauses(OnConflictFunc).Create(&list).Error
  93. //if deleteAuthors != nil {
  94. // for _, deleteAuthor := range deleteAuthors {
  95. // err = tx.Where("org_id in ? and source =? and author=?", orgIds, source, deleteAuthor).Delete(&Report{}).Error
  96. // if err != nil {
  97. // logger.Error("批量删除研报失败:%v", err)
  98. // tx.Rollback()
  99. // return
  100. // }
  101. // }
  102. //}
  103. if err != nil {
  104. logger.Error("批量插入或更新研报失败:%v", err)
  105. tx.Rollback()
  106. return
  107. }
  108. tx.Commit()
  109. err = db.Select(CommonColumns).Where("org_id in ? and source =? ", orgIds, source).Find(&result).Error
  110. if err != nil {
  111. logger.Error("查询更新的研报数据失败:%v", err)
  112. err = nil
  113. }
  114. return
  115. }
  116. func (t *Report) BeforeCreate(_ *gorm.DB) (err error) {
  117. t.CreatedTime = time.Now()
  118. return
  119. }
  120. func GetAuthorByOrgId(orgId int, source string) (names []string, err error) {
  121. db := models.Main()
  122. err = db.Model(&Report{}).Select("author").Where("org_id = ? and source =? ", orgId, source).Scan(&names).Error
  123. return
  124. }
  125. func GetReportByOrgId(orgId int, source string) (reports Report, err error) {
  126. db := models.Main()
  127. err = db.Select(CommonColumns).Where("org_id = ? and source =? ", orgId, source).Find(&reports).Error
  128. return
  129. }
  130. func GetReportById(reportId int) (report Report, err error) {
  131. db := models.Main()
  132. err = db.Select(CommonColumns).Where("id = ?", reportId).First(&report).Error
  133. if err != nil {
  134. logger.Error("查询报告失败:%v", err)
  135. }
  136. return
  137. }
  138. func GetLatestReportIdBySource(source ReportSource) (id int, err error) {
  139. sql := "select IFNULL(max(org_id),0) from reports where source = ?"
  140. err = DoSql(sql, &id, source)
  141. return
  142. }
  143. func DoSql(sql string, result interface{}, values ...interface{}) (err error) {
  144. db := models.Main()
  145. err = db.Raw(sql, values...).Scan(result).Error
  146. if err != nil {
  147. logger.Error("执行sql[%v]失败:%v", sql, err)
  148. }
  149. return
  150. }
  151. func GetListOrderByCondition(week bool, column string, limit int, order models.Order) (reports []Report, err error) {
  152. db := models.Main()
  153. if week {
  154. current := time.Now()
  155. begin := date.GetBeginOfTheWeek(current, time.Monday).Format(time.DateOnly)
  156. end := current.Format(time.DateOnly)
  157. err = db.Select(CommonColumns).Where("status = ?", StatusPublish).Where(" STR_TO_DATE(published_time,'%Y-%m-%d') BETWEEN ? AND ?", begin, end).Order(fmt.Sprintf("%s %s", column, order)).Limit(limit).Find(&reports).Error
  158. } else {
  159. err = db.Select(CommonColumns).Where("status = ?", StatusPublish).Order(fmt.Sprintf("%s %s", column, order)).Limit(limit).Find(&reports).Error
  160. }
  161. if err != nil {
  162. logger.Error("查询报告列表失败:%v", err)
  163. }
  164. if reports == nil {
  165. return []Report{}, nil
  166. }
  167. return
  168. }
  169. func GetListByCondition[T any](column string, values []T) (reports []Report, err error) {
  170. if len(values) == 0 {
  171. logger.Error("查询条件的值不能为空")
  172. return []Report{}, nil
  173. }
  174. db := models.Main()
  175. err = db.Select(CommonColumns).Where(fmt.Sprintf("%s in ?", column), values).Find(&reports).Error
  176. if err != nil {
  177. logger.Error("查询报告列表失败:%v", err)
  178. }
  179. if reports == nil {
  180. return []Report{}, nil
  181. }
  182. return
  183. }
  184. func GetReportIdListByOrgIds(orgIds map[string][]int) (ids []int, err error) {
  185. db := models.Main()
  186. if len(orgIds["ETA"]) == 0 && len(orgIds["HT"]) == 0 {
  187. return
  188. }
  189. if len(orgIds["ETA"]) == 0 {
  190. err = db.Model(&Report{}).Select("id").Where("status = ?", StatusPublish).Where(" source='HT' and org_id in ?", orgIds["HT"]).Scan(&ids).Error
  191. if err != nil {
  192. logger.Error("获取报告ID列表失败:%v", err)
  193. return
  194. }
  195. return
  196. }
  197. if len(orgIds["HT"]) == 0 {
  198. err = db.Model(&Report{}).Select("id").Where("status = ?", StatusPublish).Where(" source='ETA' and org_id in ?", orgIds["ETA"]).Scan(&ids).Error
  199. if err != nil {
  200. logger.Error("获取报告ID列表失败:%v", err)
  201. return
  202. }
  203. return
  204. }
  205. err = db.Model(&Report{}).Select("id").Where("status = ?", StatusPublish).Where(" source='ETA' and org_id in ?", orgIds["ETA"]).Or("source='HT' and org_id in ?", orgIds["HT"]).Scan(&ids).Error
  206. if err != nil {
  207. logger.Error("获取报告ID列表失败:%v", err)
  208. return
  209. }
  210. return
  211. }
  212. func GetMaxIdByPermissionIds(orgIds map[string][]int, disCardIds []int) (total int64, maxId int64, err error) {
  213. db := models.Main()
  214. if len(orgIds["ETA"]) == 0 && len(orgIds["HT"]) == 0 {
  215. maxId = 0
  216. total = 0
  217. return
  218. }
  219. countQuery := db.Model(&Report{}).Select("count(*)").Where("status = ? ", StatusPublish)
  220. maxQuery := db.Model(&Report{}).Select("MAX(id) id").Where("status = ? ", StatusPublish)
  221. if len(disCardIds) > 0 {
  222. countQuery.Where("id not in ?", disCardIds)
  223. maxQuery.Where("id not in ?", disCardIds)
  224. }
  225. if len(orgIds["ETA"]) == 0 {
  226. err = countQuery.Where(" source='HT' and org_id in ?", orgIds["HT"]).Scan(&total).Error
  227. if err != nil {
  228. logger.Error("获取记录条数失败:%v", err)
  229. return
  230. }
  231. err = maxQuery.Where(" source='HT' and org_id in ? ", orgIds["HT"]).Scan(&maxId).Error
  232. if err != nil {
  233. logger.Error("获取报告最大ID失败:%v", err)
  234. return
  235. }
  236. return
  237. }
  238. if len(orgIds["HT"]) == 0 {
  239. err = countQuery.Where(" source='ETA' and org_id in ? ", orgIds["ETA"]).Scan(&total).Error
  240. if err != nil {
  241. logger.Error("获取报告最大ID失败:%v", err)
  242. return
  243. }
  244. err = maxQuery.Where(" source='ETA' and org_id in ? ", orgIds["ETA"]).Scan(&maxId).Error
  245. if err != nil {
  246. logger.Error("获取报告最大ID失败:%v", err)
  247. return
  248. }
  249. return
  250. }
  251. err = countQuery.Where(" source='ETA' and org_id in ? ", orgIds["ETA"]).Or("source='HT' and org_id in ?", orgIds["HT"]).Scan(&total).Error
  252. if err != nil {
  253. logger.Error("获取报告最大ID失败:%v", err)
  254. return
  255. }
  256. err = maxQuery.Where(" source='ETA' and org_id in ?", orgIds["ETA"]).Or("source='HT' and org_id in ?", orgIds["HT"]).Scan(&maxId).Error
  257. if err != nil {
  258. logger.Error("获取报告最大ID失败:%v", err)
  259. return
  260. }
  261. return
  262. }
  263. func GetTotalPageCount() (total int64, latestId int64, err error) {
  264. db := models.Main()
  265. err = db.Model(&Report{}).Select("MAX(id) id").Where("status = ?", StatusPublish).Scan(&latestId).Error
  266. if err != nil {
  267. logger.Error("获取最大id失败:%v", err)
  268. }
  269. err = db.Model(&Report{}).Where("status = ?", StatusPublish).Count(&total).Error
  270. if err != nil {
  271. logger.Error("统计报告数量失败:%v", err)
  272. }
  273. return
  274. }
  275. func GetReportsByAnalyst(analyst string) (ids []int, err error) {
  276. db := models.Main()
  277. err = db.Model(&Report{}).Select("id").Where("status = ?", StatusPublish).Where("author like ?", "%"+analyst+"%").Scan(&ids).Error
  278. if err != nil {
  279. return
  280. }
  281. return
  282. }
  283. func GetTotalPageCountByAnalyst(analyst string) (total int64, latestId int64) {
  284. db := models.Main()
  285. err := db.Model(&Report{}).Where("status = ?", StatusPublish).Where("author like ?", "%"+analyst+"%").Count(&total).Error
  286. if err != nil {
  287. return
  288. }
  289. err = db.Model(&Report{}).Select("Max(id)").Where("status = ?", StatusPublish).Where("author like ?", "%"+analyst+"%").Scan(&latestId).Error
  290. if err != nil {
  291. return
  292. }
  293. return
  294. }
  295. func GetReportPage(latestId int64, limit int, offset int) (list []Report, err error) {
  296. if latestId < 0 {
  297. err = errors.New("非法的id参数")
  298. logger.Error("非法的id参数:%d", latestId)
  299. return
  300. }
  301. if limit <= 0 {
  302. err = errors.New("非法的limit参数")
  303. logger.Error("非法的limit参数:%d", limit)
  304. }
  305. db := models.Main()
  306. err = db.Select(CommonColumns).Where("id<= ?", latestId).Where("status = ?", StatusPublish).Order("published_time desc").Limit(limit).Offset(offset).Find(&list).Error
  307. return
  308. }
  309. func GetReportPageByOrgIds(latestId int64, limit int, offset int, orgIds map[string][]int, discardIds []int) (list []Report, err error) {
  310. if latestId < 0 {
  311. err = errors.New("非法的id参数")
  312. logger.Error("非法的id参数:%d", latestId)
  313. return
  314. }
  315. if limit <= 0 {
  316. err = errors.New("非法的limit参数")
  317. logger.Error("非法的limit参数:%d", limit)
  318. }
  319. db := models.Main()
  320. listQuery := db.Model(&Report{}).Select(CommonColumns).Where("status = ? ", StatusPublish).Where("id<= ?", latestId)
  321. if len(discardIds) > 0 {
  322. listQuery.Where("id not in ?", discardIds)
  323. }
  324. if len(orgIds["ETA"]) == 0 {
  325. err = listQuery.Where(" source='HT' and org_id in ?", orgIds["HT"]).Order("published_time desc").Limit(limit).Offset(offset).Find(&list).Error
  326. return
  327. }
  328. if len(orgIds["HT"]) == 0 {
  329. err = listQuery.Where("source='ETA' and org_id in ?", orgIds["ETA"]).Order("published_time desc").Limit(limit).Offset(offset).Find(&list).Error
  330. return
  331. }
  332. err = listQuery.Where("(source='ETA' and org_id in ? ) or (source='HT' and org_id in ?) ", orgIds["ETA"], orgIds["HT"]).Order("published_time desc").Limit(limit).Offset(offset).Find(&list).Error
  333. return
  334. }
  335. func GetNewReportByPublishTime(time time.Time) (list []Report) {
  336. db := models.Main()
  337. err := db.Select(taskColumns).Where("status = ?", StatusPublish).Where("published_time >= ?", time).Order("published_time desc").Find(&list).Error
  338. if err != nil {
  339. logger.Error("查询新发布的报告列表失败:%v", err)
  340. }
  341. return
  342. }
  343. func GetReportPageByAnalyst(latestId int64, limit int, offset int, analyst string, reportIds []int) (list []Report, err error) {
  344. if len(reportIds) == 0 {
  345. logger.Info("reportIds为空")
  346. return
  347. }
  348. if latestId < 0 {
  349. err = errors.New("非法的id参数")
  350. logger.Error("非法的id参数:%d", latestId)
  351. return
  352. }
  353. if limit <= 0 {
  354. err = errors.New("非法的limit参数")
  355. logger.Error("非法的limit参数:%d", limit)
  356. }
  357. db := models.Main()
  358. err = db.Select(CommonColumns).Where("status = ?", StatusPublish).Where("id<= ? and id in ? and author like ?", latestId, reportIds, "%"+analyst+"%").Order("published_time desc").Limit(limit).Offset(offset).Find(&list).Error
  359. return
  360. }
  361. func GetETAReportIdsByClassifyIds(classifyIds []int) (orgIds []int, err error) {
  362. db := models.Main()
  363. err = db.Model(&Report{}).Select("DISTINCT org_id").Where("source=? and classify_id in (?)", SourceETA, classifyIds).Scan(&orgIds).Error
  364. return
  365. }
  366. func GetETAReportByClassifyIds(classifyIds []int) (reportList []Report, err error) {
  367. db := models.Main()
  368. err = db.Model(&Report{}).Select(CommonColumns).Where("source=? and classify_id in (?)", SourceETA, classifyIds).Find(&reportList).Error
  369. return
  370. }
  371. func GetReportClassifyById(id int) (classifyId int, err error) {
  372. db := models.Main()
  373. err = db.Model(&Report{}).
  374. Select("classify_id").
  375. Where("org_id=? and source ='ETA' ", id).Scan(&classifyId).Error
  376. return
  377. }
  378. func CountPermissionWeight(ids []int) (list []permissionDao.PermissionWeight, err error) {
  379. db := models.Main()
  380. sql := `select a.permission_id,count(*) from
  381. (select p.permission_id from reports r LEFT JOIN(SELECT *from permissions) p on p.name=r.plate_name where r.id in (?) and r.source='HT'
  382. UNION ALL
  383. select cpm.permission_id from reports r LEFT JOIN (SELECT classify_id,permission_id FROM permission_classify_mapping) cpm on cpm.classify_id=r.classify_id where r.id in (?) and r.source='ETA') a GROUP BY a.permission_id`
  384. err = db.Raw(sql, ids, ids).Find(&list).Error
  385. return
  386. }
  387. func GetHiddenReportIds(classifyIds []int, plateNames []string) (reportIds []int, err error) {
  388. db := models.Main()
  389. exc := db.Model(&Report{}).Select("id")
  390. if len(classifyIds) > 0 {
  391. exc.Where("(source='ETA' and classify_id in ?)", classifyIds)
  392. }
  393. if len(plateNames) > 0 {
  394. exc.Or("(source='HT' and plate_name in ?)", plateNames)
  395. }
  396. err = exc.Scan(&reportIds).Error
  397. return
  398. }
  399. func FilterReportIds(ids []int) (total int64, reportIds []int, err error) {
  400. db := models.Main()
  401. etaSubQuery := `
  402. select a.classify_id
  403. from (
  404. select classify_id, IFNULL( GROUP_CONCAT(permissions.risk_level SEPARATOR ','),'') as risks
  405. from permission_classify_mapping
  406. left join permissions on permissions.permission_id = permission_classify_mapping.permission_id
  407. group by classify_id
  408. ) a
  409. where a.risks = ''
  410. `
  411. htSubQuery := `
  412. select a.name from (SELECT name,IFNULL(risk_level,'') as risk_level FROM permissions) a WHERE a.risk_level =''
  413. `
  414. err = db.Model(&Report{}).Select("id").
  415. Where("id in ?", ids).
  416. Where("(classify_id not in (?) and source=?) or (plate_name not in (?) and source=?)", gorm.Expr(etaSubQuery), SourceETA, gorm.Expr(htSubQuery), SourceHT).
  417. Where("Status = ?", StatusPublish).
  418. Scan(&reportIds).Error
  419. if err != nil {
  420. logger.Error("查询过滤后的报告失败: %v", err)
  421. }
  422. total = int64(len(reportIds))
  423. return
  424. }
  425. func GetReportListById(ids []int) (reports []Report, err error) {
  426. db := models.Main()
  427. err = db.Select(CommonColumns).
  428. Where("id in ?", ids).Find(&reports).Error
  429. return
  430. }
  431. func DeleteReport(id int) error {
  432. db := models.Main()
  433. return db.Model(&Report{}).Where("id=?", id).Update("status", StatusDeleted).Error
  434. }