dfa_handler.go 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. package llm
  2. import (
  3. "encoding/json"
  4. "eta/eta_api/models/llm/dfa"
  5. "eta/eta_api/utils"
  6. dfaUtils "eta/eta_api/utils/llm/algorithm/dfa"
  7. "fmt"
  8. "github.com/panjf2000/ants"
  9. "runtime"
  10. "strings"
  11. "sync"
  12. "time"
  13. )
  14. var (
  15. dafOnce sync.Once
  16. dafHandler *DAFHandler
  17. )
  18. type TagName string
  19. type SensitiveWord string
  20. type DAFHandler struct {
  21. dfa *dfaUtils.DFA
  22. WorkerPool *ants.Pool
  23. SensitiveWordMap map[SensitiveWord]TagName
  24. }
  25. type DAFService interface {
  26. // ReloadSensitiveWordMap 重新加载敏感词
  27. ReloadSensitiveWordMap()
  28. // GetTextTagLabels 处理文本标签
  29. GetTextTagLabels(task *dfa.ArticleDfaTagMapping)
  30. // RejectCallback 携程池拒绝策略
  31. RejectCallback(p interface{})
  32. // BatchSubmitTasks 批量提交任务
  33. BatchSubmitTasks(tasks []*dfa.ArticleDfaTagMapping)
  34. FindTextTagLabels(text string) []string
  35. }
  36. func (d *DAFHandler) FindTextTagLabels(text string) (labels []string) {
  37. result := d.dfa.Search(text)
  38. if len(result) > 0 {
  39. for k, _ := range result {
  40. labels = append(labels, k)
  41. }
  42. }
  43. return
  44. }
  45. func (d *DAFHandler) BatchSubmitTasks(tasks []*dfa.ArticleDfaTagMapping) {
  46. // 创建结果收集通道
  47. results := make(chan *dfa.ArticleDfaTagMapping, len(tasks))
  48. var wg sync.WaitGroup
  49. for _, task := range tasks {
  50. wg.Add(1)
  51. // 复制task以避免并发问题
  52. currentTask := *task
  53. time.Sleep(200 * time.Millisecond)
  54. err := d.WorkerPool.Submit(func() {
  55. defer wg.Done()
  56. d.GetTextTagLabels(&currentTask)
  57. results <- &currentTask
  58. })
  59. if err != nil {
  60. utils.FileLog.Error("提交任务失败:", err.Error())
  61. wg.Done()
  62. }
  63. }
  64. // 等待所有任务完成
  65. go func() {
  66. wg.Wait()
  67. close(results)
  68. }()
  69. // 收集所有结果
  70. var processedTasks []*dfa.ArticleDfaTagMapping
  71. for task := range results {
  72. processedTasks = append(processedTasks, task)
  73. }
  74. fmt.Println(fmt.Sprintf("处理完成,开始批量插入数据...,%d", len(processedTasks)))
  75. // 批量插入
  76. if len(processedTasks) > 0 {
  77. if err := d.BatchInsert(processedTasks); err != nil {
  78. utils.FileLog.Error("批量插入文章系统提示词结果失败:", err.Error())
  79. }
  80. }
  81. }
  82. func (d *DAFHandler) RejectCallback(p interface{}) {
  83. fmt.Printf(fmt.Sprintf("任务被拒绝: %v", p))
  84. }
  85. type SensitiveWordStruct struct {
  86. Title map[string]int
  87. Content map[string]int
  88. }
  89. func (d *DAFHandler) GetTextTagLabels(task *dfa.ArticleDfaTagMapping) {
  90. var TagResult = make(map[TagName]int)
  91. var titleResult, contentResult map[string]int
  92. if task.ArticleContent != "" {
  93. contentResult = d.dfa.Search(task.ArticleContent)
  94. for k, v := range contentResult {
  95. if tag, ok := d.SensitiveWordMap[SensitiveWord(k)]; ok {
  96. TagResult[tag] = TagResult[tag] + v
  97. }
  98. }
  99. }
  100. if task.ArticleTitle != "" {
  101. titleResult = d.dfa.Search(task.ArticleTitle)
  102. for k, v := range titleResult {
  103. if tag, ok := d.SensitiveWordMap[SensitiveWord(k)]; ok {
  104. TagResult[tag] = TagResult[tag] + v*10
  105. }
  106. }
  107. }
  108. item := SensitiveWordStruct{
  109. Content: contentResult,
  110. Title: titleResult,
  111. }
  112. for k, v := range TagResult {
  113. if v < 20 {
  114. delete(TagResult, k)
  115. }
  116. }
  117. var allValues []string
  118. for k, _ := range TagResult {
  119. allValues = append(allValues, string(k))
  120. }
  121. task.TagName = strings.Join(allValues, ",")
  122. str, _ := json.Marshal(item)
  123. task.SensitiveWords = string(str)
  124. }
  125. func (d *DAFHandler) ReloadSensitiveWordMap() {
  126. d.SensitiveWordMap = initSensitiveWords()
  127. }
  128. func (d *DAFHandler) BatchInsert(tasks []*dfa.ArticleDfaTagMapping) (err error) {
  129. return dfa.BatchInsertArticleDfaTagMapping(tasks)
  130. }
  131. func initSensitiveWords() (sensitiveWordMap map[SensitiveWord]TagName) {
  132. list, err := dfa.GetList()
  133. if err != nil {
  134. return nil
  135. }
  136. sensitiveWordMap = make(map[SensitiveWord]TagName)
  137. for _, item := range list {
  138. sensitiveWordMap[SensitiveWord(item.SensitiveWord)] = TagName(item.TagName)
  139. }
  140. return
  141. }
  142. func initThreadPool() (threadPool *ants.Pool) {
  143. cpuCores := runtime.NumCPU()
  144. threadPool, _ = ants.NewPool(cpuCores,
  145. ants.WithExpiryDuration(10*time.Second), // worker的过期时间
  146. ants.WithPreAlloc(true), // 预分配内存
  147. ants.WithMaxBlockingTasks(2*cpuCores), // 最大阻塞任务数
  148. ants.WithNonblocking(false), // 非阻塞模式
  149. ants.WithPanicHandler(dafHandler.RejectCallback), // 自定义panic处理
  150. )
  151. return
  152. }
  153. func RejectCallback(err error) {
  154. // 处理被拒绝的请求
  155. //请求被拒绝处理
  156. }
  157. func GetDAFHandlerInstance() DAFService {
  158. if dafHandler == nil {
  159. dafOnce.Do(func() {
  160. dafHandler = &DAFHandler{
  161. SensitiveWordMap: initSensitiveWords(),
  162. WorkerPool: initThreadPool(),
  163. }
  164. var sensitiveWords []string
  165. for k := range dafHandler.SensitiveWordMap {
  166. sensitiveWords = append(sensitiveWords, string(k))
  167. }
  168. dfa := dfaUtils.NewDFA()
  169. dfa.Build(sensitiveWords)
  170. dafHandler.dfa = dfa
  171. })
  172. }
  173. return dafHandler
  174. }