package llm import ( "encoding/json" "eta/eta_api/models/llm/dfa" "eta/eta_api/utils" dfaUtils "eta/eta_api/utils/llm/algorithm/dfa" "fmt" "github.com/panjf2000/ants" "runtime" "strings" "sync" "time" ) var ( dafOnce sync.Once dafHandler *DAFHandler ) type TagName string type SensitiveWord string type DAFHandler struct { dfa *dfaUtils.DFA WorkerPool *ants.Pool SensitiveWordMap map[SensitiveWord]TagName } type DAFService interface { // ReloadSensitiveWordMap 重新加载敏感词 ReloadSensitiveWordMap() // GetTextTagLabels 处理文本标签 GetTextTagLabels(task *dfa.ArticleDfaTagMapping) // RejectCallback 携程池拒绝策略 RejectCallback(p interface{}) // BatchSubmitTasks 批量提交任务 BatchSubmitTasks(tasks []*dfa.ArticleDfaTagMapping) } func (d *DAFHandler) BatchSubmitTasks(tasks []*dfa.ArticleDfaTagMapping) { // 创建结果收集通道 results := make(chan *dfa.ArticleDfaTagMapping, len(tasks)) var wg sync.WaitGroup for _, task := range tasks { wg.Add(1) // 复制task以避免并发问题 currentTask := *task time.Sleep(200 * time.Millisecond) err := d.WorkerPool.Submit(func() { defer wg.Done() d.GetTextTagLabels(¤tTask) results <- ¤tTask }) if err != nil { utils.FileLog.Error("提交任务失败:", err.Error()) wg.Done() } } // 等待所有任务完成 go func() { wg.Wait() close(results) }() // 收集所有结果 var processedTasks []*dfa.ArticleDfaTagMapping for task := range results { processedTasks = append(processedTasks, task) } fmt.Println(fmt.Sprintf("处理完成,开始批量插入数据...,%d", len(processedTasks))) // 批量插入 if len(processedTasks) > 0 { if err := d.BatchInsert(processedTasks); err != nil { utils.FileLog.Error("批量插入文章系统提示词结果失败:", err.Error()) } } } func (d *DAFHandler) RejectCallback(p interface{}) { fmt.Printf(fmt.Sprintf("任务被拒绝: %v", p)) } type SensitiveWordStruct struct { Title map[string]int Content map[string]int } func (d *DAFHandler) GetTextTagLabels(task *dfa.ArticleDfaTagMapping) { var TagResult = make(map[TagName]int) var titleResult, contentResult map[string]int if task.ArticleContent != "" { contentResult = d.dfa.Search(task.ArticleContent) for k, v := range contentResult { if tag, ok := d.SensitiveWordMap[SensitiveWord(k)]; ok { TagResult[tag] = TagResult[tag] + v } } } if task.ArticleTitle != "" { titleResult = d.dfa.Search(task.ArticleTitle) for k, v := range titleResult { if tag, ok := d.SensitiveWordMap[SensitiveWord(k)]; ok { TagResult[tag] = TagResult[tag] + v*10 } } } item := SensitiveWordStruct{ Content: contentResult, Title: titleResult, } for k, v := range TagResult { if v < 20 { delete(TagResult, k) } } var allValues []string for k, _ := range TagResult { allValues = append(allValues, string(k)) } task.TagName = strings.Join(allValues, ",") str, _ := json.Marshal(item) task.SensitiveWords = string(str) } func (d *DAFHandler) ReloadSensitiveWordMap() { d.SensitiveWordMap = initSensitiveWords() } func (d *DAFHandler) BatchInsert(tasks []*dfa.ArticleDfaTagMapping) (err error) { return dfa.BatchInsertArticleDfaTagMapping(tasks) } func initSensitiveWords() (sensitiveWordMap map[SensitiveWord]TagName) { list, err := dfa.GetList() if err != nil { return nil } sensitiveWordMap = make(map[SensitiveWord]TagName) for _, item := range list { sensitiveWordMap[SensitiveWord(item.SensitiveWord)] = TagName(item.TagName) } return } func initThreadPool() (threadPool *ants.Pool) { cpuCores := runtime.NumCPU() threadPool, _ = ants.NewPool(cpuCores, ants.WithExpiryDuration(10*time.Second), // worker的过期时间 ants.WithPreAlloc(true), // 预分配内存 ants.WithMaxBlockingTasks(2*cpuCores), // 最大阻塞任务数 ants.WithNonblocking(false), // 非阻塞模式 ants.WithPanicHandler(dafHandler.RejectCallback), // 自定义panic处理 ) return } func RejectCallback(err error) { // 处理被拒绝的请求 //请求被拒绝处理 } func GetDAFHandlerInstance() DAFService { if dafHandler == nil { dafOnce.Do(func() { dafHandler = &DAFHandler{ SensitiveWordMap: initSensitiveWords(), WorkerPool: initThreadPool(), } var sensitiveWords []string for k := range dafHandler.SensitiveWordMap { sensitiveWords = append(sensitiveWords, string(k)) } dfa := dfaUtils.NewDFA() dfa.Build(sensitiveWords) dafHandler.dfa = dfa }) } return dafHandler }