Explorar el Código

Merge remote-tracking branch 'origin/feature/deepseek_rag_1.0' into feature/deepseek_rag_1.0

# Conflicts:
#	routers/commentsRouter.go
Roc hace 2 meses
padre
commit
059b054192

+ 30 - 1
controllers/rag/chat_controller.go

@@ -2,6 +2,7 @@ package rag
 
 import (
 	"eta/eta_api/controllers"
+	"eta/eta_api/models"
 	"eta/eta_api/models/system"
 	"eta/eta_api/services/llm/facade"
 	"eta/eta_api/utils"
@@ -25,6 +26,34 @@ func (cc *ChatController) Prepare() {
 	}
 }
 
+// NewChat @Title 新建对话框
+// @Description 测试知识库问答
+// @Success 101 {object} response.ListResp
+// @router /chat/new_chat [post]
+func (kbctrl *KbController) NewChat() {
+	br := new(models.BaseResponse).Init()
+	defer func() {
+		kbctrl.Data["json"] = br
+		kbctrl.ServeJSON()
+	}()
+	sysUser := kbctrl.SysUser
+	if sysUser == nil {
+		br.Msg = "请登录"
+		br.ErrMsg = "请登录,SysUser Is Empty"
+		br.Ret = 408
+		return
+	}
+	//searchResp, err := facade.LLMKnowledgeBaseChat(req)
+	//if err != nil {
+	//	br.Msg = "知识库问答"
+	//	br.ErrMsg = "知识库问答:" + err.Error()
+	//	return
+	//}
+	br.Ret = 200
+	br.Success = true
+	br.Msg = "创建成功"
+}
+
 // ChatConnect @Title 知识库问答创建对话连接
 // @Description 知识库问答创建对话连接
 // @Success 101 {object} response.ListResp
@@ -63,7 +92,7 @@ func webSocketHandler(w http.ResponseWriter, r *http.Request) (conn *websocket.C
 	// 获取底层 TCP 连接并设置保活
 	if tcpConn, ok := conn.NetConn().(*net.TCPConn); ok {
 		_ = tcpConn.SetKeepAlive(true)
-		_ = tcpConn.SetKeepAlivePeriod(90 * time.Second)
+		_ = tcpConn.SetKeepAlivePeriod(ws.TcpTimeout)
 		utils.FileLog.Info("TCP KeepAlive 已启用")
 	}
 	_ = conn.SetReadDeadline(time.Now().Add(ws.ReadTimeout))

+ 52 - 0
controllers/rag/kb_controller.go

@@ -0,0 +1,52 @@
+package rag
+
+import (
+	"encoding/json"
+	"eta/eta_api/controllers"
+	"eta/eta_api/models"
+	"eta/eta_api/services/llm/facade"
+)
+
+type KbController struct {
+	controllers.BaseAuthController
+}
+
+// SearchDocs  @Title 搜索知识库文档
+// @Description 搜索知识库文档
+// @Success 101 {object} response.ListResp
+// @router /knowledge_base/searchDocs [post]
+func (kbctrl *KbController) SearchDocs() {
+	br := new(models.BaseResponse).Init()
+	defer func() {
+		if br.ErrMsg == "" {
+			br.IsSendEmail = false
+		}
+		kbctrl.Data["json"] = br
+		kbctrl.ServeJSON()
+	}()
+	sysUser := kbctrl.SysUser
+	if sysUser == nil {
+		br.Msg = "请登录"
+		br.ErrMsg = "请登录,SysUser Is Empty"
+		br.Ret = 408
+		return
+	}
+	var req facade.LLMKnowledgeSearch
+	err := json.Unmarshal(kbctrl.Ctx.Input.RequestBody, &req)
+	if err != nil {
+		br.Msg = "参数解析异常!"
+		br.ErrMsg = "参数解析失败,Err:" + err.Error()
+		return
+	}
+	searchResp, err := facade.LLMKnowledgeBaseSearchDocs(req)
+	if err != nil {
+		br.Msg = "搜索知识库失败"
+		br.ErrMsg = "搜索知识库失败:" + err.Error()
+		return
+	}
+	br.Data = searchResp
+	br.Ret = 200
+	br.Success = true
+	br.Msg = "获取成功"
+}
+

+ 1 - 1
main.go

@@ -12,7 +12,7 @@ import (
 	_ "eta/eta_api/routers"
 	"eta/eta_api/services"
 	"eta/eta_api/utils"
-	_ "eta/eta_api/utils/llm/deepseek"
+	_ "eta/eta_api/utils/llm/eta_llm"
 	"github.com/beego/beego/v2/adapter/logs"
 	"github.com/beego/beego/v2/server/web"
 	"github.com/beego/beego/v2/server/web/context"

+ 4 - 0
models/llm/user_llm_chat.go

@@ -0,0 +1,4 @@
+package llm
+
+type UserLlmChat struct {
+}

+ 3 - 0
routers/router.go

@@ -75,6 +75,9 @@ func init() {
 				&rag.WechatPlatformController{},
 				&rag.QuestionController{},
 			),
+			web.NSInclude(
+				&rag.KbController{},
+			),
 		),
 		web.NSNamespace("/banner",
 			web.NSInclude(

+ 11 - 0
services/llm/facade/bus_response/bus_response.go

@@ -0,0 +1,11 @@
+package bus_response
+
+import "eta/eta_api/utils/llm/eta_llm/eta_llm_http"
+
+type KnowledgeBaseChatResponse struct {
+	PageContent string                `json:"page_content"`
+	Metadata    eta_llm_http.Metadata `json:"metadata"`
+	Type        string                `json:"type"`
+	Id          string                `json:"id"`
+	Score       float32               `json:"score"`
+}

+ 8 - 0
services/llm/facade/bus_response/eta_response.go

@@ -0,0 +1,8 @@
+package bus_response
+
+import "eta/eta_api/utils/llm/eta_llm/eta_llm_http"
+
+type SearchDocsEtaResponse struct {
+	Content string
+	Docs    []eta_llm_http.SearchDocsResponse
+}

+ 22 - 1
services/llm/facade/llm_service.go

@@ -1,7 +1,9 @@
 package facade
 
 import (
+	"eta/eta_api/services/llm/facade/bus_response"
 	"eta/eta_api/utils/llm"
+	"eta/eta_api/utils/llm/eta_llm/eta_llm_http"
 	"eta/eta_api/utils/ws"
 	"fmt"
 	"github.com/gorilla/websocket"
@@ -9,15 +11,34 @@ import (
 )
 
 var (
-	deepseekService, _ = llm.GetInstance(llm.LLM_DEEPSEEK)
+	llmService, _ = llm.GetInstance(llm.ETA_LLM_CLIENT)
 )
 
 func generateSessionCode() (code string) {
 	return fmt.Sprintf("%s%s", "llm_session_", uuid.NewUUID().Hex32())
 }
 
+// AddSession 创建会话session
 func AddSession(userId int, conn *websocket.Conn) {
 	sessionId := generateSessionCode()
 	session := ws.NewSession(userId, sessionId, conn)
 	ws.Manager().AddSession(session)
 }
+
+// LLMKnowledgeBaseSearchDocs 搜索知识库
+func LLMKnowledgeBaseSearchDocs(search LLMKnowledgeSearch) (resp bus_response.SearchDocsEtaResponse, err error) {
+	docs, err := llmService.SearchKbDocs(search.Query, search.KnowledgeBaseName)
+	if err != nil {
+		return
+	}
+	for _, doc := range docs.([]eta_llm_http.SearchDocsResponse) {
+		resp.Content = resp.Content + doc.PageContent
+	}
+	resp.Docs = docs.([]eta_llm_http.SearchDocsResponse)
+	return
+}
+
+type LLMKnowledgeSearch struct {
+	Query             string `json:"Query"`
+	KnowledgeBaseName string `json:"KnowledgeBaseName"`
+}

+ 1 - 1
services/ws_service.go

@@ -19,7 +19,7 @@ func WsAuthenticate() web.FilterFunc {
 	return func(ctx *context.Context) {
 		method := ctx.Input.Method()
 		uri := ctx.Input.URI()
-		if method == "POST" || method == "GET" {
+		if method == "GET" {
 			authorization := ctx.Input.Header("authorization")
 			if authorization == "" {
 				authorization = ctx.Input.Header("Authorization")

+ 4 - 2
utils/config.go

@@ -13,7 +13,8 @@ import (
 
 // 大模型配置
 var (
-	DS_LLM_SERVER string //模型服务地址
+	LLM_SERVER string //模型服务地址
+	LLM_MODEL  string
 )
 var (
 	RunMode          string //运行模式
@@ -644,7 +645,8 @@ func init() {
 	ChromePath = config["chrome_path"]
 
 	//模型服务器地址
-	DS_LLM_SERVER = config["llm_server"]
+	LLM_SERVER = config["llm_server"]
+	LLM_MODEL = config["llm_model"]
 	// 初始化ES
 	initEs()
 

+ 0 - 38
utils/llm/deepseek/deekseek.go

@@ -1,38 +0,0 @@
-package deepseek
-
-import (
-	"eta/eta_api/utils"
-	"eta/eta_api/utils/llm"
-	"sync"
-)
-
-var (
-	dsOnce sync.Once
-
-	deepseekClient *DeepSeekClient
-)
-
-type DeepSeekClient struct {
-	*llm.LLMClient
-}
-
-func Getinstance() llm.LLMService {
-	dsOnce.Do(func() {
-		if deepseekClient == nil {
-			deepseekClient = &DeepSeekClient{
-				LLMClient: llm.NewLLMClient(utils.DS_LLM_SERVER, 10),
-			}
-		}
-	})
-	return deepseekClient
-}
-
-func (ds *DeepSeekClient) AskQuestion() string {
-	return ""
-}
-func init() {
-	err := llm.Register(llm.LLM_DEEPSEEK, Getinstance())
-	if err != nil {
-		utils.FileLog.Error("注册deepseek服务失败:", err)
-	}
-}

+ 193 - 0
utils/llm/eta_llm/eta_llm_client.go

@@ -0,0 +1,193 @@
+package eta_llm
+
+import (
+	"bufio"
+	"bytes"
+	"encoding/json"
+	"errors"
+	"eta/eta_api/utils"
+	"eta/eta_api/utils/llm"
+	"eta/eta_api/utils/llm/eta_llm/eta_llm_http"
+	"fmt"
+	"io"
+	"net/http"
+	"strings"
+	"sync"
+)
+
+var (
+	dsOnce sync.Once
+
+	etaLlmClient *ETALLMClient
+)
+
+const (
+	KNOWLEDEG_CHAT_MODE            = "local_kb"
+	DEFALUT_PROMPT_NAME            = "default"
+	CONTENT_TYPE_JSON              = "application/json"
+	KNOWLEDGE_BASE_CHAT_API        = "/chat/kb_chat"
+	KNOWLEDGE_BASE_SEARCH_DOCS_API = "/knowledge_base/search_docs"
+)
+
+type ETALLMClient struct {
+	*llm.LLMClient
+	LlmModel string
+}
+
+func GetInstance() llm.LLMService {
+	dsOnce.Do(func() {
+		if etaLlmClient == nil {
+			etaLlmClient = &ETALLMClient{
+				LLMClient: llm.NewLLMClient(utils.LLM_SERVER, 120),
+				LlmModel:  utils.LLM_MODEL,
+			}
+		}
+	})
+	return etaLlmClient
+}
+
+func (ds *ETALLMClient) KnowledgeBaseChat(query string, KnowledgeBaseName string, history []string) (llmRes *http.Response, err error) {
+	ChatHistory := make([]eta_llm_http.HistoryContent, 0)
+	for _, historyItemStr := range history {
+		str := strings.Split(historyItemStr, "-")
+		historyItem := eta_llm_http.HistoryContent{
+			Role:    str[0],
+			Content: str[1],
+		}
+		ChatHistory = append(ChatHistory, historyItem)
+	}
+	kbReq := eta_llm_http.KbChatRequest{
+		Query:          query,
+		Mode:           KNOWLEDEG_CHAT_MODE,
+		KbName:         KnowledgeBaseName,
+		History:        ChatHistory,
+		TopK:           3,
+		ScoreThreshold: 0.5,
+		Stream:         true,
+		Model:          ds.LlmModel,
+		Temperature:    0.7,
+		MaxTokens:      0,
+		PromptName:     DEFALUT_PROMPT_NAME,
+		ReturnDirect:   false,
+	}
+	fmt.Printf("%v", kbReq.History)
+	body, err := json.Marshal(kbReq)
+	if err != nil {
+		return
+	}
+	return ds.DoStreamPost(KNOWLEDGE_BASE_CHAT_API, body)
+}
+
+func (ds *ETALLMClient) SearchKbDocs(query string, KnowledgeBaseName string) (content interface{}, err error) {
+	kbReq := eta_llm_http.KbSearchDocsRequest{
+		Query:             query,
+		KnowledgeBaseName: KnowledgeBaseName,
+		TopK:              10,
+		ScoreThreshold:    0.5,
+		Metadata:          struct{}{},
+	}
+
+	body, err := json.Marshal(kbReq)
+	if err != nil {
+		return
+	}
+	resp, err := ds.DoPost(KNOWLEDGE_BASE_SEARCH_DOCS_API, body)
+	if !resp.Success {
+		err = errors.New(resp.Msg)
+		return
+	}
+	if resp.Data != nil {
+		var kbSearchRes []eta_llm_http.SearchDocsResponse
+		err = json.Unmarshal(resp.Data, &kbSearchRes)
+		if err != nil {
+			err = errors.New("搜索知识库失败")
+			return
+		}
+		content = kbSearchRes
+		return
+	}
+	err = errors.New("搜索知识库失败")
+	return
+}
+func init() {
+	err := llm.Register(llm.ETA_LLM_CLIENT, GetInstance())
+	if err != nil {
+		utils.FileLog.Error("注册eta_llm_server服务失败:", err)
+	}
+}
+
+func (ds *ETALLMClient) DoPost(apiUrl string, body []byte) (baseResp eta_llm_http.BaseResponse, err error) {
+	requestReader := bytes.NewReader(body)
+	response, err := ds.HttpClient.Post(ds.BaseURL+apiUrl, CONTENT_TYPE_JSON, requestReader)
+	if err != nil {
+		return
+	}
+	return parseResponse(response)
+}
+func (ds *ETALLMClient) DoStreamPost(apiUrl string, body []byte) (baseResp *http.Response, err error) {
+	requestReader := bytes.NewReader(body)
+	return ds.HttpClient.Post(ds.BaseURL+apiUrl, CONTENT_TYPE_JSON, requestReader)
+}
+func parseResponse(response *http.Response) (baseResp eta_llm_http.BaseResponse, err error) {
+	defer func() {
+		_ = response.Body.Close()
+	}()
+	baseResp.Ret = response.StatusCode
+	if response.StatusCode != http.StatusOK {
+		baseResp.Msg = fmt.Sprintf("请求失败,状态码:%d, 状态信息:%s", response.StatusCode, http.StatusText(response.StatusCode))
+		return
+	}
+	bodyBytes, err := io.ReadAll(response.Body)
+	if err != nil {
+		err = fmt.Errorf("读取响应体失败: %w", err)
+		return
+	}
+	baseResp.Success = true
+	baseResp.Data = bodyBytes
+	return
+}
+func ParseStreamResponse(response *http.Response) (contentChan chan string, errChan chan error, closeChan chan struct{}) {
+	contentChan = make(chan string, 10)
+	errChan = make(chan error, 10)
+	closeChan = make(chan struct{})
+	go func() {
+		defer close(contentChan)
+		defer close(errChan)
+		defer close(closeChan)
+		scanner := bufio.NewScanner(response.Body)
+		scanner.Split(bufio.ScanLines)
+		for scanner.Scan() {
+			line := scanner.Text()
+			if line == "" {
+				continue
+			}
+			// 忽略 "ping" 行
+			if strings.HasPrefix(line, ": ping") {
+				continue
+			}
+			// 去除 "data: " 前缀
+			if strings.HasPrefix(line, "data: ") {
+				line = strings.TrimPrefix(line, "data: ")
+			}
+			var chunk eta_llm_http.ChunkResponse
+			if err := json.Unmarshal([]byte(line), &chunk); err != nil {
+				fmt.Println("解析错误的line:" + line)
+				errChan <- fmt.Errorf("解析 JSON 块失败: %w", err)
+				return
+			}
+			// 处理每个 chunk
+			if chunk.Choices != nil && len(chunk.Choices) > 0 {
+				for _, choice := range chunk.Choices {
+					if choice.Delta.Content != "" {
+						contentChan <- choice.Delta.Content
+					}
+				}
+			}
+		}
+		if err := scanner.Err(); err != nil {
+			errChan <- fmt.Errorf("读取响应体失败: %w", err)
+			return
+		}
+	}()
+	return
+}

+ 30 - 0
utils/llm/eta_llm/eta_llm_http/request.go

@@ -0,0 +1,30 @@
+package eta_llm_http
+
+type KbChatRequest struct {
+	Query          string           `json:"query"`
+	Mode           string           `json:"mode"`
+	KbName         string           `json:"kb_name"`
+	TopK           int              `json:"top_k"`
+	ScoreThreshold float32          `json:"score_threshold"`
+	History        []HistoryContent `json:"history"`
+	Stream         bool             `json:"stream"`
+	Model          string           `json:"model"`
+	Temperature    float32          `json:"temperature"`
+	MaxTokens      int              `json:"max_tokens"`
+	PromptName     string           `json:"prompt_name"`
+	ReturnDirect   bool             `json:"return_direct"`
+}
+
+type HistoryContent struct {
+	Content string `json:"content"`
+	Role    string `json:"role"`
+}
+
+type KbSearchDocsRequest struct {
+	Query             string      `json:"query"`
+	KnowledgeBaseName string      `json:"knowledge_base_name"`
+	TopK              int         `json:"top_k"`
+	ScoreThreshold    float32     `json:"score_threshold"`
+	FileName          string      `json:"file_name"`
+	Metadata          interface{} `json:"metadata"`
+}

+ 63 - 0
utils/llm/eta_llm/eta_llm_http/response.go

@@ -0,0 +1,63 @@
+package eta_llm_http
+
+import "encoding/json"
+
+type BaseResponse struct {
+	Ret     int             `json:"ret"`
+	Msg     string          `json:"msg"`
+	Success bool            `json:"success"`
+	Data    json.RawMessage `json:"data"`
+}
+type SteamResponse struct {
+	Data    ChunkResponse `json:"data"`
+}
+// ChunkResponse 定义流式响应的结构体
+type ChunkResponse struct {
+	ID          string   `json:"id"`
+	Object      string   `json:"object"`
+	Model       string   `json:"model"`
+	Created     int64    `json:"created"`
+	Status      *string  `json:"status"`
+	MessageType int      `json:"message_type"`
+	MessageID   *string  `json:"message_id"`
+	IsRef       bool     `json:"is_ref"`
+	Docs        []string `json:"docs"`
+	Choices     []Choice `json:"choices"`
+}
+
+// Choice 定义选择的结构体
+type Choice struct {
+	Delta Delta  `json:"delta"`
+	Role  string `json:"role"`
+}
+
+// Delta 定义增量的结构体
+type Delta struct {
+	Content   string     `json:"content"`
+	ToolCalls []ToolCall `json:"tool_calls"`
+}
+
+// ToolCall 定义工具调用的结构体
+type ToolCall struct {
+	ID       string   `json:"id"`
+	Type     string   `json:"type"`
+	Function Function `json:"function"`
+}
+
+// Function 定义函数的结构体
+type Function struct {
+	Name      string          `json:"name"`
+	Arguments json.RawMessage `json:"arguments"`
+}
+
+type SearchDocsResponse struct {
+	PageContent string   `json:"page_content"`
+	Metadata    Metadata `json:"metadata"`
+	Type        string   `json:"type"`
+	Id          string   `json:"id"`
+	Score       float32  `json:"score"`
+}
+type Metadata struct {
+	Source string `json:"source"`
+	Id     string `json:"id"`
+}

+ 2 - 1
utils/llm/llm_client.go

@@ -20,5 +20,6 @@ func NewLLMClient(baseURL string, timeout time.Duration) *LLMClient {
 }
 
 type LLMService interface {
-	AskQuestion() string
+	KnowledgeBaseChat(query string, KnowledgeBaseName string, history []string) (llmRes *http.Response, err error)
+	SearchKbDocs(query string, KnowledgeBaseName string) (data interface{}, err error)
 }

+ 2 - 3
utils/llm/llm_factory.go

@@ -5,11 +5,11 @@ import (
 )
 
 var (
-	llmInstanceMap map[string]LLMService = make(map[string]LLMService)
+	llmInstanceMap = make(map[string]LLMService)
 )
 
 const (
-	LLM_DEEPSEEK = "deepseek"
+	ETA_LLM_CLIENT = "eta_llm"
 )
 
 func Register(name string, llmClient LLMService) (err error) {
@@ -36,4 +36,3 @@ func GetInstance(name string) (llmClient LLMService, err error) {
 	llmClient = llmInstanceMap[name]
 	return
 }
-

+ 2 - 2
utils/ws/latency_measurer.go

@@ -8,7 +8,7 @@ import (
 )
 
 const (
-	maxMessageSize   = 1024 * 1024 // 1MB
+	maxMessageSize   = 1024 * 1024 * 10 // 1MB
 	basePingInterval = 5 * time.Second
 	maxPingInterval  = 120 * time.Second
 	minPingInterval  = 15 * time.Second
@@ -39,7 +39,7 @@ func (lm *LatencyMeasurer) SendPing(conn *websocket.Conn) error {
 		return errors.New("connection closed")
 	}
 	// 发送Ping消息
-	err := conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(5*time.Second))
+	err := conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(writeWaitTimeout))
 	if err != nil {
 		return err
 	}

+ 4 - 2
utils/ws/limiter.go

@@ -21,6 +21,8 @@ const (
 	QA_LIMITER          = "qaLimiter"
 	LIMITER_KEY         = "llm_chat_key_user_%d"
 	CONNECT_LIMITER_KEY = "llm_chat_connect_key_user_%d"
+
+	RATE_LIMTER_TIME	=60*time.Second
 )
 
 type RateLimiter struct {
@@ -46,7 +48,7 @@ func (qalm *LimiterManger) GetLimiter(token string) *RateLimiter {
 
 	// 创建一个新的限流器,例如每10秒1个请求
 	limiter := &RateLimiter{
-		Limiter: rate.NewLimiter(rate.Every(10*time.Second), 1),
+		Limiter: rate.NewLimiter(rate.Every(RATE_LIMTER_TIME), 1),
 	}
 	qalm.limiterMap[token] = limiter
 	return limiter
@@ -57,7 +59,7 @@ func (qalm *LimiterManger) Allow(token string) bool {
 		limiter.LastRequest = time.Now()
 		return limiter.Allow()
 	}
-	if time.Now().Sub(limiter.LastRequest) < 10*time.Second {
+	if time.Now().Sub(limiter.LastRequest) < RATE_LIMTER_TIME {
 		return false
 	}
 	limiter.LastRequest = time.Now()

+ 22 - 14
utils/ws/session.go

@@ -3,6 +3,7 @@ package ws
 import (
 	"errors"
 	"eta/eta_api/utils"
+	"fmt"
 	"github.com/gorilla/websocket"
 	"sync"
 	"time"
@@ -17,23 +18,28 @@ type Session struct {
 	Latency     *LatencyMeasurer
 	History     []string
 	CloseChan   chan struct{}
-	MessageChan chan *Message
+	MessageChan chan string
 	mu          sync.RWMutex
 	sessionOnce sync.Once
 }
 type Message struct {
-	MessageType string
-	message     []byte
+	KbName     string   `json:"KbName"`
+	Query      string   `json:"Query"`
+	LastTopics []string `json:"LastTopics"`
 }
 
 // readPump 处理读操作
 func (s *Session) readPump() {
-	defer manager.RemoveSession(s.Id)
+	defer func() {
+		fmt.Printf("读进程session %s closed", s.Id)
+		manager.RemoveSession(s.Id)
+	}()
 	s.Conn.SetReadLimit(maxMessageSize)
-	_ = s.Conn.SetReadDeadline(time.Now().Add(60 * time.Second))
+	_ = s.Conn.SetReadDeadline(time.Now().Add(ReadTimeout))
 	for {
 		_, message, err := s.Conn.ReadMessage()
 		if err != nil {
+			fmt.Printf("websocket 错误关闭 %s closed", err.Error())
 			handleCloseError(err)
 			return
 		}
@@ -42,10 +48,9 @@ func (s *Session) readPump() {
 		// 处理消息
 		if err = manager.HandleMessage(s.UserId, s.Id, message); err != nil {
 			//写应答
-			_ = s.writeWithTimeout(&Message{
-				MessageType: "error",
-				message:     []byte(err.Error()),
-			})
+
+			_ = s.writeWithTimeout(err.Error())
+
 		}
 	}
 }
@@ -67,7 +72,7 @@ func (s *Session) Close() {
 }
 
 // 带超时的安全写入
-func (s *Session) writeWithTimeout(msg *Message) error {
+func (s *Session) writeWithTimeout(msg string) error {
 	s.mu.Lock()
 	defer s.mu.Unlock()
 	if s.Conn == nil {
@@ -77,13 +82,14 @@ func (s *Session) writeWithTimeout(msg *Message) error {
 	if err := s.Conn.SetWriteDeadline(time.Now().Add(writeWaitTimeout)); err != nil {
 		return err
 	}
-	return s.Conn.WriteMessage(websocket.TextMessage, msg.message)
+	return s.Conn.WriteMessage(websocket.TextMessage, []byte(msg))
 }
 
 // writePump 处理写操作
 func (s *Session) writePump() {
 	ticker := time.NewTicker(basePingInterval)
 	defer func() {
+		fmt.Printf("写继进程:session %s closed", s.Id)
 		manager.RemoveSession(s.Id)
 		ticker.Stop()
 	}()
@@ -106,12 +112,15 @@ func handleCloseError(err error) {
 	if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) {
 		var wsErr *websocket.CloseError
 		if !errors.As(err, &wsErr) {
+			fmt.Printf("websocket未知错误 %s", err.Error())
 			utils.FileLog.Error("未知错误 %s", err.Error())
 		} else {
 			switch wsErr.Code {
 			case websocket.CloseNormalClosure:
+				fmt.Println("websocket正常关闭连接")
 				utils.FileLog.Info("正常关闭连接")
 			default:
+				fmt.Printf("websocket关闭代码 %d:%s", wsErr.Code, wsErr.Text)
 				utils.FileLog.Error("关闭代码:%d:%s", wsErr.Code, wsErr.Text)
 			}
 		}
@@ -126,7 +135,7 @@ func (s *Session) forceClose() {
 	// 发送关闭帧
 	_ = s.Conn.WriteControl(websocket.CloseMessage,
 		websocket.FormatCloseMessage(websocket.ClosePolicyViolation, "heartbeat failed"),
-		time.Now().Add(5*time.Second))
+		time.Now().Add(writeWaitTimeout))
 	_ = s.Conn.Close()
 	s.Conn = nil // 标记连接已关闭
 	utils.FileLog.Info("连接已强制关闭",
@@ -139,10 +148,9 @@ func NewSession(userId int, sessionId string, conn *websocket.Conn) (session *Se
 		UserId:      userId,
 		Id:          sessionId,
 		Conn:        conn,
-		History:     []string{},
 		LastActive:  time.Now(),
 		CloseChan:   make(chan struct{}),
-		MessageChan: make(chan *Message, 10),
+		MessageChan: make(chan string, 10),
 	}
 	session.Latency = SetupLatencyMeasurement(conn)
 	go session.readPump()

+ 62 - 17
utils/ws/session_manager.go

@@ -1,19 +1,28 @@
 package ws
 
 import (
+	"encoding/json"
 	"errors"
 	"eta/eta_api/utils"
+	"eta/eta_api/utils/llm"
+	"eta/eta_api/utils/llm/eta_llm"
 	"fmt"
 	"github.com/gorilla/websocket"
+	"net/http"
 	"sync"
 	"time"
 )
 
+var (
+	llmService, _ = llm.GetInstance(llm.ETA_LLM_CLIENT)
+)
+
 const (
-	defaultCheckInterval = 5 * time.Second  // 检测间隔应小于心跳超时时间
-	connectionTimeout    = 20 * time.Second // 客户端超时时间
-	ReadTimeout    = 10 * time.Second // 客户端超时时间
-	writeWaitTimeout = 5 * time.Second
+	defaultCheckInterval = 2 * time.Minute  // 检测间隔应小于心跳超时时间
+	connectionTimeout    = 10 * time.Minute // 客户端超时时间
+	TcpTimeout           = 20 * time.Minute // TCP超时时间,保底关闭,覆盖会话超时时间
+	ReadTimeout          = 15 * time.Minute // 读取超时时间,保底关闭,覆盖会话超时时间
+	writeWaitTimeout     = 60 * time.Second //写入超时时间
 )
 
 type ConnectionManager struct {
@@ -51,14 +60,54 @@ func (manager *ConnectionManager) HandleMessage(userID int, sessionID string, me
 	if !exists {
 		return errors.New("session not found")
 	}
-
+	var userMessage Message
+	err := json.Unmarshal(message, &userMessage)
+	if err != nil {
+		return errors.New("消息格式错误")
+	}
 	// 处理业务逻辑
-	session.History = append(session.History, string(message))
-	response := "Processed: " + string(message)
+	session.History = append(session.History, userMessage.LastTopics...)
+	resp, err := llmService.KnowledgeBaseChat(userMessage.Query, userMessage.KbName, session.History)
+	defer func() {
+		_ = resp.Body.Close()
+	}()
+	if err != nil {
+		err = errors.New(fmt.Sprintf("知识库问答失败: httpCode:%d,错误信息:%s", resp.StatusCode, http.StatusText(resp.StatusCode)))
+		return err
+	}
+	if resp.StatusCode != http.StatusOK {
+		err = errors.New(fmt.Sprintf("知识库问答失败: httpCode:%d,错误信息:%s", resp.StatusCode, http.StatusText(resp.StatusCode)))
+		return err
+	}
+	// 解析流式响应
+	contentChan, errChan, closeChan := eta_llm.ParseStreamResponse(resp)
+	// 处理流式数据并发送到 WebSocket
+	for {
+		select {
+		case content, ok := <-contentChan:
+			if !ok {
+				err = errors.New("未知的错误异常")
+				return err
+			}
+			session.UpdateActivity()
+			// 发送消息到 WebSocket
+			_ = session.Conn.WriteMessage(websocket.TextMessage, []byte(content))
+		case chanErr, ok := <-errChan:
+			if !ok {
+				err = errors.New("未知的错误异常")
+			} else {
+				err = errors.New(chanErr.Error())
+			}
+			// 发送错误消息到 WebSocket
+			return err
+		case <-closeChan:
+			_ = session.Conn.WriteMessage(websocket.TextMessage, []byte("<EOF>"))
+			return nil
+		}
+	}
 	// 更新最后活跃时间
-	session.LastActive = time.Now()
 	// 发送响应
-	return session.Conn.WriteMessage(websocket.TextMessage, []byte(response))
+	//return session.Conn.WriteMessage(websocket.TextMessage, []byte(response))
 }
 
 // AddSession Add 添加一个新的会话
@@ -71,6 +120,7 @@ func (manager *ConnectionManager) GetSessionId(userId int, sessionId string) (se
 
 // RemoveSession Remove 移除一个会话
 func (manager *ConnectionManager) RemoveSession(sessionCode string) {
+	fmt.Printf("移除会话: SessionID=%s, UserID=%s", sessionCode, sessionCode)
 	manager.Sessions.Delete(sessionCode)
 }
 
@@ -85,12 +135,10 @@ func (manager *ConnectionManager) GetSession(sessionCode string) (session *Sessi
 
 // CheckAll 批量检测所有连接
 func (manager *ConnectionManager) CheckAll() {
-	n := 0
 	manager.Sessions.Range(func(key, value interface{}) bool {
-		n++
 		session := value.(*Session)
 		// 判断超时
-		if time.Since(session.LastActive) > connectionTimeout {
+		if time.Since(session.LastActive) > 2*connectionTimeout {
 			fmt.Printf("连接超时关闭: SessionID=%s, UserID=%s", session.Id, session.UserId)
 			utils.FileLog.Warn("连接超时关闭: SessionID=%s, UserID=%s", session.Id, session.UserId)
 			session.Close()
@@ -99,18 +147,17 @@ func (manager *ConnectionManager) CheckAll() {
 		// 发送心跳
 		go func(s *Session) {
 			err := s.Conn.WriteControl(websocket.PingMessage,
-				nil, time.Now().Add(5*time.Second))
+				nil, time.Now().Add(writeWaitTimeout))
 			if err != nil {
 				fmt.Printf("心跳发送失败: SessionID=%s, Error=%v", s.Id, err)
 				utils.FileLog.Warn("心跳发送失败: SessionID=%s, Error=%v",
 					s.Id, err)
+				fmt.Println("心跳无响应,退出请求")
 				session.Close()
 			}
 		}(session)
-		fmt.Println("当前连接数:", n)
 		return true
 	})
-	fmt.Println("当前连接数:", n)
 }
 
 // Start 启动心跳检测
@@ -119,10 +166,8 @@ func (manager *ConnectionManager) Start() {
 	for {
 		select {
 		case <-manager.ticker.C:
-			fmt.Printf("开始检测连接超时")
 			manager.CheckAll()
 		case <-manager.stopChan:
-			fmt.Printf("退出检测")
 			return
 		}
 	}