package ws import ( "encoding/json" "errors" "eta/eta_api/utils" "eta/eta_api/utils/llm" "eta/eta_api/utils/llm/eta_llm" "fmt" "github.com/gorilla/websocket" "net/http" "sync" "time" ) var ( llmService, _ = llm.GetInstance(llm.ETA_LLM_CLIENT) ) const ( defaultCheckInterval = 2 * time.Minute // 检测间隔应小于心跳超时时间 connectionTimeout = 10 * time.Minute // 客户端超时时间 ReadTimeout = 60 * time.Second // 读取超时时间 writeWaitTimeout = 60 * time.Second //写入超时时间 ) type ConnectionManager struct { Sessions sync.Map ticker *time.Ticker stopChan chan struct{} } var ( smOnce sync.Once manager *ConnectionManager ) func GetInstance() *ConnectionManager { smOnce.Do(func() { if manager == nil { manager = &ConnectionManager{ ticker: time.NewTicker(defaultCheckInterval), stopChan: make(chan struct{}), } } }) return manager } func Manager() *ConnectionManager { return manager } // HandleMessage 消息处理核心逻辑 func (manager *ConnectionManager) HandleMessage(userID int, sessionID string, message []byte) error { if !Allow(userID, QA_LIMITER) { return errors.New("您提问的太频繁了,请稍后再试") } session, exists := manager.GetSession(sessionID) if !exists { return errors.New("session not found") } var userMessage Message err := json.Unmarshal(message, &userMessage) if err != nil { return errors.New("消息格式错误") } // 处理业务逻辑 session.History = append(session.History, userMessage.LastTopics...) resp, err := llmService.KnowledgeBaseChat(userMessage.Query, userMessage.KbName, session.History) defer func() { _ = resp.Body.Close() }() if err != nil { err = errors.New(fmt.Sprintf("知识库问答失败: httpCode:%d,错误信息:%s", resp.StatusCode, http.StatusText(resp.StatusCode))) return err } if resp.StatusCode != http.StatusOK { err = errors.New(fmt.Sprintf("知识库问答失败: httpCode:%d,错误信息:%s", resp.StatusCode, http.StatusText(resp.StatusCode))) return err } // 解析流式响应 contentChan, errChan, closeChan := eta_llm.ParseStreamResponse(resp) // 处理流式数据并发送到 WebSocket for { select { case content, ok := <-contentChan: if !ok { err = errors.New("未知的错误异常") return err } session.UpdateActivity() // 发送消息到 WebSocket _ = session.Conn.WriteMessage(websocket.TextMessage, []byte(content)) case chanErr, ok := <-errChan: if !ok { err = errors.New("未知的错误异常") } else { err = errors.New(chanErr.Error()) } // 发送错误消息到 WebSocket return err case <-closeChan: return nil } } // 更新最后活跃时间 // 发送响应 //return session.Conn.WriteMessage(websocket.TextMessage, []byte(response)) } // AddSession Add 添加一个新的会话 func (manager *ConnectionManager) AddSession(session *Session) { manager.Sessions.Store(session.Id, session) } func (manager *ConnectionManager) GetSessionId(userId int, sessionId string) (sessionID string) { return fmt.Sprintf("%d_%s", userId, sessionId) } // RemoveSession Remove 移除一个会话 func (manager *ConnectionManager) RemoveSession(sessionCode string) { manager.Sessions.Delete(sessionCode) } // GetSession 获取一个会话 func (manager *ConnectionManager) GetSession(sessionCode string) (session *Session, exists bool) { if data, ok := manager.Sessions.Load(sessionCode); ok { session = data.(*Session) exists = ok } return } // CheckAll 批量检测所有连接 func (manager *ConnectionManager) CheckAll() { manager.Sessions.Range(func(key, value interface{}) bool { session := value.(*Session) // 判断超时 if time.Since(session.LastActive) > 2*connectionTimeout { fmt.Printf("连接超时关闭: SessionID=%s, UserID=%s", session.Id, session.UserId) utils.FileLog.Warn("连接超时关闭: SessionID=%s, UserID=%s", session.Id, session.UserId) session.Close() return true } // 发送心跳 go func(s *Session) { err := s.Conn.WriteControl(websocket.PingMessage, nil, time.Now().Add(5*time.Second)) if err != nil { fmt.Printf("心跳发送失败: SessionID=%s, Error=%v", s.Id, err) utils.FileLog.Warn("心跳发送失败: SessionID=%s, Error=%v", s.Id, err) fmt.Println("心跳无响应,退出请求") session.Close() } }(session) return true }) } // Start 启动心跳检测 func (manager *ConnectionManager) Start() { defer manager.ticker.Stop() for { select { case <-manager.ticker.C: manager.CheckAll() case <-manager.stopChan: return } } } // Stop 停止心跳检测 func (manager *ConnectionManager) Stop() { close(manager.stopChan) }