session.go 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. package ws
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "eta/eta_api/utils"
  6. "fmt"
  7. "github.com/gorilla/websocket"
  8. "sync"
  9. "time"
  10. )
  11. // Session 会话结构
  12. type Session struct {
  13. Id string
  14. UserId int
  15. Conn *websocket.Conn
  16. LastActive time.Time
  17. Latency *LatencyMeasurer
  18. History []json.RawMessage
  19. CloseChan chan struct{}
  20. MessageChan chan string
  21. mu sync.RWMutex
  22. sessionOnce sync.Once
  23. CloseLlmChan *chan bool
  24. LLMStatus int8 // llm提问状态,0:未提问,1:提问中,-1:暂停提问
  25. }
  26. type Message struct {
  27. KbName string `json:"KbName"`
  28. Query string `json:"Query"`
  29. ChatId int `json:"ChatId"`
  30. MessageType string `json:"MessageType"`
  31. //LastTopics []json.RawMessage `json:"LastTopics"`
  32. }
  33. // readPump 处理读操作
  34. func (s *Session) readPump() {
  35. defer func() {
  36. fmt.Printf("读进程session %s closed", s.Id)
  37. manager.RemoveSession(s.Id)
  38. }()
  39. s.Conn.SetReadLimit(maxMessageSize)
  40. _ = s.Conn.SetReadDeadline(time.Now().Add(ReadTimeout))
  41. for {
  42. _, message, err := s.Conn.ReadMessage()
  43. if err != nil {
  44. fmt.Printf("websocket 错误关闭 %s closed", err.Error())
  45. handleCloseError(err)
  46. return
  47. }
  48. // 更新活跃时间
  49. s.UpdateActivity()
  50. // 处理消息
  51. //if err = manager.HandleMessage(s.UserId, s.Id, message); err != nil {
  52. // //写应答
  53. // _ = s.writeWithTimeout("<think></think>")
  54. // _ = s.writeWithTimeout(err.Error())
  55. // _ = s.writeWithTimeout("<EOF/>")
  56. //}
  57. go manager.HandleMessage(s.UserId, s.Id, message)
  58. }
  59. }
  60. // UpdateActivity 跟新最近活跃时间
  61. func (s *Session) UpdateActivity() {
  62. s.mu.Lock()
  63. defer s.mu.Unlock()
  64. s.LastActive = time.Now()
  65. }
  66. func (s *Session) Close() {
  67. s.sessionOnce.Do(func() {
  68. // 控制关闭顺序
  69. close(s.CloseChan)
  70. close(s.MessageChan)
  71. s.forceClose()
  72. })
  73. }
  74. // 带超时的安全写入
  75. func (s *Session) writeWithTimeout(msg string) error {
  76. s.mu.Lock()
  77. defer s.mu.Unlock()
  78. if s.Conn == nil {
  79. utils.FileLog.Error("写入消息失败,connection已关闭")
  80. return errors.New("connection closed")
  81. }
  82. // 设置写超时
  83. if err := s.Conn.SetWriteDeadline(time.Now().Add(writeWaitTimeout)); err != nil {
  84. return err
  85. }
  86. return s.Conn.WriteMessage(websocket.TextMessage, []byte(msg))
  87. }
  88. // writePump 处理写操作
  89. func (s *Session) writePump() {
  90. ticker := time.NewTicker(basePingInterval)
  91. defer func() {
  92. fmt.Printf("写继进程:session %s closed", s.Id)
  93. manager.RemoveSession(s.Id)
  94. ticker.Stop()
  95. }()
  96. for {
  97. select {
  98. case message, ok := <-s.MessageChan:
  99. if !ok {
  100. return
  101. }
  102. _ = s.writeWithTimeout(message)
  103. case <-ticker.C:
  104. _ = s.Latency.SendPing(s.Conn)
  105. ticker.Reset(s.Latency.lastLatency)
  106. case <-s.CloseChan:
  107. return
  108. }
  109. }
  110. }
  111. func handleCloseError(err error) {
  112. utils.FileLog.Error("websocket错误关闭 %s closed", err.Error())
  113. if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) {
  114. var wsErr *websocket.CloseError
  115. if !errors.As(err, &wsErr) {
  116. fmt.Printf("websocket未知错误 %s", err.Error())
  117. utils.FileLog.Error("未知错误 %s", err.Error())
  118. } else {
  119. switch wsErr.Code {
  120. case websocket.CloseNormalClosure:
  121. fmt.Println("websocket正常关闭连接")
  122. utils.FileLog.Info("正常关闭连接")
  123. default:
  124. fmt.Printf("websocket关闭代码 %d:%s", wsErr.Code, wsErr.Text)
  125. utils.FileLog.Error(":%d:%s", wsErr.Code, wsErr.Text)
  126. }
  127. }
  128. }
  129. }
  130. // 强制关闭连接
  131. func (s *Session) forceClose() {
  132. // 添加互斥锁保护
  133. s.mu.Lock()
  134. defer s.mu.Unlock()
  135. // 发送关闭帧
  136. _ = s.Conn.WriteControl(websocket.CloseMessage,
  137. websocket.FormatCloseMessage(websocket.ClosePolicyViolation, "heartbeat failed"),
  138. time.Now().Add(writeWaitTimeout))
  139. _ = s.Conn.Close()
  140. s.Conn = nil // 标记连接已关闭
  141. utils.FileLog.Info("连接已强制关闭",
  142. "user", s.UserId,
  143. "session", s.Id)
  144. }
  145. func NewSession(userId int, sessionId string, conn *websocket.Conn) (session *Session) {
  146. session = &Session{
  147. UserId: userId,
  148. Id: sessionId,
  149. Conn: conn,
  150. LastActive: time.Now(),
  151. CloseChan: make(chan struct{}),
  152. MessageChan: make(chan string, 10),
  153. }
  154. session.Latency = SetupLatencyMeasurement(conn)
  155. go session.readPump()
  156. go session.writePump()
  157. return
  158. }