session.go 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. package ws
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "eta/eta_api/utils"
  6. "fmt"
  7. "github.com/gorilla/websocket"
  8. "sync"
  9. "time"
  10. )
  11. // Session 会话结构
  12. type Session struct {
  13. Id string
  14. UserId int
  15. Conn *websocket.Conn
  16. LastActive time.Time
  17. Latency *LatencyMeasurer
  18. History []json.RawMessage
  19. CloseChan chan struct{}
  20. MessageChan chan string
  21. mu sync.RWMutex
  22. sessionOnce sync.Once
  23. }
  24. type Message struct {
  25. KbName string `json:"KbName"`
  26. Query string `json:"Query"`
  27. ChatId int `json:"ChatId"`
  28. //LastTopics []json.RawMessage `json:"LastTopics"`
  29. }
  30. // readPump 处理读操作
  31. func (s *Session) readPump() {
  32. defer func() {
  33. fmt.Printf("读进程session %s closed", s.Id)
  34. manager.RemoveSession(s.Id)
  35. }()
  36. s.Conn.SetReadLimit(maxMessageSize)
  37. _ = s.Conn.SetReadDeadline(time.Now().Add(ReadTimeout))
  38. for {
  39. _, message, err := s.Conn.ReadMessage()
  40. if err != nil {
  41. fmt.Printf("websocket 错误关闭 %s closed", err.Error())
  42. handleCloseError(err)
  43. return
  44. }
  45. // 更新活跃时间
  46. s.UpdateActivity()
  47. // 处理消息
  48. if err = manager.HandleMessage(s.UserId, s.Id, message); err != nil {
  49. //写应答
  50. _ = s.writeWithTimeout("<think></think>")
  51. _ = s.writeWithTimeout(err.Error())
  52. _ = s.writeWithTimeout("<EOF/>")
  53. }
  54. }
  55. }
  56. // UpdateActivity 跟新最近活跃时间
  57. func (s *Session) UpdateActivity() {
  58. s.mu.Lock()
  59. defer s.mu.Unlock()
  60. s.LastActive = time.Now()
  61. }
  62. func (s *Session) Close() {
  63. s.sessionOnce.Do(func() {
  64. // 控制关闭顺序
  65. close(s.CloseChan)
  66. close(s.MessageChan)
  67. s.forceClose()
  68. })
  69. }
  70. // 带超时的安全写入
  71. func (s *Session) writeWithTimeout(msg string) error {
  72. s.mu.Lock()
  73. defer s.mu.Unlock()
  74. if s.Conn == nil {
  75. utils.FileLog.Error("写入消息失败,connection已关闭")
  76. return errors.New("connection closed")
  77. }
  78. // 设置写超时
  79. if err := s.Conn.SetWriteDeadline(time.Now().Add(writeWaitTimeout)); err != nil {
  80. return err
  81. }
  82. return s.Conn.WriteMessage(websocket.TextMessage, []byte(msg))
  83. }
  84. // writePump 处理写操作
  85. func (s *Session) writePump() {
  86. ticker := time.NewTicker(basePingInterval)
  87. defer func() {
  88. fmt.Printf("写继进程:session %s closed", s.Id)
  89. manager.RemoveSession(s.Id)
  90. ticker.Stop()
  91. }()
  92. for {
  93. select {
  94. case message, ok := <-s.MessageChan:
  95. if !ok {
  96. return
  97. }
  98. _ = s.writeWithTimeout(message)
  99. case <-ticker.C:
  100. _ = s.Latency.SendPing(s.Conn)
  101. ticker.Reset(s.Latency.lastLatency)
  102. case <-s.CloseChan:
  103. return
  104. }
  105. }
  106. }
  107. func handleCloseError(err error) {
  108. utils.FileLog.Error("websocket错误关闭 %s closed", err.Error())
  109. if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) {
  110. var wsErr *websocket.CloseError
  111. if !errors.As(err, &wsErr) {
  112. fmt.Printf("websocket未知错误 %s", err.Error())
  113. utils.FileLog.Error("未知错误 %s", err.Error())
  114. } else {
  115. switch wsErr.Code {
  116. case websocket.CloseNormalClosure:
  117. fmt.Println("websocket正常关闭连接")
  118. utils.FileLog.Info("正常关闭连接")
  119. default:
  120. fmt.Printf("websocket关闭代码 %d:%s", wsErr.Code, wsErr.Text)
  121. utils.FileLog.Error(":%d:%s", wsErr.Code, wsErr.Text)
  122. }
  123. }
  124. }
  125. }
  126. // 强制关闭连接
  127. func (s *Session) forceClose() {
  128. // 添加互斥锁保护
  129. s.mu.Lock()
  130. defer s.mu.Unlock()
  131. // 发送关闭帧
  132. _ = s.Conn.WriteControl(websocket.CloseMessage,
  133. websocket.FormatCloseMessage(websocket.ClosePolicyViolation, "heartbeat failed"),
  134. time.Now().Add(writeWaitTimeout))
  135. _ = s.Conn.Close()
  136. s.Conn = nil // 标记连接已关闭
  137. utils.FileLog.Info("连接已强制关闭",
  138. "user", s.UserId,
  139. "session", s.Id)
  140. }
  141. func NewSession(userId int, sessionId string, conn *websocket.Conn) (session *Session) {
  142. session = &Session{
  143. UserId: userId,
  144. Id: sessionId,
  145. Conn: conn,
  146. LastActive: time.Now(),
  147. CloseChan: make(chan struct{}),
  148. MessageChan: make(chan string, 10),
  149. }
  150. session.Latency = SetupLatencyMeasurement(conn)
  151. go session.readPump()
  152. go session.writePump()
  153. return
  154. }