|
@@ -0,0 +1,178 @@
|
|
|
+package llm
|
|
|
+
|
|
|
+import (
|
|
|
+ "encoding/json"
|
|
|
+ "eta/eta_api/models/llm/dfa"
|
|
|
+ "eta/eta_api/utils"
|
|
|
+ dfaUtils "eta/eta_api/utils/llm/algorithm/dfa"
|
|
|
+ "fmt"
|
|
|
+ "github.com/panjf2000/ants"
|
|
|
+ "runtime"
|
|
|
+ "strings"
|
|
|
+ "sync"
|
|
|
+ "time"
|
|
|
+)
|
|
|
+
|
|
|
+var (
|
|
|
+ dafOnce sync.Once
|
|
|
+
|
|
|
+ dafHandler *DAFHandler
|
|
|
+)
|
|
|
+
|
|
|
+type TagName string
|
|
|
+type SensitiveWord string
|
|
|
+type DAFHandler struct {
|
|
|
+ dfa *dfaUtils.DFA
|
|
|
+ WorkerPool *ants.Pool
|
|
|
+ SensitiveWordMap map[SensitiveWord]TagName
|
|
|
+}
|
|
|
+type DAFService interface {
|
|
|
+ // ReloadSensitiveWordMap 重新加载敏感词
|
|
|
+ ReloadSensitiveWordMap()
|
|
|
+ // GetTextTagLabels 处理文本标签
|
|
|
+ GetTextTagLabels(task *dfa.ArticleDfaTagMapping)
|
|
|
+ // RejectCallback 携程池拒绝策略
|
|
|
+ RejectCallback(p interface{})
|
|
|
+ // BatchSubmitTasks 批量提交任务
|
|
|
+ BatchSubmitTasks(tasks []*dfa.ArticleDfaTagMapping)
|
|
|
+}
|
|
|
+
|
|
|
+func (d *DAFHandler) BatchSubmitTasks(tasks []*dfa.ArticleDfaTagMapping) {
|
|
|
+ // 创建结果收集通道
|
|
|
+ results := make(chan *dfa.ArticleDfaTagMapping, len(tasks))
|
|
|
+ var wg sync.WaitGroup
|
|
|
+ for _, task := range tasks {
|
|
|
+ wg.Add(1)
|
|
|
+ // 复制task以避免并发问题
|
|
|
+ currentTask := *task
|
|
|
+ time.Sleep(200 * time.Millisecond)
|
|
|
+ err := d.WorkerPool.Submit(func() {
|
|
|
+ defer wg.Done()
|
|
|
+ d.GetTextTagLabels(¤tTask)
|
|
|
+ results <- ¤tTask
|
|
|
+ })
|
|
|
+ if err != nil {
|
|
|
+ utils.FileLog.Error("提交任务失败:", err.Error())
|
|
|
+ wg.Done()
|
|
|
+ }
|
|
|
+ }
|
|
|
+ // 等待所有任务完成
|
|
|
+ go func() {
|
|
|
+ wg.Wait()
|
|
|
+ close(results)
|
|
|
+ }()
|
|
|
+
|
|
|
+ // 收集所有结果
|
|
|
+ var processedTasks []*dfa.ArticleDfaTagMapping
|
|
|
+ for task := range results {
|
|
|
+ processedTasks = append(processedTasks, task)
|
|
|
+ }
|
|
|
+ fmt.Println(fmt.Sprintf("处理完成,开始批量插入数据...,%d", len(processedTasks)))
|
|
|
+ // 批量插入
|
|
|
+ if len(processedTasks) > 0 {
|
|
|
+ if err := d.BatchInsert(processedTasks); err != nil {
|
|
|
+ utils.FileLog.Error("批量插入文章系统提示词结果失败:", err.Error())
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+func (d *DAFHandler) RejectCallback(p interface{}) {
|
|
|
+ fmt.Printf(fmt.Sprintf("任务被拒绝: %v", p))
|
|
|
+}
|
|
|
+
|
|
|
+type SensitiveWordStruct struct {
|
|
|
+ Title map[string]int
|
|
|
+ Content map[string]int
|
|
|
+}
|
|
|
+
|
|
|
+func (d *DAFHandler) GetTextTagLabels(task *dfa.ArticleDfaTagMapping) {
|
|
|
+ var TagResult = make(map[TagName]int)
|
|
|
+ var titleResult, contentResult map[string]int
|
|
|
+ if task.ArticleContent != "" {
|
|
|
+ contentResult = d.dfa.Search(task.ArticleContent)
|
|
|
+ for k, v := range contentResult {
|
|
|
+ if tag, ok := d.SensitiveWordMap[SensitiveWord(k)]; ok {
|
|
|
+ TagResult[tag] = TagResult[tag] + v
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if task.ArticleTitle != "" {
|
|
|
+ titleResult = d.dfa.Search(task.ArticleTitle)
|
|
|
+ for k, v := range titleResult {
|
|
|
+ if tag, ok := d.SensitiveWordMap[SensitiveWord(k)]; ok {
|
|
|
+ TagResult[tag] = TagResult[tag] + v*10
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ item := SensitiveWordStruct{
|
|
|
+ Content: contentResult,
|
|
|
+ Title: titleResult,
|
|
|
+ }
|
|
|
+ for k, v := range TagResult {
|
|
|
+ if v < 20 {
|
|
|
+ delete(TagResult, k)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ var allValues []string
|
|
|
+ for k, _ := range TagResult {
|
|
|
+ allValues = append(allValues, string(k))
|
|
|
+ }
|
|
|
+ task.TagName = strings.Join(allValues, ",")
|
|
|
+ str, _ := json.Marshal(item)
|
|
|
+ task.SensitiveWords = string(str)
|
|
|
+}
|
|
|
+func (d *DAFHandler) ReloadSensitiveWordMap() {
|
|
|
+ d.SensitiveWordMap = initSensitiveWords()
|
|
|
+}
|
|
|
+
|
|
|
+func (d *DAFHandler) BatchInsert(tasks []*dfa.ArticleDfaTagMapping) (err error) {
|
|
|
+ return dfa.BatchInsertArticleDfaTagMapping(tasks)
|
|
|
+
|
|
|
+}
|
|
|
+func initSensitiveWords() (sensitiveWordMap map[SensitiveWord]TagName) {
|
|
|
+ list, err := dfa.GetList()
|
|
|
+ if err != nil {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ sensitiveWordMap = make(map[SensitiveWord]TagName)
|
|
|
+ for _, item := range list {
|
|
|
+ sensitiveWordMap[SensitiveWord(item.SensitiveWord)] = TagName(item.TagName)
|
|
|
+ }
|
|
|
+ return
|
|
|
+}
|
|
|
+
|
|
|
+func initThreadPool() (threadPool *ants.Pool) {
|
|
|
+ cpuCores := runtime.NumCPU()
|
|
|
+ threadPool, _ = ants.NewPool(cpuCores,
|
|
|
+ ants.WithExpiryDuration(10*time.Second), // worker的过期时间
|
|
|
+ ants.WithPreAlloc(true), // 预分配内存
|
|
|
+ ants.WithMaxBlockingTasks(2*cpuCores), // 最大阻塞任务数
|
|
|
+ ants.WithNonblocking(false), // 非阻塞模式
|
|
|
+ ants.WithPanicHandler(dafHandler.RejectCallback), // 自定义panic处理
|
|
|
+ )
|
|
|
+ return
|
|
|
+}
|
|
|
+
|
|
|
+func RejectCallback(err error) {
|
|
|
+ // 处理被拒绝的请求
|
|
|
+ //请求被拒绝处理
|
|
|
+
|
|
|
+}
|
|
|
+
|
|
|
+func GetDAFHandlerInstance() DAFService {
|
|
|
+ if dafHandler == nil {
|
|
|
+ dafOnce.Do(func() {
|
|
|
+ dafHandler = &DAFHandler{
|
|
|
+ SensitiveWordMap: initSensitiveWords(),
|
|
|
+ WorkerPool: initThreadPool(),
|
|
|
+ }
|
|
|
+ var sensitiveWords []string
|
|
|
+ for k := range dafHandler.SensitiveWordMap {
|
|
|
+ sensitiveWords = append(sensitiveWords, string(k))
|
|
|
+ }
|
|
|
+ dfa := dfaUtils.NewDFA()
|
|
|
+ dfa.Build(sensitiveWords)
|
|
|
+ dafHandler.dfa = dfa
|
|
|
+ })
|
|
|
+ }
|
|
|
+ return dafHandler
|
|
|
+}
|