package ws
import (
"encoding/json"
"errors"
chatService "eta/eta_api/services/llm"
"eta/eta_api/utils"
"eta/eta_api/utils/llm"
"eta/eta_api/utils/llm/eta_llm"
"eta/eta_api/utils/llm/eta_llm/eta_llm_http"
"fmt"
"github.com/gorilla/websocket"
"net/http"
"strings"
"sync"
"time"
)
var (
llmService, _ = llm.GetInstance(llm.ETA_LLM_CLIENT)
)
const (
defaultCheckInterval = 2 * time.Minute // 检测间隔应小于心跳超时时间
connectionTimeout = 10 * time.Minute // 客户端超时时间
TcpTimeout = 20 * time.Minute // TCP超时时间,保底关闭,覆盖会话超时时间
ReadTimeout = 15 * time.Minute // 读取超时时间,保底关闭,覆盖会话超时时间
writeWaitTimeout = 60 * time.Second //写入超时时间
)
type ConnectionManager struct {
Sessions sync.Map
ticker *time.Ticker
stopChan chan struct{}
}
var (
smOnce sync.Once
manager *ConnectionManager
)
func GetInstance() *ConnectionManager {
smOnce.Do(func() {
if manager == nil {
manager = &ConnectionManager{
ticker: time.NewTicker(defaultCheckInterval),
stopChan: make(chan struct{}),
}
}
})
return manager
}
func Manager() *ConnectionManager {
return manager
}
// HandleMessage 消息处理核心逻辑
func (manager *ConnectionManager) HandleMessage(userID int, sessionID string, message []byte) error {
session, exists := manager.GetSession(sessionID)
if !exists {
return errors.New("session not found")
}
if strings.ToLower(string(message)) == "pong" {
session.UpdateActivity()
fmt.Printf("收到心跳消息,续期长连接:%v", session.LastActive)
return nil
}
if !Allow(userID, QA_LIMITER) {
_ = session.Conn.WriteMessage(websocket.TextMessage, []byte("您提问的太频繁了,请稍后再试"))
_ = session.Conn.WriteMessage(websocket.TextMessage, []byte(""))
return nil
}
var userMessage Message
err := json.Unmarshal(message, &userMessage)
if err != nil {
utils.FileLog.Error(fmt.Sprintf("消息格式错误:%s", string(message)))
fmt.Printf("消息格式错误:%s", string(message))
return errors.New("消息格式错误:" + err.Error())
}
// 处理业务逻辑
//session.History = append(session.History, userMessage.LastTopics...)
redisHisChat, err := chatService.GetChatRecordsFromRedis(userMessage.ChatId)
if err != nil {
utils.FileLog.Error("获取历史对话数据失败,err:", err.Error())
} else {
for _, chat := range redisHisChat {
his := eta_llm_http.HistoryContent{
Content: chat.Content,
Role: chat.ChatUserType,
}
hisMsg, _ := json.Marshal(&his)
if len(hisMsg) != 0 {
session.History = append(session.History, hisMsg)
}
}
}
resp, err := llmService.KnowledgeBaseChat(userMessage.Query, userMessage.KbName, session.History)
defer func() {
if resp != nil && resp.Body != nil && err == nil {
_ = resp.Body.Close()
}
}()
if resp == nil {
utils.FileLog.Error("知识库问答失败: 无应答")
return errors.New("知识库问答失败: 无应答")
}
if err != nil {
utils.FileLog.Error(fmt.Sprintf("知识库问答失败: httpCode:%d,错误信息:%s", resp.StatusCode, http.StatusText(resp.StatusCode)))
err = errors.New(fmt.Sprintf("知识库问答失败: httpCode:%d,错误信息:%s", resp.StatusCode, http.StatusText(resp.StatusCode)))
return err
}
if resp.StatusCode != http.StatusOK {
utils.FileLog.Error(fmt.Sprintf("知识库问答失败: httpCode:%d,错误信息:%s", resp.StatusCode, http.StatusText(resp.StatusCode)))
err = errors.New(fmt.Sprintf("知识库问答失败: httpCode:%d,错误信息:%s", resp.StatusCode, http.StatusText(resp.StatusCode)))
return err
}
// 解析流式响应
contentChan, errChan, closeChan := eta_llm.ParseStreamResponse(resp)
// 处理流式数据并发送到 WebSocket
for {
select {
case content, ok := <-contentChan:
if !ok {
err = errors.New("未知的错误异常")
return err
}
session.UpdateActivity()
// 发送消息到 WebSocket
_ = session.Conn.WriteMessage(websocket.TextMessage, []byte(content))
case chanErr, ok := <-errChan:
if !ok {
err = errors.New("未知的错误异常")
} else {
err = errors.New(chanErr.Error())
}
// 发送错误消息到 WebSocket
return err
case <-closeChan:
_ = session.Conn.WriteMessage(websocket.TextMessage, []byte(""))
return nil
}
}
// 更新最后活跃时间
// 发送响应
//return session.Conn.WriteMessage(websocket.TextMessage, []byte(response))
}
// AddSession Add 添加一个新的会话
func (manager *ConnectionManager) AddSession(session *Session) {
manager.Sessions.Store(session.Id, session)
}
func (manager *ConnectionManager) GetSessionId(userId int, sessionId string) (sessionID string) {
return fmt.Sprintf("%d_%s", userId, sessionId)
}
// RemoveSession Remove 移除一个会话
func (manager *ConnectionManager) RemoveSession(sessionCode string) {
fmt.Printf("移除会话: SessionID=%s, UserID=%s", sessionCode, sessionCode)
manager.Sessions.Delete(sessionCode)
}
// GetSession 获取一个会话
func (manager *ConnectionManager) GetSession(sessionCode string) (session *Session, exists bool) {
if data, ok := manager.Sessions.Load(sessionCode); ok {
session = data.(*Session)
exists = ok
}
return
}
// CheckAll 批量检测所有连接
func (manager *ConnectionManager) CheckAll() {
manager.Sessions.Range(func(key, value interface{}) bool {
session := value.(*Session)
// 判断超时
if time.Since(session.LastActive) > 2*connectionTimeout {
fmt.Printf("连接超时关闭: SessionID=%s, UserID=%s", session.Id, session.UserId)
utils.FileLog.Warn("连接超时关闭: SessionID=%s, UserID=%s", session.Id, session.UserId)
session.Close()
return true
}
// 发送心跳
go func(s *Session) {
err := s.Conn.WriteControl(websocket.PingMessage,
nil, time.Now().Add(writeWaitTimeout))
if err != nil {
fmt.Printf("心跳发送失败: SessionID=%s, Error=%v", s.Id, err)
utils.FileLog.Warn("心跳发送失败: SessionID=%s, Error=%v",
s.Id, err)
fmt.Println("心跳无响应,退出请求")
session.Close()
}
}(session)
return true
})
}
// Start 启动心跳检测
func (manager *ConnectionManager) Start() {
defer manager.ticker.Stop()
for {
select {
case <-manager.ticker.C:
manager.CheckAll()
case <-manager.stopChan:
return
}
}
}
// Stop 停止心跳检测
func (manager *ConnectionManager) Stop() {
close(manager.stopChan)
}