Jelajahi Sumber

合并冲突

kobe6258 4 hari lalu
induk
melakukan
6ae7abfbbc

+ 2 - 1
go.mod

@@ -36,6 +36,7 @@ require (
 	github.com/mojocn/base64Captcha v1.3.6
 	github.com/nosixtools/solarlunar v0.0.0-20211112060703-1b6dea7b4a19
 	github.com/olivere/elastic/v7 v7.0.32
+	github.com/panjf2000/ants v1.3.0
 	github.com/pdfcpu/pdfcpu v0.8.0
 	github.com/qiniu/qmgo v1.1.8
 	github.com/rdlucklib/rdluck_tools v1.0.3
@@ -50,6 +51,7 @@ require (
 	github.com/xuri/excelize/v2 v2.8.1
 	go.mongodb.org/mongo-driver v1.16.0
 	golang.org/x/net v0.27.0
+	golang.org/x/time v0.5.0
 	gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df
 	gorm.io/driver/mysql v1.5.7
 	gorm.io/gorm v1.25.12
@@ -163,7 +165,6 @@ require (
 	golang.org/x/sync v0.7.0 // indirect
 	golang.org/x/sys v0.22.0 // indirect
 	golang.org/x/text v0.16.0 // indirect
-	golang.org/x/time v0.5.0 // indirect
 	google.golang.org/protobuf v1.34.1 // indirect
 	gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect
 	gopkg.in/asn1-ber.v1 v1.0.0-20181015200546-f715ec2f112d // indirect

+ 2 - 0
go.sum

@@ -396,6 +396,8 @@ github.com/onsi/gomega v1.19.0/go.mod h1:LY+I3pBVzYsTBU1AnDwOSxaYi9WoWiqgwooUqq9
 github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b h1:FfH+VrHHk6Lxt9HdVS0PXzSXFyS2NbZKXv33FYPol0A=
 github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b/go.mod h1:AC62GU6hc0BrNm+9RK9VSiwa/EUe1bkIeFORAMcHvJU=
 github.com/openzipkin/zipkin-go v0.1.6/go.mod h1:QgAqvLzwWbR/WpD4A3cGpPtJrZXNIiJc5AZX7/PBEpw=
+github.com/panjf2000/ants v1.3.0 h1:8pQ+8leaLc9lys2viEEr8md0U4RN6uOSUCE9bOYjQ9M=
+github.com/panjf2000/ants v1.3.0/go.mod h1:AaACblRPzq35m1g3enqYcxspbbiOJJYaxU2wMpm1cXY=
 github.com/pdfcpu/pdfcpu v0.8.0 h1:SuEB4uVsPFz1nb802r38YpFpj9TtZh/oB0bGG34IRZw=
 github.com/pdfcpu/pdfcpu v0.8.0/go.mod h1:jj03y/KKrwigt5xCi8t7px2mATcKuOzkIOoCX62yMho=
 github.com/pelletier/go-toml v1.0.1/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic=

+ 60 - 0
models/llm/dfa/article_dfa_tag_mapping.go

@@ -0,0 +1,60 @@
+package dfa
+
+import (
+	"eta/eta_api/global"
+	"eta/eta_api/utils"
+	"time"
+)
+
+type ArticleSource string
+
+const (
+	articleDfaTagMappingTableName = "article_dfa_tag_mapping"
+	ArticleSourceWechat           = "wechat"
+	ArticleSourceEta              = "eta"
+)
+
+type ArticleDfaTagMapping struct {
+	ID             int           `gorm:"column:id;primary_key"`
+	ArticleID      int           `gorm:"column:article_id"`
+	Source         ArticleSource `gorm:"column:source"`
+	TagName        string        `gorm:"column:tag_name"`
+	SensitiveWords string        `gorm:"column:sensitive_words"`
+	ArticleTitle   string        `gorm:"column:article_title"`
+	ArticleContent string        `gorm:"column:article_content"`
+	Remark         string        `gorm:"column:remark"`
+	CreatedTime    time.Time     `gorm:"column:created_time"`
+	UpdatedTime    time.Time     `gorm:"column:update_time"`
+}
+
+// TableName sets the insert table name for this struct type
+func (a *ArticleDfaTagMapping) TableName() string {
+	return articleDfaTagMappingTableName
+}
+func (a *ArticleDfaTagMapping) Insert() error {
+	return global.DbMap[utils.DbNameAI].Table(articleDfaTagMappingTableName).Create(a).Error
+}
+
+func GetArticleDafTagMappingList() (items []*ArticleDfaTagMapping, err error) {
+	err = global.DbMap[utils.DbNameAI].Table(articleDfaTagMappingTableName).Select("article_id,source").Find(&items).Error
+	return
+}
+
+func BatchInsertArticleDfaTagMapping(tasks []*ArticleDfaTagMapping) error {
+	// 使用事务确保批量插入的原子性
+	tx := global.DbMap[utils.DbNameAI].Begin()
+	defer func() {
+		if r := recover(); r != nil {
+			tx.Rollback()
+		}
+	}()
+	if err := tx.Error; err != nil {
+		return err
+	}
+	// 批量插入
+	if err := tx.CreateInBatches(tasks, len(tasks)).Error; err != nil { // 每批100条
+		tx.Rollback()
+		return err
+	}
+	return tx.Commit().Error
+}

+ 31 - 0
models/llm/dfa/dfa_tag_sensitive_word_mapping.go

@@ -0,0 +1,31 @@
+package dfa
+
+import (
+	"eta/eta_api/global"
+	"eta/eta_api/utils"
+	"time"
+)
+
+const (
+	dfaTagSensitiveWordMappingTableName = "dfa_tag_sensitive_word_mapping"
+	Deleted   = 1
+	UnDeleted = 0
+)
+
+type DfaTagSensitiveWordMapping struct {
+	ID            int       `gorm:"column:id;primary_key"`
+	TagName       string    `gorm:"column:tag_name"`
+	SensitiveWord string    `gorm:"column:sensitive_word"`
+	IsDeleted     int       `gorm:"column:is_deleted"`
+	CreatedTime   time.Time `gorm:"column:created_time"`
+	UpdateTime    time.Time `gorm:"column:update_time"`
+}
+
+func (DfaTagSensitiveWordMapping) TableName() string {
+	return dfaTagSensitiveWordMappingTableName
+}
+
+func GetList() (items []*DfaTagSensitiveWordMapping, err error) {
+	err = global.DbMap[utils.DbNameAI].Table(dfaTagSensitiveWordMappingTableName).Select("tag_name,sensitive_word").Where("is_deleted = ?", UnDeleted).Find(&items).Error
+	return
+}

+ 61 - 0
models/rag/wechat_article.go

@@ -305,3 +305,64 @@ func (m *WechatArticle) GetPageListByPlatformCondition(condition string, pars []
 
 	return
 }
+
+type DafWechatArticleItem struct {
+	WechatArticleId int    `gorm:"column:wechat_article_id;type:int(10) UNSIGNED;primaryKey;not null;" description:""`
+	Title           string `gorm:"column:title;type:varchar(255);comment:标题;" description:"标题"`
+	TextContent     string `gorm:"column:text_content;type:text;comment:文本内容;" description:"文本内容"`
+}
+
+func (m *WechatArticle) ToDAFView() *DafWechatArticleItem {
+	return &DafWechatArticleItem{
+		WechatArticleId: m.WechatArticleId,
+		Title:           m.Title,
+		TextContent:     m.TextContent,
+	}
+}
+func FetchArticleDafTagMappingList(lastId int, excludeIds []int, fetchSize int) (items []*DafWechatArticleItem, id int, err error) {
+	var articleList []*WechatArticle
+	if len(excludeIds) == 0 {
+		err = global.DbMap[utils.DbNameAI].Model(&WechatArticle{}).Select("wechat_article_id,title,text_content").Where("wechat_article_id > ? and is_deleted=?", lastId, 0).Order("wechat_article_id asc").Limit(fetchSize).Find(&articleList).Error
+	} else {
+		err = global.DbMap[utils.DbNameAI].Model(&WechatArticle{}).Select("wechat_article_id,title,text_content").Where("wechat_article_id > ? and wechat_article_id not in (?) and is_deleted=?", lastId, excludeIds, 0).Order("wechat_article_id asc").Limit(fetchSize).Find(&articleList).Error
+	}
+	if err != nil {
+		return
+	}
+	for _, article := range articleList {
+		items = append(items, article.ToDAFView())
+	}
+	if len(articleList) > 1 {
+		id = articleList[len(articleList)-1].WechatArticleId
+	}
+	return
+}
+
+func GetArticleByTags(labels []string, daysBefore int, fetchSize int) (items []*DafWechatArticleItem, err error) {
+	var tagSql = `select article_id from article_dfa_tag_mapping where (`
+	for i := 0; i < len(labels); i++ {
+		if i == 0 {
+			tagSql += fmt.Sprintf(`find_in_set('%s', tag_name)`, labels[i])
+		} else {
+			tagSql += fmt.Sprintf(` or find_in_set('%s', tag_name)`, labels[i])
+		}
+	}
+	tagSql += `)`
+	endDate := time.Now()
+	startDate := endDate.AddDate(0, 0, -daysBefore)
+	var limitedCondition string
+	if fetchSize > 0 {
+		limitedCondition = fmt.Sprintf("limit 0,%d", fetchSize)
+	}
+	var dateCondition = fmt.Sprintf("and article_create_time between '%s' and  '%s'", startDate.Format(utils.FormatDate), endDate.Format(utils.FormatDate))
+	var sql = fmt.Sprintf(`select text_content from wechat_article where   wechat_article_id in (%s) and is_deleted = 0 %s`, tagSql, limitedCondition)
+	var sqlWithDays = fmt.Sprintf(`select text_content from wechat_article where   wechat_article_id in (%s) and is_deleted = 0 %s %s `, tagSql, dateCondition, limitedCondition)
+	err = global.DbMap[utils.DbNameAI].Raw(sqlWithDays).Find(&items).Error
+	if err != nil {
+		return
+	}
+	if len(items) == 0 {
+		err = global.DbMap[utils.DbNameAI].Raw(sql).Find(&items).Error
+	}
+	return
+}

+ 98 - 0
services/llm/daf_service.go

@@ -0,0 +1,98 @@
+package llm
+
+import (
+	"eta/eta_api/models/llm/dfa"
+	"eta/eta_api/models/rag"
+	"eta/eta_api/utils"
+	"eta/eta_api/utils/llm"
+	"fmt"
+	"time"
+)
+
+const (
+	fetchSize            = 500
+	ETA_ARTICLE_TABLE    = "rag_eta_report"
+	WECHAT_ARTICLE_TABLE = "wechat_article"
+)
+
+var (
+	DAFHandler = llm.GetDAFHandlerInstance()
+)
+
+func DealHistoryArticleDafTags() {
+	utils.FileLog.Info("自动处理历史文章算法标签任务开始")
+	//获取还未处理的标签的文章
+	ArticleParts, err := getDealDafTagsArticleList()
+	if err != nil {
+		utils.FileLog.Error("自动处理历史文章算法标签任务错误退出:err", err.Error())
+		return
+	}
+	go func() {
+		etaErr := DAFTagDeal(dfa.ArticleSourceEta, ArticleParts[dfa.ArticleSourceEta])
+		if etaErr != nil {
+
+		}
+	}()
+	go func() {
+		wechatErr := DAFTagDeal(dfa.ArticleSourceWechat, ArticleParts[dfa.ArticleSourceWechat])
+		if wechatErr != nil {
+
+		}
+	}()
+	utils.FileLog.Info("自动处理历史文章算法标签任务结束")
+}
+
+func getDealDafTagsArticleList() (articleIdsMap map[dfa.ArticleSource][]int, err error) {
+	list, err := dfa.GetArticleDafTagMappingList()
+	if err != nil {
+		utils.FileLog.Error("获取DAF算法标签文章列表失败,err:", err)
+		err = fmt.Errorf("获取DAF算法标签文章列表失败,err:%v", err)
+		return
+	}
+	articleIdsMap = make(map[dfa.ArticleSource][]int)
+	for _, item := range list {
+		if _, ok := articleIdsMap[item.Source]; !ok {
+			articleIdsMap[item.Source] = make([]int, 0)
+		}
+		articleIdsMap[item.Source] = append(articleIdsMap[item.Source], item.ArticleID)
+	}
+	return
+}
+func DAFTagDeal(source dfa.ArticleSource, excludeArticleIds []int) (err error) {
+	switch source {
+	case dfa.ArticleSourceWechat:
+		return wechatDeal(excludeArticleIds)
+	case dfa.ArticleSourceEta:
+		return
+	default:
+		utils.FileLog.Warn("位置的文章来源,无法进行DAF标签处理,自动退出")
+		return
+	}
+}
+func wechatDeal(excludeArticleIds []int) (err error) {
+	var id = 0
+	for {
+		var articleList []*rag.DafWechatArticleItem
+		var dafTagList []*dfa.ArticleDfaTagMapping
+		articleList, id, err = rag.FetchArticleDafTagMappingList(id, excludeArticleIds, fetchSize)
+		if err != nil {
+			utils.FileLog.Error("获取微信文章列表失败,err:", err)
+			return
+		}
+		if len(articleList) == 0 {
+			break
+		}
+		for _, article := range articleList {
+			dafTagList = append(dafTagList, &dfa.ArticleDfaTagMapping{
+				ArticleID:      article.WechatArticleId,
+				Source:         dfa.ArticleSourceWechat,
+				ArticleTitle:   article.Title,
+				ArticleContent: article.TextContent,
+				TagName:        "微信文章",
+				CreatedTime:    time.Now(),
+			})
+		}
+		DAFHandler.BatchSubmitTasks(dafTagList)
+	}
+	return
+}

+ 2 - 0
services/task.go

@@ -87,6 +87,8 @@ func Task() {
 	// 权益报告监听入库
 	go AutoInsertRaiReport()
 
+	go llm.DealHistoryArticleDafTags()
+
 	// TODO:数据修复
 	//FixNewEs()
 	fmt.Println("task end")

+ 92 - 0
utils/llm/algorithm/dfa/dfa_algorithm.go

@@ -0,0 +1,92 @@
+package daf
+
+// DFANode [DFA 全称为: Deterministic Finite Automaton(确定有穷自动机)算法]
+type DFANode struct {
+	children map[rune]*DFANode
+	isEnd    bool
+	keyword  string
+}
+
+type DFA struct {
+	root *DFANode
+}
+
+func NewDFA() *DFA {
+	return &DFA{
+		root: &DFANode{
+			children: make(map[rune]*DFANode),
+		},
+	}
+}
+
+func (d *DFA) AddKeyword(keyword string) {
+	node := d.root
+	for _, r := range keyword {
+		if _, exists := node.children[r]; !exists {
+			node.children[r] = &DFANode{
+				children: make(map[rune]*DFANode),
+				isEnd:    false,
+			}
+		}
+		node = node.children[r]
+	}
+	node.isEnd = true
+	node.keyword = keyword
+}
+
+func (d *DFA) Build(keywords []string) {
+	for _, kw := range keywords {
+		d.AddKeyword(kw)
+	}
+}
+
+func (d *DFA) Search(text string) map[string]int {
+	result := make(map[string]int)
+	runes := []rune(text)
+	n := len(runes)
+
+	for i := 0; i < n; {
+		node := d.root
+		longestMatch := ""
+		matchEnd := i
+
+		// 尝试寻找从i开始的最长匹配
+		for j := i; j < n; j++ {
+			r := runes[j]
+			if nextNode, exists := node.children[r]; exists {
+				node = nextNode
+				if node.isEnd {
+					// 检查边界条件 - 中文不需要严格的单词边界检查
+					longestMatch = node.keyword
+					matchEnd = j + 1
+				}
+			} else {
+				break
+			}
+		}
+
+		if longestMatch != "" {
+			result[longestMatch]++
+			i = matchEnd // 跳到匹配结束位置
+		} else {
+			i++
+		}
+	}
+
+	return result
+}
+
+//单元测试用
+//func main() {
+//	keywords := []string{"人工智能", "机器学习", "深度学习", "AI"}
+//	text := "人工智能人工智能(AI)是机器学习的重要分支,深度学习则是机器学习的一个子领域。AI技术正在快速发展。机器学习AI机器学习机器学习机器学习机器学习机器学习机器学习机器学习机器学习机器学习AIAIAI机器学习机器学习AI"
+//
+//	dfa := NewDFA()
+//	dfa.Build(keywords)
+//	result := dfa.Search(text)
+//
+//	fmt.Println("关键词出现次数:")
+//	for kw, count := range result {
+//		fmt.Printf("%s: %d\n", kw, count)
+//	}
+//}

+ 188 - 0
utils/llm/dfa_handler.go

@@ -0,0 +1,188 @@
+package llm
+
+import (
+	"encoding/json"
+	"eta/eta_api/models/llm/dfa"
+	"eta/eta_api/utils"
+	dfaUtils "eta/eta_api/utils/llm/algorithm/dfa"
+	"fmt"
+	"github.com/panjf2000/ants"
+	"runtime"
+	"strings"
+	"sync"
+	"time"
+)
+
+var (
+	dafOnce sync.Once
+
+	dafHandler *DAFHandler
+)
+
+type TagName string
+type SensitiveWord string
+type DAFHandler struct {
+	dfa              *dfaUtils.DFA
+	WorkerPool       *ants.Pool
+	SensitiveWordMap map[SensitiveWord]TagName
+}
+type DAFService interface {
+	// ReloadSensitiveWordMap 重新加载敏感词
+	ReloadSensitiveWordMap()
+	// GetTextTagLabels 处理文本标签
+	GetTextTagLabels(task *dfa.ArticleDfaTagMapping)
+	// RejectCallback 携程池拒绝策略
+	RejectCallback(p interface{})
+	// BatchSubmitTasks 批量提交任务
+	BatchSubmitTasks(tasks []*dfa.ArticleDfaTagMapping)
+	FindTextTagLabels(text string) []string
+}
+
+func (d *DAFHandler) FindTextTagLabels(text string) (labels []string) {
+	result := d.dfa.Search(text)
+	if len(result) > 0 {
+		for k, _ := range result {
+			labels = append(labels, k)
+		}
+	}
+	return
+}
+func (d *DAFHandler) BatchSubmitTasks(tasks []*dfa.ArticleDfaTagMapping) {
+	// 创建结果收集通道
+	results := make(chan *dfa.ArticleDfaTagMapping, len(tasks))
+	var wg sync.WaitGroup
+	for _, task := range tasks {
+		wg.Add(1)
+		// 复制task以避免并发问题
+		currentTask := *task
+		time.Sleep(200 * time.Millisecond)
+		err := d.WorkerPool.Submit(func() {
+			defer wg.Done()
+			d.GetTextTagLabels(&currentTask)
+			results <- &currentTask
+		})
+		if err != nil {
+			utils.FileLog.Error("提交任务失败:", err.Error())
+			wg.Done()
+		}
+	}
+	// 等待所有任务完成
+	go func() {
+		wg.Wait()
+		close(results)
+	}()
+
+	// 收集所有结果
+	var processedTasks []*dfa.ArticleDfaTagMapping
+	for task := range results {
+		processedTasks = append(processedTasks, task)
+	}
+	fmt.Println(fmt.Sprintf("处理完成,开始批量插入数据...,%d", len(processedTasks)))
+	// 批量插入
+	if len(processedTasks) > 0 {
+		if err := d.BatchInsert(processedTasks); err != nil {
+			utils.FileLog.Error("批量插入文章系统提示词结果失败:", err.Error())
+		}
+	}
+}
+func (d *DAFHandler) RejectCallback(p interface{}) {
+	fmt.Printf(fmt.Sprintf("任务被拒绝: %v", p))
+}
+
+type SensitiveWordStruct struct {
+	Title   map[string]int
+	Content map[string]int
+}
+
+func (d *DAFHandler) GetTextTagLabels(task *dfa.ArticleDfaTagMapping) {
+	var TagResult = make(map[TagName]int)
+	var titleResult, contentResult map[string]int
+	if task.ArticleContent != "" {
+		contentResult = d.dfa.Search(task.ArticleContent)
+		for k, v := range contentResult {
+			if tag, ok := d.SensitiveWordMap[SensitiveWord(k)]; ok {
+				TagResult[tag] = TagResult[tag] + v
+			}
+		}
+	}
+	if task.ArticleTitle != "" {
+		titleResult = d.dfa.Search(task.ArticleTitle)
+		for k, v := range titleResult {
+			if tag, ok := d.SensitiveWordMap[SensitiveWord(k)]; ok {
+				TagResult[tag] = TagResult[tag] + v*10
+			}
+		}
+	}
+	item := SensitiveWordStruct{
+		Content: contentResult,
+		Title:   titleResult,
+	}
+	for k, v := range TagResult {
+		if v < 20 {
+			delete(TagResult, k)
+		}
+	}
+	var allValues []string
+	for k, _ := range TagResult {
+		allValues = append(allValues, string(k))
+	}
+	task.TagName = strings.Join(allValues, ",")
+	str, _ := json.Marshal(item)
+	task.SensitiveWords = string(str)
+}
+func (d *DAFHandler) ReloadSensitiveWordMap() {
+	d.SensitiveWordMap = initSensitiveWords()
+}
+
+func (d *DAFHandler) BatchInsert(tasks []*dfa.ArticleDfaTagMapping) (err error) {
+	return dfa.BatchInsertArticleDfaTagMapping(tasks)
+
+}
+func initSensitiveWords() (sensitiveWordMap map[SensitiveWord]TagName) {
+	list, err := dfa.GetList()
+	if err != nil {
+		return nil
+	}
+	sensitiveWordMap = make(map[SensitiveWord]TagName)
+	for _, item := range list {
+		sensitiveWordMap[SensitiveWord(item.SensitiveWord)] = TagName(item.TagName)
+	}
+	return
+}
+
+func initThreadPool() (threadPool *ants.Pool) {
+	cpuCores := runtime.NumCPU()
+	threadPool, _ = ants.NewPool(cpuCores,
+		ants.WithExpiryDuration(10*time.Second),          // worker的过期时间
+		ants.WithPreAlloc(true),                          // 预分配内存
+		ants.WithMaxBlockingTasks(2*cpuCores),            // 最大阻塞任务数
+		ants.WithNonblocking(false),                      // 非阻塞模式
+		ants.WithPanicHandler(dafHandler.RejectCallback), // 自定义panic处理
+	)
+	return
+}
+
+func RejectCallback(err error) {
+	// 处理被拒绝的请求
+	//请求被拒绝处理
+
+}
+
+func GetDAFHandlerInstance() DAFService {
+	if dafHandler == nil {
+		dafOnce.Do(func() {
+			dafHandler = &DAFHandler{
+				SensitiveWordMap: initSensitiveWords(),
+				WorkerPool:       initThreadPool(),
+			}
+			var sensitiveWords []string
+			for k := range dafHandler.SensitiveWordMap {
+				sensitiveWords = append(sensitiveWords, string(k))
+			}
+			dfa := dfaUtils.NewDFA()
+			dfa.Build(sensitiveWords)
+			dafHandler.dfa = dfa
+		})
+	}
+	return dafHandler
+}

+ 38 - 0
utils/llm/eta_llm/eta_llm_client.go

@@ -33,6 +33,7 @@ const (
 	DEFALUT_PROMPT_NAME            = "default"
 	CONTENT_TYPE_JSON              = "application/json"
 	KNOWLEDGE_BASE_CHAT_API        = "/chat/kb_chat"
+	COMPLETION_CHAT_API            = "/chat/chat/completions"
 	DOCUMENT_CHAT_API              = "/chat/file_chat"
 	KNOWLEDGE_BASE_SEARCH_DOCS_API = "/knowledge_base/search_docs"
 	UPLOAD_TEMP_DOCS_API           = "/knowledge_base/upload_temp_docs"
@@ -109,6 +110,43 @@ func (ds *ETALLMClient) DocumentChat(query string, KnowledgeId string, history [
 	return ds.DoStreamPost(DOCUMENT_CHAT_API, body)
 }
 
+func (ds *ETALLMClient) CompletionChat(query string, messages []json.RawMessage) (llmRes *http.Response, err error) {
+	ChatHistory := make([]eta_llm_http.HistoryContent, 0)
+	for _, historyItemStr := range messages {
+		var historyItem eta_llm_http.HistoryContentWeb
+		parseErr := json.Unmarshal(historyItemStr, &historyItem)
+		if parseErr != nil {
+			continue
+		}
+		ChatHistory = append(ChatHistory, eta_llm_http.HistoryContent{
+			Content: historyItem.Content,
+			Role:    historyItem.Role,
+		})
+	}
+	ChatHistory = append(ChatHistory, eta_llm_http.HistoryContent{
+		Content: query,
+		Role:    "user",
+	})
+	kbReq := eta_llm_http.CompletionChatRequest{
+		Mode:           KNOWLEDEG_CHAT_MODE,
+		Messages:       ChatHistory,
+		TopK:           3,
+		ScoreThreshold: 0.5,
+		Stream:         true,
+		Model:          ds.LlmModel,
+		Temperature:    0.7,
+		MaxTokens:      0,
+		PromptName:     DEFALUT_PROMPT_NAME,
+		ReturnDirect:   false,
+	}
+	fmt.Printf("%v", kbReq.Messages)
+	body, err := json.Marshal(kbReq)
+	if err != nil {
+		return
+	}
+	return ds.DoStreamPost(COMPLETION_CHAT_API, body)
+}
+
 func (ds *ETALLMClient) KnowledgeBaseChat(query string, KnowledgeBaseName string, history []json.RawMessage) (llmRes *http.Response, err error) {
 	ChatHistory := make([]eta_llm_http.HistoryContent, 0)
 	for _, historyItemStr := range history {

+ 14 - 0
utils/llm/eta_llm/eta_llm_http/request.go

@@ -49,3 +49,17 @@ type UploadTempDocsRequest struct {
 	ChunkOverlap   string `json:"chunk_overlap"`
 	ZhTitleEnhance string `json:"zh_title_enhance"`
 }
+
+type CompletionChatRequest struct {
+	Mode           string           `json:"mode"`
+	KbName         string           `json:"kb_name"`
+	TopK           int              `json:"top_k"`
+	ScoreThreshold float32          `json:"score_threshold"`
+	Messages        []HistoryContent `json:"messages"`
+	Stream         bool             `json:"stream"`
+	Model          string           `json:"model"`
+	Temperature    float32          `json:"temperature"`
+	MaxTokens      int              `json:"max_tokens"`
+	PromptName     string           `json:"prompt_name"`
+	ReturnDirect   bool             `json:"return_direct"`
+}

+ 1 - 0
utils/llm/llm_client.go

@@ -28,4 +28,5 @@ type LLMService interface {
 	SearchKbDocs(query string, KnowledgeBaseName string) (data interface{}, err error)
 	UploadFileToTemplate(files []*os.File, param map[string]interface{}) (data interface{}, err error)
 	FileChat(query string, KnowledgeId string, llmModel string, history []json.RawMessage) (resp eta_llm_http.BaseResponse, err error)
+	CompletionChat(query string, messages []json.RawMessage) (llmRes *http.Response, err error)
 }

+ 25 - 1
utils/ws/session_manager.go

@@ -3,6 +3,7 @@ package ws
 import (
 	"encoding/json"
 	"errors"
+	"eta/eta_api/models/rag"
 	chatService "eta/eta_api/services/llm"
 	"eta/eta_api/utils"
 	"eta/eta_api/utils/llm"
@@ -11,6 +12,7 @@ import (
 	"fmt"
 	"github.com/gorilla/websocket"
 	"net/http"
+	"regexp"
 	"strings"
 	"sync"
 	"time"
@@ -120,7 +122,29 @@ func (manager *ConnectionManager) HandleMessage(userID int, sessionID string, me
 			}
 		}
 	}
-	resp, err := llmService.KnowledgeBaseChat(userMessage.Query, userMessage.KbName, session.History)
+	//修改逻辑。如果问题出现敏感词,则返回敏感词提示
+	var resp *http.Response
+	labels := llm.GetDAFHandlerInstance().FindTextTagLabels(userMessage.Query)
+	if len(labels) > 0 {
+		articles, findErr := rag.GetArticleByTags(labels, 15, 10)
+		if findErr != nil {
+			utils.FileLog.Warn("没有搜索到相关的研报内容,执行RAG对话 ")
+			resp, err = llmService.KnowledgeBaseChat(userMessage.Query, userMessage.KbName, session.History)
+		} else {
+			//直接对话,不需要走RAG
+			var articlesContents string
+			articlesCounts := len(articles)
+			for i := 0; i < articlesCounts; i++ {
+				articlesContents += fmt.Sprintf("【%d】:%s\n", i+1, articles[i].TextContent)
+			}
+			promote := fmt.Sprintf("【问题】:%s,请基于以下%d篇研报进行回答问题,以下是研报:%s", userMessage.Query, articlesCounts, articlesContents)
+			re := regexp.MustCompile(`\s+`)
+			promote = re.ReplaceAllString(promote, "")
+			resp, err = llmService.CompletionChat(promote, session.History)
+		}
+	} else {
+		resp, err = llmService.KnowledgeBaseChat(userMessage.Query, userMessage.KbName, session.History)
+	}
 	defer func() {
 		if resp != nil && resp.Body != nil && err == nil {
 			_ = resp.Body.Close()