瀏覽代碼

文章增加系统DFA标签

kobe6258 4 天之前
父節點
當前提交
436b880b44

+ 29 - 0
models/rag/wechat_article.go

@@ -337,3 +337,32 @@ func FetchArticleDafTagMappingList(lastId int, excludeIds []int, fetchSize int)
 	}
 	return
 }
+
+func GetArticleByTags(labels []string, daysBefore int, fetchSize int) (items []*DafWechatArticleItem, err error) {
+	var tagSql = `select article_id from article_dfa_tag_mapping where (`
+	for i := 0; i < len(labels); i++ {
+		if i == 0 {
+			tagSql += fmt.Sprintf(`find_in_set('%s', tag_name)`, labels[i])
+		} else {
+			tagSql += fmt.Sprintf(` or find_in_set('%s', tag_name)`, labels[i])
+		}
+	}
+	tagSql += `)`
+	endDate := time.Now()
+	startDate := endDate.AddDate(0, 0, -daysBefore)
+	var limitedCondition string
+	if fetchSize > 0 {
+		limitedCondition = fmt.Sprintf("limit 0,%d", fetchSize)
+	}
+	var dateCondition = fmt.Sprintf("and article_create_time between '%s' and  '%s'", startDate.Format(utils.FormatDate), endDate.Format(utils.FormatDate))
+	var sql = fmt.Sprintf(`select text_content from wechat_article where   wechat_article_id in (%s) and is_deleted = 0 %s`, tagSql, limitedCondition)
+	var sqlWithDays = fmt.Sprintf(`select text_content from wechat_article where   wechat_article_id in (%s) and is_deleted = 0 %s %s `, tagSql, dateCondition, limitedCondition)
+	err = global.DbMap[utils.DbNameAI].Raw(sqlWithDays).Find(&items).Error
+	if err != nil {
+		return
+	}
+	if len(items) == 0 {
+		err = global.DbMap[utils.DbNameAI].Raw(sql).Find(&items).Error
+	}
+	return
+}

+ 10 - 0
utils/llm/dfa_handler.go

@@ -35,8 +35,18 @@ type DAFService interface {
 	RejectCallback(p interface{})
 	// BatchSubmitTasks 批量提交任务
 	BatchSubmitTasks(tasks []*dfa.ArticleDfaTagMapping)
+	FindTextTagLabels(text string) []string
 }
 
+func (d *DAFHandler) FindTextTagLabels(text string) (labels []string) {
+	result := d.dfa.Search(text)
+	if len(result) > 0 {
+		for k, _ := range result {
+			labels = append(labels, k)
+		}
+	}
+	return
+}
 func (d *DAFHandler) BatchSubmitTasks(tasks []*dfa.ArticleDfaTagMapping) {
 	// 创建结果收集通道
 	results := make(chan *dfa.ArticleDfaTagMapping, len(tasks))

+ 38 - 0
utils/llm/eta_llm/eta_llm_client.go

@@ -33,6 +33,7 @@ const (
 	DEFALUT_PROMPT_NAME            = "default"
 	CONTENT_TYPE_JSON              = "application/json"
 	KNOWLEDGE_BASE_CHAT_API        = "/chat/kb_chat"
+	COMPLETION_CHAT_API            = "/chat/chat/completions"
 	DOCUMENT_CHAT_API              = "/chat/file_chat"
 	KNOWLEDGE_BASE_SEARCH_DOCS_API = "/knowledge_base/search_docs"
 	UPLOAD_TEMP_DOCS_API           = "/knowledge_base/upload_temp_docs"
@@ -109,6 +110,43 @@ func (ds *ETALLMClient) DocumentChat(query string, KnowledgeId string, history [
 	return ds.DoStreamPost(DOCUMENT_CHAT_API, body)
 }
 
+func (ds *ETALLMClient) CompletionChat(query string, messages []json.RawMessage) (llmRes *http.Response, err error) {
+	ChatHistory := make([]eta_llm_http.HistoryContent, 0)
+	for _, historyItemStr := range messages {
+		var historyItem eta_llm_http.HistoryContentWeb
+		parseErr := json.Unmarshal(historyItemStr, &historyItem)
+		if parseErr != nil {
+			continue
+		}
+		ChatHistory = append(ChatHistory, eta_llm_http.HistoryContent{
+			Content: historyItem.Content,
+			Role:    historyItem.Role,
+		})
+	}
+	ChatHistory = append(ChatHistory, eta_llm_http.HistoryContent{
+		Content: query,
+		Role:    "user",
+	})
+	kbReq := eta_llm_http.CompletionChatRequest{
+		Mode:           KNOWLEDEG_CHAT_MODE,
+		Messages:       ChatHistory,
+		TopK:           3,
+		ScoreThreshold: 0.5,
+		Stream:         true,
+		Model:          ds.LlmModel,
+		Temperature:    0.7,
+		MaxTokens:      0,
+		PromptName:     DEFALUT_PROMPT_NAME,
+		ReturnDirect:   false,
+	}
+	fmt.Printf("%v", kbReq.Messages)
+	body, err := json.Marshal(kbReq)
+	if err != nil {
+		return
+	}
+	return ds.DoStreamPost(COMPLETION_CHAT_API, body)
+}
+
 func (ds *ETALLMClient) KnowledgeBaseChat(query string, KnowledgeBaseName string, history []json.RawMessage) (llmRes *http.Response, err error) {
 	ChatHistory := make([]eta_llm_http.HistoryContent, 0)
 	for _, historyItemStr := range history {

+ 14 - 0
utils/llm/eta_llm/eta_llm_http/request.go

@@ -49,3 +49,17 @@ type UploadTempDocsRequest struct {
 	ChunkOverlap   string `json:"chunk_overlap"`
 	ZhTitleEnhance string `json:"zh_title_enhance"`
 }
+
+type CompletionChatRequest struct {
+	Mode           string           `json:"mode"`
+	KbName         string           `json:"kb_name"`
+	TopK           int              `json:"top_k"`
+	ScoreThreshold float32          `json:"score_threshold"`
+	Messages        []HistoryContent `json:"messages"`
+	Stream         bool             `json:"stream"`
+	Model          string           `json:"model"`
+	Temperature    float32          `json:"temperature"`
+	MaxTokens      int              `json:"max_tokens"`
+	PromptName     string           `json:"prompt_name"`
+	ReturnDirect   bool             `json:"return_direct"`
+}

+ 1 - 0
utils/llm/llm_client.go

@@ -28,4 +28,5 @@ type LLMService interface {
 	SearchKbDocs(query string, KnowledgeBaseName string) (data interface{}, err error)
 	UploadFileToTemplate(files []*os.File, param map[string]interface{}) (data interface{}, err error)
 	FileChat(query string, KnowledgeId string, llmModel string, history []json.RawMessage) (resp eta_llm_http.BaseResponse, err error)
+	CompletionChat(query string, messages []json.RawMessage) (llmRes *http.Response, err error)
 }

+ 25 - 1
utils/ws/session_manager.go

@@ -3,6 +3,7 @@ package ws
 import (
 	"encoding/json"
 	"errors"
+	"eta/eta_api/models/rag"
 	chatService "eta/eta_api/services/llm"
 	"eta/eta_api/utils"
 	"eta/eta_api/utils/llm"
@@ -11,6 +12,7 @@ import (
 	"fmt"
 	"github.com/gorilla/websocket"
 	"net/http"
+	"regexp"
 	"strings"
 	"sync"
 	"time"
@@ -120,7 +122,29 @@ func (manager *ConnectionManager) HandleMessage(userID int, sessionID string, me
 			}
 		}
 	}
-	resp, err := llmService.KnowledgeBaseChat(userMessage.Query, userMessage.KbName, session.History)
+	//修改逻辑。如果问题出现敏感词,则返回敏感词提示
+	var resp *http.Response
+	labels := llm.GetDAFHandlerInstance().FindTextTagLabels(userMessage.Query)
+	if len(labels) > 0 {
+		articles, findErr := rag.GetArticleByTags(labels, 15, 10)
+		if findErr != nil {
+			utils.FileLog.Warn("没有搜索到相关的研报内容,执行RAG对话 ")
+			resp, err = llmService.KnowledgeBaseChat(userMessage.Query, userMessage.KbName, session.History)
+		} else {
+			//直接对话,不需要走RAG
+			var articlesContents string
+			articlesCounts := len(articles)
+			for i := 0; i < articlesCounts; i++ {
+				articlesContents += fmt.Sprintf("【%d】:%s\n", i+1, articles[i].TextContent)
+			}
+			promote := fmt.Sprintf("【问题】:%s,请基于以下%d篇研报进行回答问题,以下是研报:%s", userMessage.Query, articlesCounts, articlesContents)
+			re := regexp.MustCompile(`\s+`)
+			promote = re.ReplaceAllString(promote, "")
+			resp, err = llmService.CompletionChat(promote, session.History)
+		}
+	} else {
+		resp, err = llmService.KnowledgeBaseChat(userMessage.Query, userMessage.KbName, session.History)
+	}
 	defer func() {
 		if resp != nil && resp.Body != nil && err == nil {
 			_ = resp.Body.Close()