session_manager.go 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. package ws
  2. import (
  3. "encoding/json"
  4. "errors"
  5. chatService "eta/eta_api/services/llm"
  6. "eta/eta_api/utils"
  7. "eta/eta_api/utils/llm"
  8. "eta/eta_api/utils/llm/eta_llm"
  9. "eta/eta_api/utils/llm/eta_llm/eta_llm_http"
  10. "fmt"
  11. "github.com/gorilla/websocket"
  12. "net/http"
  13. "strings"
  14. "sync"
  15. "time"
  16. )
  17. var (
  18. llmService, _ = llm.GetInstance(llm.ETA_LLM_CLIENT)
  19. )
  20. const (
  21. defaultCheckInterval = 2 * time.Minute // 检测间隔应小于心跳超时时间
  22. connectionTimeout = 10 * time.Minute // 客户端超时时间
  23. TcpTimeout = 20 * time.Minute // TCP超时时间,保底关闭,覆盖会话超时时间
  24. ReadTimeout = 15 * time.Minute // 读取超时时间,保底关闭,覆盖会话超时时间
  25. writeWaitTimeout = 60 * time.Second //写入超时时间
  26. )
  27. type ConnectionManager struct {
  28. Sessions sync.Map
  29. ticker *time.Ticker
  30. stopChan chan struct{}
  31. }
  32. var (
  33. smOnce sync.Once
  34. manager *ConnectionManager
  35. )
  36. func GetInstance() *ConnectionManager {
  37. smOnce.Do(func() {
  38. if manager == nil {
  39. manager = &ConnectionManager{
  40. ticker: time.NewTicker(defaultCheckInterval),
  41. stopChan: make(chan struct{}),
  42. }
  43. }
  44. })
  45. return manager
  46. }
  47. func Manager() *ConnectionManager {
  48. return manager
  49. }
  50. // HandleMessage 消息处理核心逻辑
  51. func (manager *ConnectionManager) HandleMessage(userID int, sessionID string, message []byte) {
  52. var err error
  53. session, exists := manager.GetSession(sessionID)
  54. if !exists {
  55. err = errors.New("session not found")
  56. return
  57. }
  58. if strings.ToLower(string(message)) == "pong" {
  59. session.UpdateActivity()
  60. fmt.Printf("收到心跳消息,续期长连接:%v", session.LastActive)
  61. return
  62. }
  63. defer func() {
  64. if err != nil {
  65. //写应答
  66. _ = session.writeWithTimeout("<think></think>")
  67. _ = session.writeWithTimeout(err.Error())
  68. _ = session.writeWithTimeout("<EOF/>")
  69. }
  70. }()
  71. var userMessage Message
  72. err = json.Unmarshal(message, &userMessage)
  73. if err != nil {
  74. utils.FileLog.Error(fmt.Sprintf("消息格式错误:%s", string(message)))
  75. fmt.Printf("消息格式错误:%s", string(message))
  76. err = errors.New("消息格式错误:" + err.Error())
  77. return
  78. }
  79. if userMessage.MessageType == `stop` {
  80. if session.LLMStatus == 1 {
  81. // 标记llm提问状态:暂停提问
  82. session.LLMStatus = -1
  83. }
  84. if session.CloseLlmChan != nil {
  85. *session.CloseLlmChan <- true
  86. }
  87. return
  88. }
  89. // 限流
  90. if !Allow(userID, QA_LIMITER) {
  91. _ = session.Conn.WriteMessage(websocket.TextMessage, []byte("<think></think>"))
  92. _ = session.Conn.WriteMessage(websocket.TextMessage, []byte("您提问的太频繁了,请稍后再试"))
  93. _ = session.Conn.WriteMessage(websocket.TextMessage, []byte("<EOF/>"))
  94. return
  95. }
  96. // 处理业务逻辑
  97. //session.History = append(session.History, userMessage.LastTopics...)
  98. redisHisChat, err := chatService.GetChatRecordsFromRedis(userMessage.ChatId)
  99. if err != nil {
  100. utils.FileLog.Error("获取历史对话数据失败,err:", err.Error())
  101. } else {
  102. for _, chat := range redisHisChat {
  103. his := eta_llm_http.HistoryContent{
  104. Content: chat.Content,
  105. Role: chat.ChatUserType,
  106. }
  107. hisMsg, _ := json.Marshal(&his)
  108. if len(hisMsg) != 0 {
  109. session.History = append(session.History, hisMsg)
  110. }
  111. }
  112. }
  113. resp, err := llmService.KnowledgeBaseChat(userMessage.Query, userMessage.KbName, session.History)
  114. defer func() {
  115. if resp != nil && resp.Body != nil && err == nil {
  116. _ = resp.Body.Close()
  117. }
  118. }()
  119. if resp == nil {
  120. utils.FileLog.Error("知识库问答失败: 无应答")
  121. err = errors.New("知识库问答失败: 无应答")
  122. return
  123. }
  124. if err != nil {
  125. utils.FileLog.Error(fmt.Sprintf("知识库问答失败: httpCode:%d,错误信息:%s", resp.StatusCode, http.StatusText(resp.StatusCode)))
  126. err = errors.New(fmt.Sprintf("知识库问答失败: httpCode:%d,错误信息:%s", resp.StatusCode, http.StatusText(resp.StatusCode)))
  127. return
  128. }
  129. if resp.StatusCode != http.StatusOK {
  130. utils.FileLog.Error(fmt.Sprintf("知识库问答失败: httpCode:%d,错误信息:%s", resp.StatusCode, http.StatusText(resp.StatusCode)))
  131. err = errors.New(fmt.Sprintf("知识库问答失败: httpCode:%d,错误信息:%s", resp.StatusCode, http.StatusText(resp.StatusCode)))
  132. return
  133. }
  134. // 解析流式响应
  135. contentChan, errChan, closeChan, closeLlmChan := eta_llm.ParseStreamResponse(resp)
  136. session.CloseLlmChan = &closeLlmChan
  137. // 标记llm提问状态:提问中
  138. session.LLMStatus = 1
  139. emptyContent := true
  140. // 处理流式数据并发送到 WebSocket
  141. for {
  142. select {
  143. case content, ok := <-contentChan:
  144. if !ok && session.LLMStatus != -1 {
  145. err = errors.New("未知的内容错误异常")
  146. // 标记llm提问状态:未提问
  147. session.LLMStatus = 0
  148. return
  149. }
  150. session.UpdateActivity()
  151. if emptyContent {
  152. emptyContent = false
  153. }
  154. // 发送消息到 WebSocket
  155. _ = session.Conn.WriteMessage(websocket.TextMessage, []byte(content))
  156. case chanErr, ok := <-errChan:
  157. if !ok {
  158. err = errors.New("未知的错误异常")
  159. } else {
  160. err = errors.New(chanErr.Error())
  161. }
  162. // 标记llm提问状态:未提问
  163. session.LLMStatus = 0
  164. // 发送错误消息到 WebSocket
  165. return
  166. case <-closeChan:
  167. if emptyContent {
  168. _ = session.Conn.WriteMessage(websocket.TextMessage, []byte("<think></think>"))
  169. _ = session.Conn.WriteMessage(websocket.TextMessage, []byte("暂时找不到答案"))
  170. }
  171. _ = session.Conn.WriteMessage(websocket.TextMessage, []byte("<EOF/>"))
  172. // 标记llm提问状态:未提问
  173. session.LLMStatus = 0
  174. return
  175. }
  176. }
  177. // 更新最后活跃时间
  178. // 发送响应
  179. //return session.Conn.WriteMessage(websocket.TextMessage, []byte(response))
  180. }
  181. // AddSession Add 添加一个新的会话
  182. func (manager *ConnectionManager) AddSession(session *Session) {
  183. manager.Sessions.Store(session.Id, session)
  184. }
  185. func (manager *ConnectionManager) GetSessionId(userId int, sessionId string) (sessionID string) {
  186. return fmt.Sprintf("%d_%s", userId, sessionId)
  187. }
  188. // RemoveSession Remove 移除一个会话
  189. func (manager *ConnectionManager) RemoveSession(sessionCode string) {
  190. fmt.Printf("移除会话: SessionID=%s, UserID=%s", sessionCode, sessionCode)
  191. manager.Sessions.Delete(sessionCode)
  192. }
  193. // GetSession 获取一个会话
  194. func (manager *ConnectionManager) GetSession(sessionCode string) (session *Session, exists bool) {
  195. if data, ok := manager.Sessions.Load(sessionCode); ok {
  196. session = data.(*Session)
  197. exists = ok
  198. }
  199. return
  200. }
  201. // CheckAll 批量检测所有连接
  202. func (manager *ConnectionManager) CheckAll() {
  203. manager.Sessions.Range(func(key, value interface{}) bool {
  204. session := value.(*Session)
  205. // 判断超时
  206. if time.Since(session.LastActive) > 2*connectionTimeout {
  207. fmt.Printf("连接超时关闭: SessionID=%s, UserID=%s", session.Id, session.UserId)
  208. utils.FileLog.Warn("连接超时关闭: SessionID=%s, UserID=%s", session.Id, session.UserId)
  209. session.Close()
  210. return true
  211. }
  212. // 发送心跳
  213. go func(s *Session) {
  214. err := s.Conn.WriteControl(websocket.PingMessage,
  215. nil, time.Now().Add(writeWaitTimeout))
  216. if err != nil {
  217. fmt.Printf("心跳发送失败: SessionID=%s, Error=%v", s.Id, err)
  218. utils.FileLog.Warn("心跳发送失败: SessionID=%s, Error=%v",
  219. s.Id, err)
  220. fmt.Println("心跳无响应,退出请求")
  221. session.Close()
  222. }
  223. }(session)
  224. return true
  225. })
  226. }
  227. // Start 启动心跳检测
  228. func (manager *ConnectionManager) Start() {
  229. defer manager.ticker.Stop()
  230. for {
  231. select {
  232. case <-manager.ticker.C:
  233. manager.CheckAll()
  234. case <-manager.stopChan:
  235. return
  236. }
  237. }
  238. }
  239. // Stop 停止心跳检测
  240. func (manager *ConnectionManager) Stop() {
  241. close(manager.stopChan)
  242. }