session_manager.go 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. package ws
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "eta/eta_api/models/rag"
  6. chatService "eta/eta_api/services/llm"
  7. "eta/eta_api/utils"
  8. "eta/eta_api/utils/llm"
  9. "eta/eta_api/utils/llm/eta_llm"
  10. "eta/eta_api/utils/llm/eta_llm/eta_llm_http"
  11. "fmt"
  12. "github.com/gorilla/websocket"
  13. "net/http"
  14. "regexp"
  15. "strings"
  16. "sync"
  17. "time"
  18. )
  19. var (
  20. llmService, _ = llm.GetInstance(llm.ETA_LLM_CLIENT)
  21. )
  22. const (
  23. defaultCheckInterval = 2 * time.Minute // 检测间隔应小于心跳超时时间
  24. connectionTimeout = 10 * time.Minute // 客户端超时时间
  25. TcpTimeout = 20 * time.Minute // TCP超时时间,保底关闭,覆盖会话超时时间
  26. ReadTimeout = 15 * time.Minute // 读取超时时间,保底关闭,覆盖会话超时时间
  27. writeWaitTimeout = 60 * time.Second //写入超时时间
  28. )
  29. type ConnectionManager struct {
  30. Sessions sync.Map
  31. ticker *time.Ticker
  32. stopChan chan struct{}
  33. }
  34. var (
  35. smOnce sync.Once
  36. manager *ConnectionManager
  37. )
  38. func GetInstance() *ConnectionManager {
  39. smOnce.Do(func() {
  40. if manager == nil {
  41. manager = &ConnectionManager{
  42. ticker: time.NewTicker(defaultCheckInterval),
  43. stopChan: make(chan struct{}),
  44. }
  45. }
  46. })
  47. return manager
  48. }
  49. func Manager() *ConnectionManager {
  50. return manager
  51. }
  52. // HandleMessage 消息处理核心逻辑
  53. func (manager *ConnectionManager) HandleMessage(userID int, sessionID string, message []byte) {
  54. var err error
  55. session, exists := manager.GetSession(sessionID)
  56. if !exists {
  57. err = errors.New("session not found")
  58. return
  59. }
  60. if strings.ToLower(string(message)) == "pong" {
  61. session.UpdateActivity()
  62. fmt.Printf("收到心跳消息,续期长连接:%v", session.LastActive)
  63. return
  64. }
  65. defer func() {
  66. if err != nil {
  67. //写应答
  68. _ = session.writeWithTimeout("<think></think>")
  69. _ = session.writeWithTimeout(err.Error())
  70. _ = session.writeWithTimeout("<EOF/>")
  71. }
  72. }()
  73. var userMessage Message
  74. err = json.Unmarshal(message, &userMessage)
  75. if err != nil {
  76. utils.FileLog.Error(fmt.Sprintf("消息格式错误:%s", string(message)))
  77. fmt.Printf("消息格式错误:%s", string(message))
  78. err = errors.New("消息格式错误:" + err.Error())
  79. return
  80. }
  81. if userMessage.MessageType == `stop` {
  82. if session.LLMStatus == 1 {
  83. // 标记llm提问状态:暂停提问
  84. session.LLMStatus = -1
  85. }
  86. if session.CloseLlmChan != nil {
  87. *session.CloseLlmChan <- true
  88. }
  89. return
  90. }
  91. // 限流
  92. if !Allow(userID, QA_LIMITER) {
  93. _ = session.Conn.WriteMessage(websocket.TextMessage, []byte("<think></think>"))
  94. _ = session.Conn.WriteMessage(websocket.TextMessage, []byte("您提问的太频繁了,请稍后再试"))
  95. _ = session.Conn.WriteMessage(websocket.TextMessage, []byte("<EOF/>"))
  96. return
  97. }
  98. // 处理业务逻辑
  99. //session.History = append(session.History, userMessage.LastTopics...)
  100. redisHisChat, err := chatService.GetChatRecordsFromRedis(userMessage.ChatId)
  101. if err != nil {
  102. utils.FileLog.Error("获取历史对话数据失败,err:", err.Error())
  103. } else {
  104. for _, chat := range redisHisChat {
  105. his := eta_llm_http.HistoryContent{
  106. Content: chat.Content,
  107. Role: chat.ChatUserType,
  108. }
  109. hisMsg, _ := json.Marshal(&his)
  110. if len(hisMsg) != 0 {
  111. session.History = append(session.History, hisMsg)
  112. }
  113. }
  114. }
  115. //修改逻辑。如果问题出现敏感词,则返回敏感词提示
  116. var resp *http.Response
  117. labels := llm.GetDAFHandlerInstance().FindTextTagLabels(userMessage.Query)
  118. if len(labels) > 0 {
  119. articles, findErr := rag.GetArticleByTags(labels, 15, 10)
  120. if findErr != nil {
  121. utils.FileLog.Warn("没有搜索到相关的研报内容,执行RAG对话 ")
  122. resp, err = llmService.KnowledgeBaseChat(userMessage.Query, userMessage.KbName, session.History)
  123. } else {
  124. utils.FileLog.Info("搜索到相关的研报内容,执行completion对话 ")
  125. //直接对话,不需要走RAG
  126. var articlesContents string
  127. articlesCounts := len(articles)
  128. for i := 0; i < articlesCounts; i++ {
  129. articlesContents += fmt.Sprintf("【%d】:%s\n", i+1, articles[i].TextContent)
  130. }
  131. promote := fmt.Sprintf("【问题】:%s,请基于以下%d篇研报进行回答问题,以下是研报:%s", userMessage.Query, articlesCounts, articlesContents)
  132. re := regexp.MustCompile(`\s+`)
  133. promote = re.ReplaceAllString(promote, "")
  134. resp, err = llmService.CompletionChat(promote, session.History)
  135. }
  136. } else {
  137. resp, err = llmService.KnowledgeBaseChat(userMessage.Query, userMessage.KbName, session.History)
  138. }
  139. defer func() {
  140. if resp != nil && resp.Body != nil && err == nil {
  141. _ = resp.Body.Close()
  142. }
  143. }()
  144. if resp == nil {
  145. utils.FileLog.Error("知识库问答失败: 无应答")
  146. err = errors.New("知识库问答失败: 无应答")
  147. return
  148. }
  149. if err != nil {
  150. utils.FileLog.Error(fmt.Sprintf("知识库问答失败: httpCode:%d,错误信息:%s", resp.StatusCode, http.StatusText(resp.StatusCode)))
  151. err = errors.New(fmt.Sprintf("知识库问答失败: httpCode:%d,错误信息:%s", resp.StatusCode, http.StatusText(resp.StatusCode)))
  152. return
  153. }
  154. if resp.StatusCode != http.StatusOK {
  155. utils.FileLog.Error(fmt.Sprintf("知识库问答失败: httpCode:%d,错误信息:%s", resp.StatusCode, http.StatusText(resp.StatusCode)))
  156. err = errors.New(fmt.Sprintf("知识库问答失败: httpCode:%d,错误信息:%s", resp.StatusCode, http.StatusText(resp.StatusCode)))
  157. return
  158. }
  159. // 解析流式响应
  160. contentChan, errChan, closeChan, closeLlmChan := eta_llm.ParseStreamResponse(resp)
  161. session.CloseLlmChan = &closeLlmChan
  162. // 标记llm提问状态:提问中
  163. session.LLMStatus = 1
  164. emptyContent := true
  165. // 处理流式数据并发送到 WebSocket
  166. for {
  167. select {
  168. case content, ok := <-contentChan:
  169. if !ok && session.LLMStatus != -1 {
  170. err = errors.New("未知的内容错误异常")
  171. // 标记llm提问状态:未提问
  172. session.LLMStatus = 0
  173. return
  174. }
  175. session.UpdateActivity()
  176. if emptyContent {
  177. emptyContent = false
  178. }
  179. // 发送消息到 WebSocket
  180. _ = session.Conn.WriteMessage(websocket.TextMessage, []byte(content))
  181. case chanErr, ok := <-errChan:
  182. if !ok && session.LLMStatus != -1 {
  183. err = errors.New("未知的错误异常")
  184. } else if chanErr != nil {
  185. err = errors.New(chanErr.Error())
  186. }
  187. // 标记llm提问状态:未提问
  188. session.LLMStatus = 0
  189. // 发送错误消息到 WebSocket
  190. return
  191. case <-closeChan:
  192. if emptyContent {
  193. _ = session.Conn.WriteMessage(websocket.TextMessage, []byte("<think></think>"))
  194. _ = session.Conn.WriteMessage(websocket.TextMessage, []byte("暂时找不到答案"))
  195. }
  196. _ = session.Conn.WriteMessage(websocket.TextMessage, []byte("<EOF/>"))
  197. // 标记llm提问状态:未提问
  198. session.LLMStatus = 0
  199. return
  200. }
  201. }
  202. // 更新最后活跃时间
  203. // 发送响应
  204. //return session.Conn.WriteMessage(websocket.TextMessage, []byte(response))
  205. }
  206. // AddSession Add 添加一个新的会话
  207. func (manager *ConnectionManager) AddSession(session *Session) {
  208. manager.Sessions.Store(session.Id, session)
  209. }
  210. func (manager *ConnectionManager) GetSessionId(userId int, sessionId string) (sessionID string) {
  211. return fmt.Sprintf("%d_%s", userId, sessionId)
  212. }
  213. // RemoveSession Remove 移除一个会话
  214. func (manager *ConnectionManager) RemoveSession(sessionCode string) {
  215. fmt.Printf("移除会话: SessionID=%s, UserID=%s", sessionCode, sessionCode)
  216. manager.Sessions.Delete(sessionCode)
  217. }
  218. // GetSession 获取一个会话
  219. func (manager *ConnectionManager) GetSession(sessionCode string) (session *Session, exists bool) {
  220. if data, ok := manager.Sessions.Load(sessionCode); ok {
  221. session = data.(*Session)
  222. exists = ok
  223. }
  224. return
  225. }
  226. // CheckAll 批量检测所有连接
  227. func (manager *ConnectionManager) CheckAll() {
  228. manager.Sessions.Range(func(key, value interface{}) bool {
  229. session := value.(*Session)
  230. // 判断超时
  231. if time.Since(session.LastActive) > 2*connectionTimeout {
  232. fmt.Printf("连接超时关闭: SessionID=%s, UserID=%s", session.Id, session.UserId)
  233. utils.FileLog.Warn("连接超时关闭: SessionID=%s, UserID=%s", session.Id, session.UserId)
  234. session.Close()
  235. return true
  236. }
  237. // 发送心跳
  238. go func(s *Session) {
  239. err := s.Conn.WriteControl(websocket.PingMessage,
  240. nil, time.Now().Add(writeWaitTimeout))
  241. if err != nil {
  242. fmt.Printf("心跳发送失败: SessionID=%s, Error=%v", s.Id, err)
  243. utils.FileLog.Warn("心跳发送失败: SessionID=%s, Error=%v",
  244. s.Id, err)
  245. fmt.Println("心跳无响应,退出请求")
  246. session.Close()
  247. }
  248. }(session)
  249. return true
  250. })
  251. }
  252. // Start 启动心跳检测
  253. func (manager *ConnectionManager) Start() {
  254. defer manager.ticker.Stop()
  255. for {
  256. select {
  257. case <-manager.ticker.C:
  258. manager.CheckAll()
  259. case <-manager.stopChan:
  260. return
  261. }
  262. }
  263. }
  264. // Stop 停止心跳检测
  265. func (manager *ConnectionManager) Stop() {
  266. close(manager.stopChan)
  267. }