kobe6258 1 lună în urmă
părinte
comite
f05f8eccaf

+ 38 - 40
controllers/rag/chat_controller.go

@@ -1,9 +1,7 @@
 package rag
 
 import (
-	"encoding/json"
 	"eta/eta_api/controllers"
-	"eta/eta_api/models"
 	"eta/eta_api/models/system"
 	"eta/eta_api/services/llm/facade"
 	"eta/eta_api/utils"
@@ -27,44 +25,44 @@ func (cc *ChatController) Prepare() {
 	}
 }
 
-// ChatTest @Title 测试知识库问答
-// @Description 测试知识库问答
-// @Success 101 {object} response.ListResp
-// @router /chat/chat_test [post]
-func (kbctrl *KbController) ChatTest() {
-	br := new(models.BaseResponse).Init()
-	defer func() {
-		if br.ErrMsg == "" {
-			br.IsSendEmail = false
-		}
-		kbctrl.Data["json"] = br
-		kbctrl.ServeJSON()
-	}()
-	sysUser := kbctrl.SysUser
-	if sysUser == nil {
-		br.Msg = "请登录"
-		br.ErrMsg = "请登录,SysUser Is Empty"
-		br.Ret = 408
-		return
-	}
-	var req facade.LLMKnowledgeChat
-	err := json.Unmarshal(kbctrl.Ctx.Input.RequestBody, &req)
-	if err != nil {
-		br.Msg = "参数解析异常!"
-		br.ErrMsg = "参数解析失败,Err:" + err.Error()
-		return
-	}
-	searchResp, err := facade.LLMKnowledgeBaseChat(req)
-	if err != nil {
-		br.Msg = "知识库问答"
-		br.ErrMsg = "知识库问答:" + err.Error()
-		return
-	}
-	br.Data = searchResp
-	br.Ret = 200
-	br.Success = true
-	br.Msg = "知识库问答成功"
-}
+//// ChatTest @Title 测试知识库问答
+//// @Description 测试知识库问答
+//// @Success 101 {object} response.ListResp
+//// @router /chat/chat_test [post]
+//func (kbctrl *KbController) ChatTest() {
+//	br := new(models.BaseResponse).Init()
+//	defer func() {
+//		if br.ErrMsg == "" {
+//			br.IsSendEmail = false
+//		}
+//		kbctrl.Data["json"] = br
+//		kbctrl.ServeJSON()
+//	}()
+//	sysUser := kbctrl.SysUser
+//	if sysUser == nil {
+//		br.Msg = "请登录"
+//		br.ErrMsg = "请登录,SysUser Is Empty"
+//		br.Ret = 408
+//		return
+//	}
+//	var req facade.LLMKnowledgeChat
+//	err := json.Unmarshal(kbctrl.Ctx.Input.RequestBody, &req)
+//	if err != nil {
+//		br.Msg = "参数解析异常!"
+//		br.ErrMsg = "参数解析失败,Err:" + err.Error()
+//		return
+//	}
+//	searchResp, err := facade.LLMKnowledgeBaseChat(req)
+//	if err != nil {
+//		br.Msg = "知识库问答"
+//		br.ErrMsg = "知识库问答:" + err.Error()
+//		return
+//	}
+//	br.Data = searchResp
+//	br.Ret = 200
+//	br.Success = true
+//	br.Msg = "知识库问答成功"
+//}
 
 // ChatConnect @Title 知识库问答创建对话连接
 // @Description 知识库问答创建对话连接

+ 0 - 9
routers/commentsRouter.go

@@ -8530,15 +8530,6 @@ func init() {
             Filters: nil,
             Params: nil})
 
-    beego.GlobalControllerRouter["eta/eta_api/controllers/rag:KbController"] = append(beego.GlobalControllerRouter["eta/eta_api/controllers/rag:KbController"],
-        beego.ControllerComments{
-            Method: "ChatTest",
-            Router: `/chat/chat_test`,
-            AllowHTTPMethods: []string{"post"},
-            MethodParams: param.Make(),
-            Filters: nil,
-            Params: nil})
-
     beego.GlobalControllerRouter["eta/eta_api/controllers/rag:KbController"] = append(beego.GlobalControllerRouter["eta/eta_api/controllers/rag:KbController"],
         beego.ControllerComments{
             Method: "SearchDocs",

+ 7 - 17
services/llm/facade/bus_response/bus_response.go

@@ -1,21 +1,11 @@
 package bus_response
 
-type SearchDocsResponse struct {
-	PageContent string   `json:"page_content"`
-	Metadata    Metadata `json:"metadata"`
-	Type        string   `json:"type"`
-	Id          string   `json:"id"`
-	Score       float32  `json:"score"`
-}
-type KnowledgeBaseChatResponse struct {
-	PageContent string   `json:"page_content"`
-	Metadata    Metadata `json:"metadata"`
-	Type        string   `json:"type"`
-	Id          string   `json:"id"`
-	Score       float32  `json:"score"`
-}
+import "eta/eta_api/utils/llm/eta_llm/eta_llm_http"
 
-type Metadata struct {
-	Source string `json:"source"`
-	Id     string `json:"id"`
+type KnowledgeBaseChatResponse struct {
+	PageContent string                `json:"page_content"`
+	Metadata    eta_llm_http.Metadata `json:"metadata"`
+	Type        string                `json:"type"`
+	Id          string                `json:"id"`
+	Score       float32               `json:"score"`
 }

+ 3 - 1
services/llm/facade/bus_response/eta_response.go

@@ -1,6 +1,8 @@
 package bus_response
 
+import "eta/eta_api/utils/llm/eta_llm/eta_llm_http"
+
 type SearchDocsEtaResponse struct {
 	Content string
-	Docs    []SearchDocsResponse
+	Docs    []eta_llm_http.SearchDocsResponse
 }

+ 3 - 17
services/llm/facade/llm_service.go

@@ -3,6 +3,7 @@ package facade
 import (
 	"eta/eta_api/services/llm/facade/bus_response"
 	"eta/eta_api/utils/llm"
+	"eta/eta_api/utils/llm/eta_llm/eta_llm_http"
 	"eta/eta_api/utils/ws"
 	"fmt"
 	"github.com/gorilla/websocket"
@@ -28,21 +29,10 @@ func LLMKnowledgeBaseSearchDocs(search LLMKnowledgeSearch) (resp bus_response.Se
 	if err != nil {
 		return
 	}
-	for _, doc := range docs.([]bus_response.SearchDocsResponse) {
+	for _, doc := range docs.([]eta_llm_http.SearchDocsResponse) {
 		resp.Content = resp.Content + doc.PageContent
 	}
-	resp.Docs = docs.([]bus_response.SearchDocsResponse)
-	return
-}
-func LLMKnowledgeBaseChat(chat LLMKnowledgeChat) (resp bus_response.SearchDocsEtaResponse, err error) {
-	docs, err := llmService.KnowledgeBaseChat(chat.Query, chat.KbName, nil)
-	if err != nil {
-		return
-	}
-	for _, doc := range docs.([]bus_response.SearchDocsResponse) {
-		resp.Content = resp.Content + doc.PageContent
-	}
-	resp.Docs = docs.([]bus_response.SearchDocsResponse)
+	resp.Docs = docs.([]eta_llm_http.SearchDocsResponse)
 	return
 }
 
@@ -50,7 +40,3 @@ type LLMKnowledgeSearch struct {
 	Query             string `json:"Query"`
 	KnowledgeBaseName string `json:"KnowledgeBaseName"`
 }
-type LLMKnowledgeChat struct {
-	Query  string `json:"Query"`
-	KbName string `json:"KbName"`
-}

+ 13 - 34
utils/llm/eta_llm/eta_llm_client.go

@@ -5,7 +5,6 @@ import (
 	"bytes"
 	"encoding/json"
 	"errors"
-	"eta/eta_api/services/llm/facade/bus_response"
 	"eta/eta_api/utils"
 	"eta/eta_api/utils/llm"
 	"eta/eta_api/utils/llm/eta_llm/eta_llm_http"
@@ -46,7 +45,7 @@ func GetInstance() llm.LLMService {
 	return etaLlmClient
 }
 
-func (ds *ETALLMClient) KnowledgeBaseChat(query string, KnowledgeBaseName string, history []interface{}) (content interface{}, err error) {
+func (ds *ETALLMClient) KnowledgeBaseChat(query string, KnowledgeBaseName string, history []interface{}) (llmRes *http.Response, err error) {
 	ChatHistory := make([]eta_llm_http.HistoryContent, 0)
 	ChatHistory = append(ChatHistory, eta_llm_http.HistoryContent{
 		Content: query,
@@ -78,23 +77,7 @@ func (ds *ETALLMClient) KnowledgeBaseChat(query string, KnowledgeBaseName string
 	if err != nil {
 		return
 	}
-	resp, err := ds.DoStreamPost(KNOWLEDGE_BASE_CHAT_API, body)
-	if !resp.Success {
-		err = errors.New(resp.Msg)
-		return
-	}
-	if resp.Data != nil {
-		var kbChatRes bus_response.KnowledgeBaseChatResponse
-		err = json.Unmarshal(resp.Data, &kbChatRes)
-		if err != nil {
-			err = errors.New("搜索知识库失败")
-			return
-		}
-		content = kbChatRes
-		return
-	}
-	err = errors.New("搜索知识库失败")
-	return
+	return ds.DoStreamPost(KNOWLEDGE_BASE_CHAT_API, body)
 }
 
 func (ds *ETALLMClient) SearchKbDocs(query string, KnowledgeBaseName string) (content interface{}, err error) {
@@ -116,7 +99,7 @@ func (ds *ETALLMClient) SearchKbDocs(query string, KnowledgeBaseName string) (co
 		return
 	}
 	if resp.Data != nil {
-		var kbSearchRes []bus_response.SearchDocsResponse
+		var kbSearchRes []eta_llm_http.SearchDocsResponse
 		err = json.Unmarshal(resp.Data, &kbSearchRes)
 		if err != nil {
 			err = errors.New("搜索知识库失败")
@@ -143,13 +126,9 @@ func (ds *ETALLMClient) DoPost(apiUrl string, body []byte) (baseResp eta_llm_htt
 	}
 	return parseResponse(response)
 }
-func (ds *ETALLMClient) DoStreamPost(apiUrl string, body []byte) (baseResp eta_llm_http.BaseResponse, err error) {
+func (ds *ETALLMClient) DoStreamPost(apiUrl string, body []byte) (baseResp *http.Response, err error) {
 	requestReader := bytes.NewReader(body)
-	response, err := ds.HttpClient.Post(ds.BaseURL+apiUrl, CONTENT_TYPE_JSON, requestReader)
-	if err != nil {
-		return
-	}
-	return parseResponse(response)
+	return ds.HttpClient.Post(ds.BaseURL+apiUrl, CONTENT_TYPE_JSON, requestReader)
 }
 func parseResponse(response *http.Response) (baseResp eta_llm_http.BaseResponse, err error) {
 	defer func() {
@@ -169,30 +148,31 @@ func parseResponse(response *http.Response) (baseResp eta_llm_http.BaseResponse,
 	baseResp.Data = bodyBytes
 	return
 }
-func parseStreamResponse(response *http.Response) (contentChan chan string, errChan chan error) {
+func ParseStreamResponse(response *http.Response) (contentChan chan string, errChan chan error, closeChan chan struct{}) {
 	defer func() {
 		_ = response.Body.Close()
+		defer close(contentChan)
+		defer close(errChan)
+		defer close(closeChan)
 	}()
-	contentChan = make(chan string)
-	errChan = make(chan error)
+	contentChan = make(chan string, 10)
+	errChan = make(chan error, 10)
+	closeChan = make(chan struct{})
 	go func() {
 		defer close(contentChan)
 		defer close(errChan)
 		scanner := bufio.NewScanner(response.Body)
 		scanner.Split(bufio.ScanLines)
-
 		for scanner.Scan() {
 			line := scanner.Text()
 			if line == "" {
 				continue
 			}
-
 			var chunk eta_llm_http.ChunkResponse
 			if err := json.Unmarshal([]byte(line), &chunk); err != nil {
 				errChan <- fmt.Errorf("解析 JSON 块失败: %w", err)
 				return
 			}
-
 			// 处理每个 chunk
 			if chunk.Choices != nil && len(chunk.Choices) > 0 {
 				for _, choice := range chunk.Choices {
@@ -202,11 +182,10 @@ func parseStreamResponse(response *http.Response) (contentChan chan string, errC
 				}
 			}
 		}
-
 		if err := scanner.Err(); err != nil {
 			errChan <- fmt.Errorf("读取响应体失败: %w", err)
 			return
 		}
 	}()
-	return contentChan, errChan
+	return
 }

+ 31 - 17
utils/llm/eta_llm/eta_llm_http/response.go

@@ -9,39 +9,53 @@ type BaseResponse struct {
 	Data    json.RawMessage `json:"data"`
 }
 
-
 // ChunkResponse 定义流式响应的结构体
 type ChunkResponse struct {
-	ID             string          `json:"id"`
-	Object         string          `json:"object"`
-	Model          string          `json:"model"`
-	Created        int64           `json:"created"`
-	Status         *string         `json:"status"`
-	MessageType    int             `json:"message_type"`
-	MessageID      *string         `json:"message_id"`
-	IsRef          bool            `json:"is_ref"`
-	Docs           []string        `json:"docs"`
-	Choices        []Choice        `json:"choices"`
+	ID          string   `json:"id"`
+	Object      string   `json:"object"`
+	Model       string   `json:"model"`
+	Created     int64    `json:"created"`
+	Status      *string  `json:"status"`
+	MessageType int      `json:"message_type"`
+	MessageID   *string  `json:"message_id"`
+	IsRef       bool     `json:"is_ref"`
+	Docs        []string `json:"docs"`
+	Choices     []Choice `json:"choices"`
 }
+
 // Choice 定义选择的结构体
 type Choice struct {
-	Delta Delta `json:"delta"`
+	Delta Delta  `json:"delta"`
 	Role  string `json:"role"`
 }
+
 // Delta 定义增量的结构体
 type Delta struct {
-	Content   string `json:"content"`
+	Content   string     `json:"content"`
 	ToolCalls []ToolCall `json:"tool_calls"`
 }
+
 // ToolCall 定义工具调用的结构体
 type ToolCall struct {
-	ID      string `json:"id"`
-	Type    string `json:"type"`
+	ID       string   `json:"id"`
+	Type     string   `json:"type"`
 	Function Function `json:"function"`
 }
 
 // Function 定义函数的结构体
 type Function struct {
-	Name      string `json:"name"`
+	Name      string          `json:"name"`
 	Arguments json.RawMessage `json:"arguments"`
-}
+}
+
+type SearchDocsResponse struct {
+	PageContent string   `json:"page_content"`
+	Metadata    Metadata `json:"metadata"`
+	Type        string   `json:"type"`
+	Id          string   `json:"id"`
+	Score       float32  `json:"score"`
+}
+type Metadata struct {
+	Source string `json:"source"`
+	Id     string `json:"id"`
+}

+ 1 - 1
utils/llm/llm_client.go

@@ -20,6 +20,6 @@ func NewLLMClient(baseURL string, timeout time.Duration) *LLMClient {
 }
 
 type LLMService interface {
-	KnowledgeBaseChat(query string, KnowledgeBaseName string, history []interface{}) (data interface{}, err error)
+	KnowledgeBaseChat(query string, KnowledgeBaseName string, history []interface{}) (llmRes *http.Response, err error)
 	SearchKbDocs(query string, KnowledgeBaseName string) (data interface{}, err error)
 }

+ 41 - 3
utils/ws/session_manager.go

@@ -3,12 +3,19 @@ package ws
 import (
 	"errors"
 	"eta/eta_api/utils"
+	"eta/eta_api/utils/llm"
+	"eta/eta_api/utils/llm/eta_llm"
 	"fmt"
 	"github.com/gorilla/websocket"
+	"net/http"
 	"sync"
 	"time"
 )
 
+var (
+	llmService, _ = llm.GetInstance(llm.ETA_LLM_CLIENT)
+)
+
 const (
 	defaultCheckInterval = 5 * time.Second  // 检测间隔应小于心跳超时时间
 	connectionTimeout    = 20 * time.Second // 客户端超时时间
@@ -54,11 +61,42 @@ func (manager *ConnectionManager) HandleMessage(userID int, sessionID string, me
 
 	// 处理业务逻辑
 	session.History = append(session.History, message)
-	response := "Processed: " + string(message)
+	resp, err := llmService.KnowledgeBaseChat("", "hz", nil)
+	if err != nil {
+		err = errors.New(fmt.Sprintf("知识库问答失败: httpCode:%d,错误信息:%s", resp.StatusCode, http.StatusText(resp.StatusCode)))
+		return err
+	}
+	if resp.StatusCode != http.StatusOK {
+		err = errors.New(fmt.Sprintf("知识库问答失败: httpCode:%d,错误信息:%s", resp.StatusCode, http.StatusText(resp.StatusCode)))
+		return err
+	}
+	// 解析流式响应
+	contentChan, errChan, closeChan := eta_llm.ParseStreamResponse(resp)
+	// 处理流式数据并发送到 WebSocket
+	for {
+		select {
+		case content, ok := <-contentChan:
+			if !ok {
+				err = errors.New("未知的错误异常")
+				return err
+			}
+			// 发送消息到 WebSocket
+			_ = session.Conn.WriteMessage(websocket.TextMessage, []byte(content))
+		case chanErr, ok := <-errChan:
+			if !ok {
+				err = errors.New("未知的错误异常")
+			} else {
+				err = errors.New(chanErr.Error())
+			}
+			// 发送错误消息到 WebSocket
+			return err
+		case <-closeChan:
+			return nil
+		}
+	}
 	// 更新最后活跃时间
-	session.LastActive = time.Now()
 	// 发送响应
-	return session.Conn.WriteMessage(websocket.TextMessage, []byte(response))
+	//return session.Conn.WriteMessage(websocket.TextMessage, []byte(response))
 }
 
 // AddSession Add 添加一个新的会话