kobe6258 1 week ago
parent
commit
1749c2781f

+ 20 - 17
utils/llm/eta_llm/eta_llm_client.go

@@ -11,6 +11,7 @@ import (
 	"fmt"
 	"io"
 	"net/http"
+	"strings"
 	"sync"
 )
 
@@ -37,7 +38,7 @@ func GetInstance() llm.LLMService {
 	dsOnce.Do(func() {
 		if etaLlmClient == nil {
 			etaLlmClient = &ETALLMClient{
-				LLMClient: llm.NewLLMClient(utils.LLM_SERVER, 10),
+				LLMClient: llm.NewLLMClient(utils.LLM_SERVER, 120),
 				LlmModel:  utils.LLM_MODEL,
 			}
 		}
@@ -45,18 +46,15 @@ func GetInstance() llm.LLMService {
 	return etaLlmClient
 }
 
-func (ds *ETALLMClient) KnowledgeBaseChat(query string, KnowledgeBaseName string, history []interface{}) (llmRes *http.Response, err error) {
+func (ds *ETALLMClient) KnowledgeBaseChat(query string, KnowledgeBaseName string, history []string) (llmRes *http.Response, err error) {
 	ChatHistory := make([]eta_llm_http.HistoryContent, 0)
-	ChatHistory = append(ChatHistory, eta_llm_http.HistoryContent{
-		Content: query,
-		Role:    "user",
-	})
-	for _, historyItem := range history {
-		historyItemMap := historyItem.(map[string]interface{})
-		ChatHistory = append(ChatHistory, eta_llm_http.HistoryContent{
-			Content: historyItemMap["content"].(string),
-			Role:    historyItemMap["role"].(string),
-		})
+	for _, historyItemStr := range history {
+		str := strings.Split(historyItemStr, "-")
+		historyItem := eta_llm_http.HistoryContent{
+			Role:    str[0],
+			Content: str[1],
+		}
+		ChatHistory = append(ChatHistory, historyItem)
 	}
 	kbReq := eta_llm_http.KbChatRequest{
 		Query:          query,
@@ -72,7 +70,7 @@ func (ds *ETALLMClient) KnowledgeBaseChat(query string, KnowledgeBaseName string
 		PromptName:     DEFALUT_PROMPT_NAME,
 		ReturnDirect:   false,
 	}
-
+	fmt.Printf("%v", kbReq.History)
 	body, err := json.Marshal(kbReq)
 	if err != nil {
 		return
@@ -149,10 +147,6 @@ func parseResponse(response *http.Response) (baseResp eta_llm_http.BaseResponse,
 	return
 }
 func ParseStreamResponse(response *http.Response) (contentChan chan string, errChan chan error, closeChan chan struct{}) {
-	defer func() {
-		_ = response.Body.Close()
-
-	}()
 	contentChan = make(chan string, 10)
 	errChan = make(chan error, 10)
 	closeChan = make(chan struct{})
@@ -167,8 +161,17 @@ func ParseStreamResponse(response *http.Response) (contentChan chan string, errC
 			if line == "" {
 				continue
 			}
+			// 忽略 "ping" 行
+			if strings.HasPrefix(line, ": ping") {
+				continue
+			}
+			// 去除 "data: " 前缀
+			if strings.HasPrefix(line, "data: ") {
+				line = strings.TrimPrefix(line, "data: ")
+			}
 			var chunk eta_llm_http.ChunkResponse
 			if err := json.Unmarshal([]byte(line), &chunk); err != nil {
+				fmt.Println("解析错误的line:" + line)
 				errChan <- fmt.Errorf("解析 JSON 块失败: %w", err)
 				return
 			}

+ 3 - 1
utils/llm/eta_llm/eta_llm_http/response.go

@@ -8,7 +8,9 @@ type BaseResponse struct {
 	Success bool            `json:"success"`
 	Data    json.RawMessage `json:"data"`
 }
-
+type SteamResponse struct {
+	Data    ChunkResponse `json:"data"`
+}
 // ChunkResponse 定义流式响应的结构体
 type ChunkResponse struct {
 	ID          string   `json:"id"`

+ 1 - 1
utils/llm/llm_client.go

@@ -20,6 +20,6 @@ func NewLLMClient(baseURL string, timeout time.Duration) *LLMClient {
 }
 
 type LLMService interface {
-	KnowledgeBaseChat(query string, KnowledgeBaseName string, history []interface{}) (llmRes *http.Response, err error)
+	KnowledgeBaseChat(query string, KnowledgeBaseName string, history []string) (llmRes *http.Response, err error)
 	SearchKbDocs(query string, KnowledgeBaseName string) (data interface{}, err error)
 }

+ 10 - 12
utils/ws/session.go

@@ -1,7 +1,6 @@
 package ws
 
 import (
-	"encoding/json"
 	"errors"
 	"eta/eta_api/utils"
 	"github.com/gorilla/websocket"
@@ -16,15 +15,16 @@ type Session struct {
 	Conn        *websocket.Conn
 	LastActive  time.Time
 	Latency     *LatencyMeasurer
-	History     []json.RawMessage
+	History     []string
 	CloseChan   chan struct{}
-	MessageChan chan *Message
+	MessageChan chan string
 	mu          sync.RWMutex
 	sessionOnce sync.Once
 }
 type Message struct {
-	MessageType string
-	message     []byte
+	KbName     string   `json:"KbName"`
+	Query      string   `json:"Query"`
+	LastTopics []string `json:"LastTopics"`
 }
 
 // readPump 处理读操作
@@ -43,10 +43,8 @@ func (s *Session) readPump() {
 		// 处理消息
 		if err = manager.HandleMessage(s.UserId, s.Id, message); err != nil {
 			//写应答
-			_ = s.writeWithTimeout(&Message{
-				MessageType: "error",
-				message:     []byte(err.Error()),
-			})
+			_ = s.writeWithTimeout(err.Error())
+
 		}
 	}
 }
@@ -68,7 +66,7 @@ func (s *Session) Close() {
 }
 
 // 带超时的安全写入
-func (s *Session) writeWithTimeout(msg *Message) error {
+func (s *Session) writeWithTimeout(msg string) error {
 	s.mu.Lock()
 	defer s.mu.Unlock()
 	if s.Conn == nil {
@@ -78,7 +76,7 @@ func (s *Session) writeWithTimeout(msg *Message) error {
 	if err := s.Conn.SetWriteDeadline(time.Now().Add(writeWaitTimeout)); err != nil {
 		return err
 	}
-	return s.Conn.WriteMessage(websocket.TextMessage, msg.message)
+	return s.Conn.WriteMessage(websocket.TextMessage, []byte(msg))
 }
 
 // writePump 处理写操作
@@ -142,7 +140,7 @@ func NewSession(userId int, sessionId string, conn *websocket.Conn) (session *Se
 		Conn:        conn,
 		LastActive:  time.Now(),
 		CloseChan:   make(chan struct{}),
-		MessageChan: make(chan *Message, 10),
+		MessageChan: make(chan string, 10),
 	}
 	session.Latency = SetupLatencyMeasurement(conn)
 	go session.readPump()

+ 18 - 14
utils/ws/session_manager.go

@@ -1,6 +1,7 @@
 package ws
 
 import (
+	"encoding/json"
 	"errors"
 	"eta/eta_api/utils"
 	"eta/eta_api/utils/llm"
@@ -17,10 +18,10 @@ var (
 )
 
 const (
-	defaultCheckInterval = 5 * time.Second  // 检测间隔应小于心跳超时时间
-	connectionTimeout    = 20 * time.Second // 客户端超时时间
-	ReadTimeout          = 10 * time.Second // 客户端超时时间
-	writeWaitTimeout     = 5 * time.Second
+	defaultCheckInterval = 2 * time.Minute  // 检测间隔应小于心跳超时时间
+	connectionTimeout    = 10 * time.Minute // 客户端超时时间
+	ReadTimeout          = 60 * time.Second // 读取超时时间
+	writeWaitTimeout     = 60 * time.Second //写入超时时间
 )
 
 type ConnectionManager struct {
@@ -58,10 +59,17 @@ func (manager *ConnectionManager) HandleMessage(userID int, sessionID string, me
 	if !exists {
 		return errors.New("session not found")
 	}
-
+	var userMessage Message
+	err := json.Unmarshal(message, &userMessage)
+	if err != nil {
+		return errors.New("消息格式错误")
+	}
 	// 处理业务逻辑
-	session.History = append(session.History, message)
-	resp, err := llmService.KnowledgeBaseChat("", "hz", nil)
+	session.History = append(session.History, userMessage.LastTopics...)
+	resp, err := llmService.KnowledgeBaseChat(userMessage.Query, userMessage.KbName, session.History)
+	defer func() {
+		_ = resp.Body.Close()
+	}()
 	if err != nil {
 		err = errors.New(fmt.Sprintf("知识库问答失败: httpCode:%d,错误信息:%s", resp.StatusCode, http.StatusText(resp.StatusCode)))
 		return err
@@ -80,6 +88,7 @@ func (manager *ConnectionManager) HandleMessage(userID int, sessionID string, me
 				err = errors.New("未知的错误异常")
 				return err
 			}
+			session.UpdateActivity()
 			// 发送消息到 WebSocket
 			_ = session.Conn.WriteMessage(websocket.TextMessage, []byte(content))
 		case chanErr, ok := <-errChan:
@@ -123,12 +132,10 @@ func (manager *ConnectionManager) GetSession(sessionCode string) (session *Sessi
 
 // CheckAll 批量检测所有连接
 func (manager *ConnectionManager) CheckAll() {
-	n := 0
 	manager.Sessions.Range(func(key, value interface{}) bool {
-		n++
 		session := value.(*Session)
 		// 判断超时
-		if time.Since(session.LastActive) > connectionTimeout {
+		if time.Since(session.LastActive) > 2*connectionTimeout {
 			fmt.Printf("连接超时关闭: SessionID=%s, UserID=%s", session.Id, session.UserId)
 			utils.FileLog.Warn("连接超时关闭: SessionID=%s, UserID=%s", session.Id, session.UserId)
 			session.Close()
@@ -142,13 +149,12 @@ func (manager *ConnectionManager) CheckAll() {
 				fmt.Printf("心跳发送失败: SessionID=%s, Error=%v", s.Id, err)
 				utils.FileLog.Warn("心跳发送失败: SessionID=%s, Error=%v",
 					s.Id, err)
+				fmt.Println("心跳无响应,退出请求")
 				session.Close()
 			}
 		}(session)
-		fmt.Println("当前连接数:", n)
 		return true
 	})
-	fmt.Println("当前连接数:", n)
 }
 
 // Start 启动心跳检测
@@ -157,10 +163,8 @@ func (manager *ConnectionManager) Start() {
 	for {
 		select {
 		case <-manager.ticker.C:
-			fmt.Printf("开始检测连接超时")
 			manager.CheckAll()
 		case <-manager.stopChan:
-			fmt.Printf("退出检测")
 			return
 		}
 	}