dfa_handler.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. package llm
  2. import (
  3. "encoding/json"
  4. "eta/eta_api/models/llm/dfa"
  5. "eta/eta_api/utils"
  6. dfaUtils "eta/eta_api/utils/llm/algorithm/dfa"
  7. "fmt"
  8. "github.com/panjf2000/ants"
  9. "runtime"
  10. "strings"
  11. "sync"
  12. "time"
  13. )
  14. var (
  15. dafOnce sync.Once
  16. dafHandler *DAFHandler
  17. )
  18. type TagName string
  19. type SensitiveWord string
  20. type DAFHandler struct {
  21. dfa *dfaUtils.DFA
  22. WorkerPool *ants.Pool
  23. SensitiveWordMap map[SensitiveWord]TagName
  24. }
  25. type DAFService interface {
  26. // ReloadSensitiveWordMap 重新加载敏感词
  27. ReloadSensitiveWordMap()
  28. // GetTextTagLabels 处理文本标签
  29. GetTextTagLabels(task *dfa.ArticleDfaTagMapping)
  30. // RejectCallback 携程池拒绝策略
  31. RejectCallback(p interface{})
  32. // BatchSubmitTasks 批量提交任务
  33. BatchSubmitTasks(tasks []*dfa.ArticleDfaTagMapping)
  34. }
  35. func (d *DAFHandler) BatchSubmitTasks(tasks []*dfa.ArticleDfaTagMapping) {
  36. // 创建结果收集通道
  37. results := make(chan *dfa.ArticleDfaTagMapping, len(tasks))
  38. var wg sync.WaitGroup
  39. for _, task := range tasks {
  40. wg.Add(1)
  41. // 复制task以避免并发问题
  42. currentTask := *task
  43. time.Sleep(200 * time.Millisecond)
  44. err := d.WorkerPool.Submit(func() {
  45. defer wg.Done()
  46. d.GetTextTagLabels(&currentTask)
  47. results <- &currentTask
  48. })
  49. if err != nil {
  50. utils.FileLog.Error("提交任务失败:", err.Error())
  51. wg.Done()
  52. }
  53. }
  54. // 等待所有任务完成
  55. go func() {
  56. wg.Wait()
  57. close(results)
  58. }()
  59. // 收集所有结果
  60. var processedTasks []*dfa.ArticleDfaTagMapping
  61. for task := range results {
  62. processedTasks = append(processedTasks, task)
  63. }
  64. fmt.Println(fmt.Sprintf("处理完成,开始批量插入数据...,%d", len(processedTasks)))
  65. // 批量插入
  66. if len(processedTasks) > 0 {
  67. if err := d.BatchInsert(processedTasks); err != nil {
  68. utils.FileLog.Error("批量插入文章系统提示词结果失败:", err.Error())
  69. }
  70. }
  71. }
  72. func (d *DAFHandler) RejectCallback(p interface{}) {
  73. fmt.Printf(fmt.Sprintf("任务被拒绝: %v", p))
  74. }
  75. type SensitiveWordStruct struct {
  76. Title map[string]int
  77. Content map[string]int
  78. }
  79. func (d *DAFHandler) GetTextTagLabels(task *dfa.ArticleDfaTagMapping) {
  80. var TagResult = make(map[TagName]int)
  81. var titleResult, contentResult map[string]int
  82. if task.ArticleContent != "" {
  83. contentResult = d.dfa.Search(task.ArticleContent)
  84. for k, v := range contentResult {
  85. if tag, ok := d.SensitiveWordMap[SensitiveWord(k)]; ok {
  86. TagResult[tag] = TagResult[tag] + v
  87. }
  88. }
  89. }
  90. if task.ArticleTitle != "" {
  91. titleResult = d.dfa.Search(task.ArticleTitle)
  92. for k, v := range titleResult {
  93. if tag, ok := d.SensitiveWordMap[SensitiveWord(k)]; ok {
  94. TagResult[tag] = TagResult[tag] + v*10
  95. }
  96. }
  97. }
  98. item := SensitiveWordStruct{
  99. Content: contentResult,
  100. Title: titleResult,
  101. }
  102. for k, v := range TagResult {
  103. if v < 20 {
  104. delete(TagResult, k)
  105. }
  106. }
  107. var allValues []string
  108. for k, _ := range TagResult {
  109. allValues = append(allValues, string(k))
  110. }
  111. task.TagName = strings.Join(allValues, ",")
  112. str, _ := json.Marshal(item)
  113. task.SensitiveWords = string(str)
  114. }
  115. func (d *DAFHandler) ReloadSensitiveWordMap() {
  116. d.SensitiveWordMap = initSensitiveWords()
  117. }
  118. func (d *DAFHandler) BatchInsert(tasks []*dfa.ArticleDfaTagMapping) (err error) {
  119. return dfa.BatchInsertArticleDfaTagMapping(tasks)
  120. }
  121. func initSensitiveWords() (sensitiveWordMap map[SensitiveWord]TagName) {
  122. list, err := dfa.GetList()
  123. if err != nil {
  124. return nil
  125. }
  126. sensitiveWordMap = make(map[SensitiveWord]TagName)
  127. for _, item := range list {
  128. sensitiveWordMap[SensitiveWord(item.SensitiveWord)] = TagName(item.TagName)
  129. }
  130. return
  131. }
  132. func initThreadPool() (threadPool *ants.Pool) {
  133. cpuCores := runtime.NumCPU()
  134. threadPool, _ = ants.NewPool(cpuCores,
  135. ants.WithExpiryDuration(10*time.Second), // worker的过期时间
  136. ants.WithPreAlloc(true), // 预分配内存
  137. ants.WithMaxBlockingTasks(2*cpuCores), // 最大阻塞任务数
  138. ants.WithNonblocking(false), // 非阻塞模式
  139. ants.WithPanicHandler(dafHandler.RejectCallback), // 自定义panic处理
  140. )
  141. return
  142. }
  143. func RejectCallback(err error) {
  144. // 处理被拒绝的请求
  145. //请求被拒绝处理
  146. }
  147. func GetDAFHandlerInstance() DAFService {
  148. if dafHandler == nil {
  149. dafOnce.Do(func() {
  150. dafHandler = &DAFHandler{
  151. SensitiveWordMap: initSensitiveWords(),
  152. WorkerPool: initThreadPool(),
  153. }
  154. var sensitiveWords []string
  155. for k := range dafHandler.SensitiveWordMap {
  156. sensitiveWords = append(sensitiveWords, string(k))
  157. }
  158. dfa := dfaUtils.NewDFA()
  159. dfa.Build(sensitiveWords)
  160. dafHandler.dfa = dfa
  161. })
  162. }
  163. return dafHandler
  164. }