session_manager.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. package ws
  2. import (
  3. "errors"
  4. "eta/eta_api/utils"
  5. "eta/eta_api/utils/llm"
  6. "eta/eta_api/utils/llm/eta_llm"
  7. "fmt"
  8. "github.com/gorilla/websocket"
  9. "net/http"
  10. "sync"
  11. "time"
  12. )
  13. var (
  14. llmService, _ = llm.GetInstance(llm.ETA_LLM_CLIENT)
  15. )
  16. const (
  17. defaultCheckInterval = 5 * time.Second // 检测间隔应小于心跳超时时间
  18. connectionTimeout = 20 * time.Second // 客户端超时时间
  19. ReadTimeout = 10 * time.Second // 客户端超时时间
  20. writeWaitTimeout = 5 * time.Second
  21. )
  22. type ConnectionManager struct {
  23. Sessions sync.Map
  24. ticker *time.Ticker
  25. stopChan chan struct{}
  26. }
  27. var (
  28. smOnce sync.Once
  29. manager *ConnectionManager
  30. )
  31. func GetInstance() *ConnectionManager {
  32. smOnce.Do(func() {
  33. if manager == nil {
  34. manager = &ConnectionManager{
  35. ticker: time.NewTicker(defaultCheckInterval),
  36. stopChan: make(chan struct{}),
  37. }
  38. }
  39. })
  40. return manager
  41. }
  42. func Manager() *ConnectionManager {
  43. return manager
  44. }
  45. // HandleMessage 消息处理核心逻辑
  46. func (manager *ConnectionManager) HandleMessage(userID int, sessionID string, message []byte) error {
  47. if !Allow(userID, QA_LIMITER) {
  48. return errors.New("您提问的太频繁了,请稍后再试")
  49. }
  50. session, exists := manager.GetSession(sessionID)
  51. if !exists {
  52. return errors.New("session not found")
  53. }
  54. // 处理业务逻辑
  55. session.History = append(session.History, message)
  56. resp, err := llmService.KnowledgeBaseChat("", "hz", nil)
  57. if err != nil {
  58. err = errors.New(fmt.Sprintf("知识库问答失败: httpCode:%d,错误信息:%s", resp.StatusCode, http.StatusText(resp.StatusCode)))
  59. return err
  60. }
  61. if resp.StatusCode != http.StatusOK {
  62. err = errors.New(fmt.Sprintf("知识库问答失败: httpCode:%d,错误信息:%s", resp.StatusCode, http.StatusText(resp.StatusCode)))
  63. return err
  64. }
  65. // 解析流式响应
  66. contentChan, errChan, closeChan := eta_llm.ParseStreamResponse(resp)
  67. // 处理流式数据并发送到 WebSocket
  68. for {
  69. select {
  70. case content, ok := <-contentChan:
  71. if !ok {
  72. err = errors.New("未知的错误异常")
  73. return err
  74. }
  75. // 发送消息到 WebSocket
  76. _ = session.Conn.WriteMessage(websocket.TextMessage, []byte(content))
  77. case chanErr, ok := <-errChan:
  78. if !ok {
  79. err = errors.New("未知的错误异常")
  80. } else {
  81. err = errors.New(chanErr.Error())
  82. }
  83. // 发送错误消息到 WebSocket
  84. return err
  85. case <-closeChan:
  86. return nil
  87. }
  88. }
  89. // 更新最后活跃时间
  90. // 发送响应
  91. //return session.Conn.WriteMessage(websocket.TextMessage, []byte(response))
  92. }
  93. // AddSession Add 添加一个新的会话
  94. func (manager *ConnectionManager) AddSession(session *Session) {
  95. manager.Sessions.Store(session.Id, session)
  96. }
  97. func (manager *ConnectionManager) GetSessionId(userId int, sessionId string) (sessionID string) {
  98. return fmt.Sprintf("%d_%s", userId, sessionId)
  99. }
  100. // RemoveSession Remove 移除一个会话
  101. func (manager *ConnectionManager) RemoveSession(sessionCode string) {
  102. manager.Sessions.Delete(sessionCode)
  103. }
  104. // GetSession 获取一个会话
  105. func (manager *ConnectionManager) GetSession(sessionCode string) (session *Session, exists bool) {
  106. if data, ok := manager.Sessions.Load(sessionCode); ok {
  107. session = data.(*Session)
  108. exists = ok
  109. }
  110. return
  111. }
  112. // CheckAll 批量检测所有连接
  113. func (manager *ConnectionManager) CheckAll() {
  114. n := 0
  115. manager.Sessions.Range(func(key, value interface{}) bool {
  116. n++
  117. session := value.(*Session)
  118. // 判断超时
  119. if time.Since(session.LastActive) > connectionTimeout {
  120. fmt.Printf("连接超时关闭: SessionID=%s, UserID=%s", session.Id, session.UserId)
  121. utils.FileLog.Warn("连接超时关闭: SessionID=%s, UserID=%s", session.Id, session.UserId)
  122. session.Close()
  123. return true
  124. }
  125. // 发送心跳
  126. go func(s *Session) {
  127. err := s.Conn.WriteControl(websocket.PingMessage,
  128. nil, time.Now().Add(5*time.Second))
  129. if err != nil {
  130. fmt.Printf("心跳发送失败: SessionID=%s, Error=%v", s.Id, err)
  131. utils.FileLog.Warn("心跳发送失败: SessionID=%s, Error=%v",
  132. s.Id, err)
  133. session.Close()
  134. }
  135. }(session)
  136. fmt.Println("当前连接数:", n)
  137. return true
  138. })
  139. fmt.Println("当前连接数:", n)
  140. }
  141. // Start 启动心跳检测
  142. func (manager *ConnectionManager) Start() {
  143. defer manager.ticker.Stop()
  144. for {
  145. select {
  146. case <-manager.ticker.C:
  147. fmt.Printf("开始检测连接超时")
  148. manager.CheckAll()
  149. case <-manager.stopChan:
  150. fmt.Printf("退出检测")
  151. return
  152. }
  153. }
  154. }
  155. // Stop 停止心跳检测
  156. func (manager *ConnectionManager) Stop() {
  157. close(manager.stopChan)
  158. }