123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178 |
- package llm
- import (
- "encoding/json"
- "eta/eta_api/models/llm/dfa"
- "eta/eta_api/utils"
- dfaUtils "eta/eta_api/utils/llm/algorithm/dfa"
- "fmt"
- "github.com/panjf2000/ants"
- "runtime"
- "strings"
- "sync"
- "time"
- )
- var (
- dafOnce sync.Once
- dafHandler *DAFHandler
- )
- type TagName string
- type SensitiveWord string
- type DAFHandler struct {
- dfa *dfaUtils.DFA
- WorkerPool *ants.Pool
- SensitiveWordMap map[SensitiveWord]TagName
- }
- type DAFService interface {
- // ReloadSensitiveWordMap 重新加载敏感词
- ReloadSensitiveWordMap()
- // GetTextTagLabels 处理文本标签
- GetTextTagLabels(task *dfa.ArticleDfaTagMapping)
- // RejectCallback 携程池拒绝策略
- RejectCallback(p interface{})
- // BatchSubmitTasks 批量提交任务
- BatchSubmitTasks(tasks []*dfa.ArticleDfaTagMapping)
- }
- func (d *DAFHandler) BatchSubmitTasks(tasks []*dfa.ArticleDfaTagMapping) {
- // 创建结果收集通道
- results := make(chan *dfa.ArticleDfaTagMapping, len(tasks))
- var wg sync.WaitGroup
- for _, task := range tasks {
- wg.Add(1)
- // 复制task以避免并发问题
- currentTask := *task
- time.Sleep(200 * time.Millisecond)
- err := d.WorkerPool.Submit(func() {
- defer wg.Done()
- d.GetTextTagLabels(¤tTask)
- results <- ¤tTask
- })
- if err != nil {
- utils.FileLog.Error("提交任务失败:", err.Error())
- wg.Done()
- }
- }
- // 等待所有任务完成
- go func() {
- wg.Wait()
- close(results)
- }()
- // 收集所有结果
- var processedTasks []*dfa.ArticleDfaTagMapping
- for task := range results {
- processedTasks = append(processedTasks, task)
- }
- fmt.Println(fmt.Sprintf("处理完成,开始批量插入数据...,%d", len(processedTasks)))
- // 批量插入
- if len(processedTasks) > 0 {
- if err := d.BatchInsert(processedTasks); err != nil {
- utils.FileLog.Error("批量插入文章系统提示词结果失败:", err.Error())
- }
- }
- }
- func (d *DAFHandler) RejectCallback(p interface{}) {
- fmt.Printf(fmt.Sprintf("任务被拒绝: %v", p))
- }
- type SensitiveWordStruct struct {
- Title map[string]int
- Content map[string]int
- }
- func (d *DAFHandler) GetTextTagLabels(task *dfa.ArticleDfaTagMapping) {
- var TagResult = make(map[TagName]int)
- var titleResult, contentResult map[string]int
- if task.ArticleContent != "" {
- contentResult = d.dfa.Search(task.ArticleContent)
- for k, v := range contentResult {
- if tag, ok := d.SensitiveWordMap[SensitiveWord(k)]; ok {
- TagResult[tag] = TagResult[tag] + v
- }
- }
- }
- if task.ArticleTitle != "" {
- titleResult = d.dfa.Search(task.ArticleTitle)
- for k, v := range titleResult {
- if tag, ok := d.SensitiveWordMap[SensitiveWord(k)]; ok {
- TagResult[tag] = TagResult[tag] + v*10
- }
- }
- }
- item := SensitiveWordStruct{
- Content: contentResult,
- Title: titleResult,
- }
- for k, v := range TagResult {
- if v < 20 {
- delete(TagResult, k)
- }
- }
- var allValues []string
- for k, _ := range TagResult {
- allValues = append(allValues, string(k))
- }
- task.TagName = strings.Join(allValues, ",")
- str, _ := json.Marshal(item)
- task.SensitiveWords = string(str)
- }
- func (d *DAFHandler) ReloadSensitiveWordMap() {
- d.SensitiveWordMap = initSensitiveWords()
- }
- func (d *DAFHandler) BatchInsert(tasks []*dfa.ArticleDfaTagMapping) (err error) {
- return dfa.BatchInsertArticleDfaTagMapping(tasks)
- }
- func initSensitiveWords() (sensitiveWordMap map[SensitiveWord]TagName) {
- list, err := dfa.GetList()
- if err != nil {
- return nil
- }
- sensitiveWordMap = make(map[SensitiveWord]TagName)
- for _, item := range list {
- sensitiveWordMap[SensitiveWord(item.SensitiveWord)] = TagName(item.TagName)
- }
- return
- }
- func initThreadPool() (threadPool *ants.Pool) {
- cpuCores := runtime.NumCPU()
- threadPool, _ = ants.NewPool(cpuCores,
- ants.WithExpiryDuration(10*time.Second), // worker的过期时间
- ants.WithPreAlloc(true), // 预分配内存
- ants.WithMaxBlockingTasks(2*cpuCores), // 最大阻塞任务数
- ants.WithNonblocking(false), // 非阻塞模式
- ants.WithPanicHandler(dafHandler.RejectCallback), // 自定义panic处理
- )
- return
- }
- func RejectCallback(err error) {
- // 处理被拒绝的请求
- //请求被拒绝处理
- }
- func GetDAFHandlerInstance() DAFService {
- if dafHandler == nil {
- dafOnce.Do(func() {
- dafHandler = &DAFHandler{
- SensitiveWordMap: initSensitiveWords(),
- WorkerPool: initThreadPool(),
- }
- var sensitiveWords []string
- for k := range dafHandler.SensitiveWordMap {
- sensitiveWords = append(sensitiveWords, string(k))
- }
- dfa := dfaUtils.NewDFA()
- dfa.Build(sensitiveWords)
- dafHandler.dfa = dfa
- })
- }
- return dafHandler
- }
|