session_manager.go 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. package ws
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "eta/eta_api/utils"
  6. "eta/eta_api/utils/llm"
  7. "eta/eta_api/utils/llm/eta_llm"
  8. "fmt"
  9. "github.com/gorilla/websocket"
  10. "net/http"
  11. "sync"
  12. "time"
  13. )
  14. var (
  15. llmService, _ = llm.GetInstance(llm.ETA_LLM_CLIENT)
  16. )
  17. const (
  18. defaultCheckInterval = 2 * time.Minute // 检测间隔应小于心跳超时时间
  19. connectionTimeout = 10 * time.Minute // 客户端超时时间
  20. TcpTimeout = 20 * time.Minute // TCP超时时间,保底关闭,覆盖会话超时时间
  21. ReadTimeout = 15 * time.Minute // 读取超时时间,保底关闭,覆盖会话超时时间
  22. writeWaitTimeout = 60 * time.Second //写入超时时间
  23. )
  24. type ConnectionManager struct {
  25. Sessions sync.Map
  26. ticker *time.Ticker
  27. stopChan chan struct{}
  28. }
  29. var (
  30. smOnce sync.Once
  31. manager *ConnectionManager
  32. )
  33. func GetInstance() *ConnectionManager {
  34. smOnce.Do(func() {
  35. if manager == nil {
  36. manager = &ConnectionManager{
  37. ticker: time.NewTicker(defaultCheckInterval),
  38. stopChan: make(chan struct{}),
  39. }
  40. }
  41. })
  42. return manager
  43. }
  44. func Manager() *ConnectionManager {
  45. return manager
  46. }
  47. // HandleMessage 消息处理核心逻辑
  48. func (manager *ConnectionManager) HandleMessage(userID int, sessionID string, message []byte) error {
  49. if !Allow(userID, QA_LIMITER) {
  50. return errors.New("您提问的太频繁了,请稍后再试")
  51. }
  52. session, exists := manager.GetSession(sessionID)
  53. if !exists {
  54. return errors.New("session not found")
  55. }
  56. var userMessage Message
  57. err := json.Unmarshal(message, &userMessage)
  58. if err != nil {
  59. fmt.Printf("消息格式错误:%s", string(message))
  60. return errors.New("消息格式错误:" + err.Error())
  61. }
  62. // 处理业务逻辑
  63. session.History = append(session.History, userMessage.LastTopics...)
  64. resp, err := llmService.KnowledgeBaseChat(userMessage.Query, userMessage.KbName, session.History)
  65. defer func() {
  66. if resp != nil && resp.Body != nil && err == nil {
  67. _ = resp.Body.Close()
  68. }
  69. }()
  70. if resp == nil {
  71. return errors.New("知识库问答失败: 无应答")
  72. }
  73. if err != nil {
  74. err = errors.New(fmt.Sprintf("知识库问答失败: httpCode:%d,错误信息:%s", resp.StatusCode, http.StatusText(resp.StatusCode)))
  75. return err
  76. }
  77. if resp.StatusCode != http.StatusOK {
  78. err = errors.New(fmt.Sprintf("知识库问答失败: httpCode:%d,错误信息:%s", resp.StatusCode, http.StatusText(resp.StatusCode)))
  79. return err
  80. }
  81. // 解析流式响应
  82. contentChan, errChan, closeChan := eta_llm.ParseStreamResponse(resp)
  83. // 处理流式数据并发送到 WebSocket
  84. for {
  85. select {
  86. case content, ok := <-contentChan:
  87. if !ok {
  88. err = errors.New("未知的错误异常")
  89. return err
  90. }
  91. session.UpdateActivity()
  92. // 发送消息到 WebSocket
  93. _ = session.Conn.WriteMessage(websocket.TextMessage, []byte(content))
  94. case chanErr, ok := <-errChan:
  95. if !ok {
  96. err = errors.New("未知的错误异常")
  97. } else {
  98. err = errors.New(chanErr.Error())
  99. }
  100. // 发送错误消息到 WebSocket
  101. return err
  102. case <-closeChan:
  103. _ = session.Conn.WriteMessage(websocket.TextMessage, []byte("<EOF/>"))
  104. return nil
  105. }
  106. }
  107. // 更新最后活跃时间
  108. // 发送响应
  109. //return session.Conn.WriteMessage(websocket.TextMessage, []byte(response))
  110. }
  111. // AddSession Add 添加一个新的会话
  112. func (manager *ConnectionManager) AddSession(session *Session) {
  113. manager.Sessions.Store(session.Id, session)
  114. }
  115. func (manager *ConnectionManager) GetSessionId(userId int, sessionId string) (sessionID string) {
  116. return fmt.Sprintf("%d_%s", userId, sessionId)
  117. }
  118. // RemoveSession Remove 移除一个会话
  119. func (manager *ConnectionManager) RemoveSession(sessionCode string) {
  120. fmt.Printf("移除会话: SessionID=%s, UserID=%s", sessionCode, sessionCode)
  121. manager.Sessions.Delete(sessionCode)
  122. }
  123. // GetSession 获取一个会话
  124. func (manager *ConnectionManager) GetSession(sessionCode string) (session *Session, exists bool) {
  125. if data, ok := manager.Sessions.Load(sessionCode); ok {
  126. session = data.(*Session)
  127. exists = ok
  128. }
  129. return
  130. }
  131. // CheckAll 批量检测所有连接
  132. func (manager *ConnectionManager) CheckAll() {
  133. manager.Sessions.Range(func(key, value interface{}) bool {
  134. session := value.(*Session)
  135. // 判断超时
  136. if time.Since(session.LastActive) > 2*connectionTimeout {
  137. fmt.Printf("连接超时关闭: SessionID=%s, UserID=%s", session.Id, session.UserId)
  138. utils.FileLog.Warn("连接超时关闭: SessionID=%s, UserID=%s", session.Id, session.UserId)
  139. session.Close()
  140. return true
  141. }
  142. // 发送心跳
  143. go func(s *Session) {
  144. err := s.Conn.WriteControl(websocket.PingMessage,
  145. nil, time.Now().Add(writeWaitTimeout))
  146. if err != nil {
  147. fmt.Printf("心跳发送失败: SessionID=%s, Error=%v", s.Id, err)
  148. utils.FileLog.Warn("心跳发送失败: SessionID=%s, Error=%v",
  149. s.Id, err)
  150. fmt.Println("心跳无响应,退出请求")
  151. session.Close()
  152. }
  153. }(session)
  154. return true
  155. })
  156. }
  157. // Start 启动心跳检测
  158. func (manager *ConnectionManager) Start() {
  159. defer manager.ticker.Stop()
  160. for {
  161. select {
  162. case <-manager.ticker.C:
  163. manager.CheckAll()
  164. case <-manager.stopChan:
  165. return
  166. }
  167. }
  168. }
  169. // Stop 停止心跳检测
  170. func (manager *ConnectionManager) Stop() {
  171. close(manager.stopChan)
  172. }