123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224 |
- package ws
- import (
- "encoding/json"
- "errors"
- chatService "eta/eta_api/services/llm"
- "eta/eta_api/utils"
- "eta/eta_api/utils/llm"
- "eta/eta_api/utils/llm/eta_llm"
- "eta/eta_api/utils/llm/eta_llm/eta_llm_http"
- "fmt"
- "github.com/gorilla/websocket"
- "net/http"
- "strings"
- "sync"
- "time"
- )
- var (
- llmService, _ = llm.GetInstance(llm.ETA_LLM_CLIENT)
- )
- const (
- defaultCheckInterval = 2 * time.Minute // 检测间隔应小于心跳超时时间
- connectionTimeout = 10 * time.Minute // 客户端超时时间
- TcpTimeout = 20 * time.Minute // TCP超时时间,保底关闭,覆盖会话超时时间
- ReadTimeout = 15 * time.Minute // 读取超时时间,保底关闭,覆盖会话超时时间
- writeWaitTimeout = 60 * time.Second //写入超时时间
- )
- type ConnectionManager struct {
- Sessions sync.Map
- ticker *time.Ticker
- stopChan chan struct{}
- }
- var (
- smOnce sync.Once
- manager *ConnectionManager
- )
- func GetInstance() *ConnectionManager {
- smOnce.Do(func() {
- if manager == nil {
- manager = &ConnectionManager{
- ticker: time.NewTicker(defaultCheckInterval),
- stopChan: make(chan struct{}),
- }
- }
- })
- return manager
- }
- func Manager() *ConnectionManager {
- return manager
- }
- // HandleMessage 消息处理核心逻辑
- func (manager *ConnectionManager) HandleMessage(userID int, sessionID string, message []byte) error {
- session, exists := manager.GetSession(sessionID)
- if !exists {
- return errors.New("session not found")
- }
- if strings.ToLower(string(message)) == "pong" {
- session.UpdateActivity()
- fmt.Printf("收到心跳消息,续期长连接:%v", session.LastActive)
- return nil
- }
- if !Allow(userID, QA_LIMITER) {
- _ = session.Conn.WriteMessage(websocket.TextMessage, []byte("<think></think>"))
- _ = session.Conn.WriteMessage(websocket.TextMessage, []byte("您提问的太频繁了,请稍后再试"))
- _ = session.Conn.WriteMessage(websocket.TextMessage, []byte("<EOF/>"))
- return nil
- }
- var userMessage Message
- err := json.Unmarshal(message, &userMessage)
- if err != nil {
- utils.FileLog.Error(fmt.Sprintf("消息格式错误:%s", string(message)))
- fmt.Printf("消息格式错误:%s", string(message))
- return errors.New("消息格式错误:" + err.Error())
- }
- // 处理业务逻辑
- //session.History = append(session.History, userMessage.LastTopics...)
- redisHisChat, err := chatService.GetChatRecordsFromRedis(userMessage.ChatId)
- if err != nil {
- utils.FileLog.Error("获取历史对话数据失败,err:", err.Error())
- } else {
- for _, chat := range redisHisChat {
- his := eta_llm_http.HistoryContent{
- Content: chat.Content,
- Role: chat.ChatUserType,
- }
- hisMsg, _ := json.Marshal(&his)
- if len(hisMsg) != 0 {
- session.History = append(session.History, hisMsg)
- }
- }
- }
- resp, err := llmService.KnowledgeBaseChat(userMessage.Query, userMessage.KbName, session.History)
- defer func() {
- if resp != nil && resp.Body != nil && err == nil {
- _ = resp.Body.Close()
- }
- }()
- if resp == nil {
- utils.FileLog.Error("知识库问答失败: 无应答")
- return errors.New("知识库问答失败: 无应答")
- }
- if err != nil {
- utils.FileLog.Error(fmt.Sprintf("知识库问答失败: httpCode:%d,错误信息:%s", resp.StatusCode, http.StatusText(resp.StatusCode)))
- err = errors.New(fmt.Sprintf("知识库问答失败: httpCode:%d,错误信息:%s", resp.StatusCode, http.StatusText(resp.StatusCode)))
- return err
- }
- if resp.StatusCode != http.StatusOK {
- utils.FileLog.Error(fmt.Sprintf("知识库问答失败: httpCode:%d,错误信息:%s", resp.StatusCode, http.StatusText(resp.StatusCode)))
- err = errors.New(fmt.Sprintf("知识库问答失败: httpCode:%d,错误信息:%s", resp.StatusCode, http.StatusText(resp.StatusCode)))
- return err
- }
- // 解析流式响应
- contentChan, errChan, closeChan := eta_llm.ParseStreamResponse(resp)
- emptyContent := true
- // 处理流式数据并发送到 WebSocket
- for {
- select {
- case content, ok := <-contentChan:
- if !ok {
- err = errors.New("未知的错误异常")
- return err
- }
- session.UpdateActivity()
- if emptyContent {
- emptyContent = false
- }
- // 发送消息到 WebSocket
- _ = session.Conn.WriteMessage(websocket.TextMessage, []byte(content))
- case chanErr, ok := <-errChan:
- if !ok {
- err = errors.New("未知的错误异常")
- } else {
- err = errors.New(chanErr.Error())
- }
- // 发送错误消息到 WebSocket
- return err
- case <-closeChan:
- if emptyContent {
- _ = session.Conn.WriteMessage(websocket.TextMessage, []byte("<think></think>"))
- }
- _ = session.Conn.WriteMessage(websocket.TextMessage, []byte("<EOF/>"))
- return nil
- }
- }
- // 更新最后活跃时间
- // 发送响应
- //return session.Conn.WriteMessage(websocket.TextMessage, []byte(response))
- }
- // AddSession Add 添加一个新的会话
- func (manager *ConnectionManager) AddSession(session *Session) {
- manager.Sessions.Store(session.Id, session)
- }
- func (manager *ConnectionManager) GetSessionId(userId int, sessionId string) (sessionID string) {
- return fmt.Sprintf("%d_%s", userId, sessionId)
- }
- // RemoveSession Remove 移除一个会话
- func (manager *ConnectionManager) RemoveSession(sessionCode string) {
- fmt.Printf("移除会话: SessionID=%s, UserID=%s", sessionCode, sessionCode)
- manager.Sessions.Delete(sessionCode)
- }
- // GetSession 获取一个会话
- func (manager *ConnectionManager) GetSession(sessionCode string) (session *Session, exists bool) {
- if data, ok := manager.Sessions.Load(sessionCode); ok {
- session = data.(*Session)
- exists = ok
- }
- return
- }
- // CheckAll 批量检测所有连接
- func (manager *ConnectionManager) CheckAll() {
- manager.Sessions.Range(func(key, value interface{}) bool {
- session := value.(*Session)
- // 判断超时
- if time.Since(session.LastActive) > 2*connectionTimeout {
- fmt.Printf("连接超时关闭: SessionID=%s, UserID=%s", session.Id, session.UserId)
- utils.FileLog.Warn("连接超时关闭: SessionID=%s, UserID=%s", session.Id, session.UserId)
- session.Close()
- return true
- }
- // 发送心跳
- go func(s *Session) {
- err := s.Conn.WriteControl(websocket.PingMessage,
- nil, time.Now().Add(writeWaitTimeout))
- if err != nil {
- fmt.Printf("心跳发送失败: SessionID=%s, Error=%v", s.Id, err)
- utils.FileLog.Warn("心跳发送失败: SessionID=%s, Error=%v",
- s.Id, err)
- fmt.Println("心跳无响应,退出请求")
- session.Close()
- }
- }(session)
- return true
- })
- }
- // Start 启动心跳检测
- func (manager *ConnectionManager) Start() {
- defer manager.ticker.Stop()
- for {
- select {
- case <-manager.ticker.C:
- manager.CheckAll()
- case <-manager.stopChan:
- return
- }
- }
- }
- // Stop 停止心跳检测
- func (manager *ConnectionManager) Stop() {
- close(manager.stopChan)
- }
|