kobe6258 2 månader sedan
förälder
incheckning
60d91eab31

+ 41 - 0
controllers/rag/chat_controller.go

@@ -1,7 +1,9 @@
 package rag
 
 import (
+	"encoding/json"
 	"eta/eta_api/controllers"
+	"eta/eta_api/models"
 	"eta/eta_api/models/system"
 	"eta/eta_api/services/llm/facade"
 	"eta/eta_api/utils"
@@ -25,6 +27,45 @@ func (cc *ChatController) Prepare() {
 	}
 }
 
+// ChatTest @Title 测试知识库问答
+// @Description 测试知识库问答
+// @Success 101 {object} response.ListResp
+// @router /chat/chat_test [post]
+func (kbctrl *KbController) ChatTest() {
+	br := new(models.BaseResponse).Init()
+	defer func() {
+		if br.ErrMsg == "" {
+			br.IsSendEmail = false
+		}
+		kbctrl.Data["json"] = br
+		kbctrl.ServeJSON()
+	}()
+	sysUser := kbctrl.SysUser
+	if sysUser == nil {
+		br.Msg = "请登录"
+		br.ErrMsg = "请登录,SysUser Is Empty"
+		br.Ret = 408
+		return
+	}
+	var req facade.LLMKnowledgeChat
+	err := json.Unmarshal(kbctrl.Ctx.Input.RequestBody, &req)
+	if err != nil {
+		br.Msg = "参数解析异常!"
+		br.ErrMsg = "参数解析失败,Err:" + err.Error()
+		return
+	}
+	searchResp, err := facade.LLMKnowledgeBaseChat(req)
+	if err != nil {
+		br.Msg = "知识库问答"
+		br.ErrMsg = "知识库问答:" + err.Error()
+		return
+	}
+	br.Data = searchResp
+	br.Ret = 200
+	br.Success = true
+	br.Msg = "知识库问答成功"
+}
+
 // ChatConnect @Title 知识库问答创建对话连接
 // @Description 知识库问答创建对话连接
 // @Success 101 {object} response.ListResp

+ 1 - 0
controllers/rag/kb_controller.go

@@ -49,3 +49,4 @@ func (kbctrl *KbController) SearchDocs() {
 	br.Success = true
 	br.Msg = "获取成功"
 }
+

+ 9 - 0
routers/commentsRouter.go

@@ -8530,6 +8530,15 @@ func init() {
             Filters: nil,
             Params: nil})
 
+    beego.GlobalControllerRouter["eta/eta_api/controllers/rag:KbController"] = append(beego.GlobalControllerRouter["eta/eta_api/controllers/rag:KbController"],
+        beego.ControllerComments{
+            Method: "ChatTest",
+            Router: `/chat/chat_test`,
+            AllowHTTPMethods: []string{"post"},
+            MethodParams: param.Make(),
+            Filters: nil,
+            Params: nil})
+
     beego.GlobalControllerRouter["eta/eta_api/controllers/rag:KbController"] = append(beego.GlobalControllerRouter["eta/eta_api/controllers/rag:KbController"],
         beego.ControllerComments{
             Method: "SearchDocs",

+ 7 - 0
services/llm/facade/bus_response/bus_response.go

@@ -7,6 +7,13 @@ type SearchDocsResponse struct {
 	Id          string   `json:"id"`
 	Score       float32  `json:"score"`
 }
+type KnowledgeBaseChatResponse struct {
+	PageContent string   `json:"page_content"`
+	Metadata    Metadata `json:"metadata"`
+	Type        string   `json:"type"`
+	Id          string   `json:"id"`
+	Score       float32  `json:"score"`
+}
 
 type Metadata struct {
 	Source string `json:"source"`

+ 15 - 0
services/llm/facade/llm_service.go

@@ -34,8 +34,23 @@ func LLMKnowledgeBaseSearchDocs(search LLMKnowledgeSearch) (resp bus_response.Se
 	resp.Docs = docs.([]bus_response.SearchDocsResponse)
 	return
 }
+func LLMKnowledgeBaseChat(chat LLMKnowledgeChat) (resp bus_response.SearchDocsEtaResponse, err error) {
+	docs, err := llmService.KnowledgeBaseChat(chat.Query, chat.KbName, nil)
+	if err != nil {
+		return
+	}
+	for _, doc := range docs.([]bus_response.SearchDocsResponse) {
+		resp.Content = resp.Content + doc.PageContent
+	}
+	resp.Docs = docs.([]bus_response.SearchDocsResponse)
+	return
+}
 
 type LLMKnowledgeSearch struct {
 	Query             string `json:"Query"`
 	KnowledgeBaseName string `json:"KnowledgeBaseName"`
 }
+type LLMKnowledgeChat struct {
+	Query  string `json:"Query"`
+	KbName string `json:"KbName"`
+}

+ 101 - 6
utils/llm/eta_llm/eta_llm_client.go

@@ -1,6 +1,7 @@
 package eta_llm
 
 import (
+	"bufio"
 	"bytes"
 	"encoding/json"
 	"errors"
@@ -21,6 +22,8 @@ var (
 )
 
 const (
+	KNOWLEDEG_CHAT_MODE            = "local_kb"
+	DEFALUT_PROMPT_NAME            = "default"
 	CONTENT_TYPE_JSON              = "application/json"
 	KNOWLEDGE_BASE_CHAT_API        = "/chat/kb_chat"
 	KNOWLEDGE_BASE_SEARCH_DOCS_API = "/knowledge_base/search_docs"
@@ -43,14 +46,58 @@ func GetInstance() llm.LLMService {
 	return etaLlmClient
 }
 
-func (ds *ETALLMClient) KnowledgeBaseChat() string {
-	ds.HttpClient.Post(ds.BaseURL+KNOWLEDGE_BASE_CHAT_API, CONTENT_TYPE_JSON, nil)
-	return ""
+func (ds *ETALLMClient) KnowledgeBaseChat(query string, KnowledgeBaseName string, history []interface{}) (content interface{}, err error) {
+	ChatHistory := make([]eta_llm_http.HistoryContent, 0)
+	ChatHistory = append(ChatHistory, eta_llm_http.HistoryContent{
+		Content: query,
+		Role:    "user",
+	})
+	for _, historyItem := range history {
+		historyItemMap := historyItem.(map[string]interface{})
+		ChatHistory = append(ChatHistory, eta_llm_http.HistoryContent{
+			Content: historyItemMap["content"].(string),
+			Role:    historyItemMap["role"].(string),
+		})
+	}
+	kbReq := eta_llm_http.KbChatRequest{
+		Query:          query,
+		Mode:           KNOWLEDEG_CHAT_MODE,
+		KbName:         KnowledgeBaseName,
+		History:        ChatHistory,
+		TopK:           3,
+		ScoreThreshold: 0.5,
+		Stream:         true,
+		Model:          ds.LlmModel,
+		Temperature:    0.7,
+		MaxTokens:      0,
+		PromptName:     DEFALUT_PROMPT_NAME,
+		ReturnDirect:   false,
+	}
+
+	body, err := json.Marshal(kbReq)
+	if err != nil {
+		return
+	}
+	resp, err := ds.DoStreamPost(KNOWLEDGE_BASE_CHAT_API, body)
+	if !resp.Success {
+		err = errors.New(resp.Msg)
+		return
+	}
+	if resp.Data != nil {
+		var kbChatRes bus_response.KnowledgeBaseChatResponse
+		err = json.Unmarshal(resp.Data, &kbChatRes)
+		if err != nil {
+			err = errors.New("搜索知识库失败")
+			return
+		}
+		content = kbChatRes
+		return
+	}
+	err = errors.New("搜索知识库失败")
+	return
 }
 
 func (ds *ETALLMClient) SearchKbDocs(query string, KnowledgeBaseName string) (content interface{}, err error) {
-	// 类型断言
-
 	kbReq := eta_llm_http.KbSearchDocsRequest{
 		Query:             query,
 		KnowledgeBaseName: KnowledgeBaseName,
@@ -96,7 +143,14 @@ func (ds *ETALLMClient) DoPost(apiUrl string, body []byte) (baseResp eta_llm_htt
 	}
 	return parseResponse(response)
 }
-
+func (ds *ETALLMClient) DoStreamPost(apiUrl string, body []byte) (baseResp eta_llm_http.BaseResponse, err error) {
+	requestReader := bytes.NewReader(body)
+	response, err := ds.HttpClient.Post(ds.BaseURL+apiUrl, CONTENT_TYPE_JSON, requestReader)
+	if err != nil {
+		return
+	}
+	return parseResponse(response)
+}
 func parseResponse(response *http.Response) (baseResp eta_llm_http.BaseResponse, err error) {
 	defer func() {
 		_ = response.Body.Close()
@@ -115,3 +169,44 @@ func parseResponse(response *http.Response) (baseResp eta_llm_http.BaseResponse,
 	baseResp.Data = bodyBytes
 	return
 }
+func parseStreamResponse(response *http.Response) (contentChan chan string, errChan chan error) {
+	defer func() {
+		_ = response.Body.Close()
+	}()
+	contentChan = make(chan string)
+	errChan = make(chan error)
+	go func() {
+		defer close(contentChan)
+		defer close(errChan)
+		scanner := bufio.NewScanner(response.Body)
+		scanner.Split(bufio.ScanLines)
+
+		for scanner.Scan() {
+			line := scanner.Text()
+			if line == "" {
+				continue
+			}
+
+			var chunk eta_llm_http.ChunkResponse
+			if err := json.Unmarshal([]byte(line), &chunk); err != nil {
+				errChan <- fmt.Errorf("解析 JSON 块失败: %w", err)
+				return
+			}
+
+			// 处理每个 chunk
+			if chunk.Choices != nil && len(chunk.Choices) > 0 {
+				for _, choice := range chunk.Choices {
+					if choice.Delta.Content != "" {
+						contentChan <- choice.Delta.Content
+					}
+				}
+			}
+		}
+
+		if err := scanner.Err(); err != nil {
+			errChan <- fmt.Errorf("读取响应体失败: %w", err)
+			return
+		}
+	}()
+	return contentChan, errChan
+}

+ 2 - 2
utils/llm/eta_llm/eta_llm_http/request.go

@@ -5,7 +5,7 @@ type KbChatRequest struct {
 	Mode           string           `json:"mode"`
 	KbName         string           `json:"kb_name"`
 	TopK           int              `json:"top_k"`
-	ScoreThreshold int              `json:"score_threshold"`
+	ScoreThreshold float32          `json:"score_threshold"`
 	History        []HistoryContent `json:"history"`
 	Stream         bool             `json:"stream"`
 	Model          string           `json:"model"`
@@ -24,7 +24,7 @@ type KbSearchDocsRequest struct {
 	Query             string      `json:"query"`
 	KnowledgeBaseName string      `json:"knowledge_base_name"`
 	TopK              int         `json:"top_k"`
-	ScoreThreshold    float32         `json:"score_threshold"`
+	ScoreThreshold    float32     `json:"score_threshold"`
 	FileName          string      `json:"file_name"`
 	Metadata          interface{} `json:"metadata"`
 }

+ 35 - 0
utils/llm/eta_llm/eta_llm_http/response.go

@@ -10,3 +10,38 @@ type BaseResponse struct {
 }
 
 
+// ChunkResponse 定义流式响应的结构体
+type ChunkResponse struct {
+	ID             string          `json:"id"`
+	Object         string          `json:"object"`
+	Model          string          `json:"model"`
+	Created        int64           `json:"created"`
+	Status         *string         `json:"status"`
+	MessageType    int             `json:"message_type"`
+	MessageID      *string         `json:"message_id"`
+	IsRef          bool            `json:"is_ref"`
+	Docs           []string        `json:"docs"`
+	Choices        []Choice        `json:"choices"`
+}
+// Choice 定义选择的结构体
+type Choice struct {
+	Delta Delta `json:"delta"`
+	Role  string `json:"role"`
+}
+// Delta 定义增量的结构体
+type Delta struct {
+	Content   string `json:"content"`
+	ToolCalls []ToolCall `json:"tool_calls"`
+}
+// ToolCall 定义工具调用的结构体
+type ToolCall struct {
+	ID      string `json:"id"`
+	Type    string `json:"type"`
+	Function Function `json:"function"`
+}
+
+// Function 定义函数的结构体
+type Function struct {
+	Name      string `json:"name"`
+	Arguments json.RawMessage `json:"arguments"`
+}

+ 1 - 1
utils/llm/llm_client.go

@@ -20,6 +20,6 @@ func NewLLMClient(baseURL string, timeout time.Duration) *LLMClient {
 }
 
 type LLMService interface {
-	KnowledgeBaseChat() string
+	KnowledgeBaseChat(query string, KnowledgeBaseName string, history []interface{}) (data interface{}, err error)
 	SearchKbDocs(query string, KnowledgeBaseName string) (data interface{}, err error)
 }

+ 2 - 2
utils/ws/session.go

@@ -1,6 +1,7 @@
 package ws
 
 import (
+	"encoding/json"
 	"errors"
 	"eta/eta_api/utils"
 	"github.com/gorilla/websocket"
@@ -15,7 +16,7 @@ type Session struct {
 	Conn        *websocket.Conn
 	LastActive  time.Time
 	Latency     *LatencyMeasurer
-	History     []string
+	History     []json.RawMessage
 	CloseChan   chan struct{}
 	MessageChan chan *Message
 	mu          sync.RWMutex
@@ -139,7 +140,6 @@ func NewSession(userId int, sessionId string, conn *websocket.Conn) (session *Se
 		UserId:      userId,
 		Id:          sessionId,
 		Conn:        conn,
-		History:     []string{},
 		LastActive:  time.Now(),
 		CloseChan:   make(chan struct{}),
 		MessageChan: make(chan *Message, 10),

+ 3 - 3
utils/ws/session_manager.go

@@ -12,8 +12,8 @@ import (
 const (
 	defaultCheckInterval = 5 * time.Second  // 检测间隔应小于心跳超时时间
 	connectionTimeout    = 20 * time.Second // 客户端超时时间
-	ReadTimeout    = 10 * time.Second // 客户端超时时间
-	writeWaitTimeout = 5 * time.Second
+	ReadTimeout          = 10 * time.Second // 客户端超时时间
+	writeWaitTimeout     = 5 * time.Second
 )
 
 type ConnectionManager struct {
@@ -53,7 +53,7 @@ func (manager *ConnectionManager) HandleMessage(userID int, sessionID string, me
 	}
 
 	// 处理业务逻辑
-	session.History = append(session.History, string(message))
+	session.History = append(session.History, message)
 	response := "Processed: " + string(message)
 	// 更新最后活跃时间
 	session.LastActive = time.Now()