dfa_handler.go 5.3 KB


  1. package llm
  2. import (
  3. "encoding/json"
  4. "eta/eta_api/models/llm/dfa"
  5. "eta/eta_api/utils"
  6. dfaUtils "eta/eta_api/utils/llm/algorithm/dfa"
  7. "fmt"
  8. "github.com/panjf2000/ants"
  9. "runtime"
  10. "strings"
  11. "sync"
  12. "time"
  13. )
  14. var (
  15. dafOnce sync.Once
  16. dafHandler *DAFHandler
  17. )
  18. type TagName string
  19. type SensitiveWord string
  20. type DAFHandler struct {
  21. dfa *dfaUtils.DFA
  22. WorkerPool *ants.Pool
  23. SensitiveWordMap map[SensitiveWord]TagName
  24. }
  25. type DAFService interface {
  26. // ReloadSensitiveWordMap 重新加载敏感词
  27. ReloadSensitiveWordMap()
  28. // GetTextTagLabels 处理文本标签
  29. GetTextTagLabels(task *dfa.ArticleDfaTagMapping)
  30. // RejectCallback 携程池拒绝策略
  31. RejectCallback(p interface{})
  32. // BatchSubmitTasks 批量提交任务
  33. BatchSubmitTasks(tasks []*dfa.ArticleDfaTagMapping)
  34. FindTextTagLabels(text string) []string
  35. SubmitTask(tasks *dfa.ArticleDfaTagMapping)
  36. }
  37. func (d *DAFHandler) FindTextTagLabels(text string) (labels []string) {
  38. result := d.dfa.Search(text)
  39. if len(result) > 0 {
  40. for k, _ := range result {
  41. labels = append(labels, k)
  42. }
  43. }
  44. return
  45. }
  46. func (d *DAFHandler) SubmitTask(task *dfa.ArticleDfaTagMapping) {
  47. currentTask := task
  48. err := d.WorkerPool.Submit(func() {
  49. d.GetTextTagLabels(currentTask)
  50. insertErr := d.Insert(currentTask)
  51. if insertErr != nil {
  52. utils.FileLog.Error("插入dfa算法标签失败:", insertErr.Error())
  53. }
  54. })
  55. if err != nil {
  56. utils.FileLog.Error("提交任务失败:", err.Error())
  57. }
  58. }
  59. func (d *DAFHandler) BatchSubmitTasks(tasks []*dfa.ArticleDfaTagMapping) {
  60. // 创建结果收集通道
  61. results := make(chan *dfa.ArticleDfaTagMapping, len(tasks))
  62. var wg sync.WaitGroup
  63. for _, task := range tasks {
  64. wg.Add(1)
  65. // 复制task以避免并发问题
  66. currentTask := *task
  67. time.Sleep(200 * time.Millisecond)
  68. err := d.WorkerPool.Submit(func() {
  69. defer wg.Done()
  70. d.GetTextTagLabels(&currentTask)
  71. results <- &currentTask
  72. })
  73. if err != nil {
  74. utils.FileLog.Error("提交任务失败:", err.Error())
  75. wg.Done()
  76. }
  77. }
  78. // 等待所有任务完成
  79. go func() {
  80. wg.Wait()
  81. close(results)
  82. }()
  83. // 收集所有结果
  84. var processedTasks []*dfa.ArticleDfaTagMapping
  85. for task := range results {
  86. processedTasks = append(processedTasks, task)
  87. }
  88. fmt.Println(fmt.Sprintf("处理完成,开始批量插入数据...,%d", len(processedTasks)))
  89. // 批量插入
  90. if len(processedTasks) > 0 {
  91. if err := d.BatchInsert(processedTasks); err != nil {
  92. utils.FileLog.Error("批量插入文章系统提示词结果失败:", err.Error())
  93. }
  94. }
  95. }
  96. func (d *DAFHandler) RejectCallback(p interface{}) {
  97. fmt.Printf(fmt.Sprintf("任务被拒绝: %v", p))
  98. }
  99. type SensitiveWordStruct struct {
  100. Title map[string]int
  101. Content map[string]int
  102. }
  103. func (d *DAFHandler) GetTextTagLabels(task *dfa.ArticleDfaTagMapping) {
  104. var TagResult = make(map[TagName]int)
  105. var titleResult, contentResult map[string]int
  106. if task.ArticleContent != "" {
  107. contentResult = d.dfa.Search(task.ArticleContent)
  108. for k, v := range contentResult {
  109. if tag, ok := d.SensitiveWordMap[SensitiveWord(k)]; ok {
  110. TagResult[tag] = TagResult[tag] + v
  111. }
  112. }
  113. }
  114. if task.ArticleTitle != "" {
  115. titleResult = d.dfa.Search(task.ArticleTitle)
  116. for k, v := range titleResult {
  117. if tag, ok := d.SensitiveWordMap[SensitiveWord(k)]; ok {
  118. TagResult[tag] = TagResult[tag] + v*10
  119. }
  120. }
  121. }
  122. item := SensitiveWordStruct{
  123. Content: contentResult,
  124. Title: titleResult,
  125. }
  126. for k, v := range TagResult {
  127. if v < 20 {
  128. delete(TagResult, k)
  129. }
  130. }
  131. var allValues []string
  132. for k, _ := range TagResult {
  133. allValues = append(allValues, string(k))
  134. }
  135. task.TagName = strings.Join(allValues, ",")
  136. str, _ := json.Marshal(item)
  137. task.SensitiveWords = string(str)
  138. }
  139. func (d *DAFHandler) ReloadSensitiveWordMap() {
  140. d.SensitiveWordMap = initSensitiveWords()
  141. }
  142. func (d *DAFHandler) BatchInsert(tasks []*dfa.ArticleDfaTagMapping) (err error) {
  143. return dfa.BatchInsertArticleDfaTagMapping(tasks)
  144. }
  145. func (d *DAFHandler) Insert(task *dfa.ArticleDfaTagMapping) (err error) {
  146. return dfa.InsertArticleDfaTagMapping(task)
  147. }
  148. func initSensitiveWords() (sensitiveWordMap map[SensitiveWord]TagName) {
  149. list, err := dfa.GetList()
  150. if err != nil {
  151. return nil
  152. }
  153. sensitiveWordMap = make(map[SensitiveWord]TagName)
  154. for _, item := range list {
  155. sensitiveWordMap[SensitiveWord(item.SensitiveWord)] = TagName(item.TagName)
  156. }
  157. return
  158. }
  159. func initThreadPool() (threadPool *ants.Pool) {
  160. cpuCores := runtime.NumCPU()
  161. threadPool, _ = ants.NewPool(cpuCores,
  162. ants.WithExpiryDuration(10*time.Second), // worker的过期时间
  163. ants.WithPreAlloc(true), // 预分配内存
  164. ants.WithMaxBlockingTasks(2*cpuCores), // 最大阻塞任务数
  165. ants.WithNonblocking(false), // 非阻塞模式
  166. ants.WithPanicHandler(dafHandler.RejectCallback), // 自定义panic处理
  167. )
  168. return
  169. }
  170. func RejectCallback(err error) {
  171. // 处理被拒绝的请求
  172. //请求被拒绝处理
  173. }
  174. func GetDAFHandlerInstance() DAFService {
  175. if dafHandler == nil {
  176. dafOnce.Do(func() {
  177. dafHandler = &DAFHandler{
  178. SensitiveWordMap: initSensitiveWords(),
  179. WorkerPool: initThreadPool(),
  180. }
  181. var sensitiveWords []string
  182. for k := range dafHandler.SensitiveWordMap {
  183. sensitiveWords = append(sensitiveWords, string(k))
  184. }
  185. dfa := dfaUtils.NewDFA()
  186. dfa.Build(sensitiveWords)
  187. dafHandler.dfa = dfa
  188. })
  189. }
  190. return dafHandler
  191. }