session.go 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. package ws
  2. import (
  3. "errors"
  4. "eta/eta_api/utils"
  5. "fmt"
  6. "github.com/gorilla/websocket"
  7. "sync"
  8. "time"
  9. )
  10. // Session 会话结构
  11. type Session struct {
  12. Id string
  13. UserId int
  14. Conn *websocket.Conn
  15. LastActive time.Time
  16. Latency *LatencyMeasurer
  17. History []string
  18. CloseChan chan struct{}
  19. MessageChan chan string
  20. mu sync.RWMutex
  21. sessionOnce sync.Once
  22. }
  23. type Message struct {
  24. KbName string `json:"KbName"`
  25. Query string `json:"Query"`
  26. LastTopics []string `json:"LastTopics"`
  27. }
  28. // readPump 处理读操作
  29. func (s *Session) readPump() {
  30. defer func() {
  31. fmt.Printf("读进程session %s closed", s.Id)
  32. manager.RemoveSession(s.Id)
  33. }()
  34. s.Conn.SetReadLimit(maxMessageSize)
  35. _ = s.Conn.SetReadDeadline(time.Now().Add(ReadTimeout))
  36. for {
  37. _, message, err := s.Conn.ReadMessage()
  38. if err != nil {
  39. fmt.Printf("websocket 错误关闭 %s closed", err.Error())
  40. handleCloseError(err)
  41. return
  42. }
  43. // 更新活跃时间
  44. s.UpdateActivity()
  45. // 处理消息
  46. if err = manager.HandleMessage(s.UserId, s.Id, message); err != nil {
  47. //写应答
  48. _ = s.writeWithTimeout(err.Error())
  49. }
  50. }
  51. }
  52. // UpdateActivity 跟新最近活跃时间
  53. func (s *Session) UpdateActivity() {
  54. s.mu.Lock()
  55. defer s.mu.Unlock()
  56. s.LastActive = time.Now()
  57. }
  58. func (s *Session) Close() {
  59. s.sessionOnce.Do(func() {
  60. // 控制关闭顺序
  61. close(s.CloseChan)
  62. close(s.MessageChan)
  63. s.forceClose()
  64. })
  65. }
  66. // 带超时的安全写入
  67. func (s *Session) writeWithTimeout(msg string) error {
  68. s.mu.Lock()
  69. defer s.mu.Unlock()
  70. if s.Conn == nil {
  71. return errors.New("connection closed")
  72. }
  73. // 设置写超时
  74. if err := s.Conn.SetWriteDeadline(time.Now().Add(writeWaitTimeout)); err != nil {
  75. return err
  76. }
  77. return s.Conn.WriteMessage(websocket.TextMessage, []byte(msg))
  78. }
  79. // writePump 处理写操作
  80. func (s *Session) writePump() {
  81. ticker := time.NewTicker(basePingInterval)
  82. defer func() {
  83. fmt.Printf("写继进程:session %s closed", s.Id)
  84. manager.RemoveSession(s.Id)
  85. ticker.Stop()
  86. }()
  87. for {
  88. select {
  89. case message, ok := <-s.MessageChan:
  90. if !ok {
  91. return
  92. }
  93. _ = s.writeWithTimeout(message)
  94. case <-ticker.C:
  95. _ = s.Latency.SendPing(s.Conn)
  96. ticker.Reset(s.Latency.lastLatency)
  97. case <-s.CloseChan:
  98. return
  99. }
  100. }
  101. }
  102. func handleCloseError(err error) {
  103. if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) {
  104. var wsErr *websocket.CloseError
  105. if !errors.As(err, &wsErr) {
  106. fmt.Printf("websocket未知错误 %s", err.Error())
  107. utils.FileLog.Error("未知错误 %s", err.Error())
  108. } else {
  109. switch wsErr.Code {
  110. case websocket.CloseNormalClosure:
  111. fmt.Println("websocket正常关闭连接")
  112. utils.FileLog.Info("正常关闭连接")
  113. default:
  114. fmt.Printf("websocket关闭代码 %d:%s", wsErr.Code, wsErr.Text)
  115. utils.FileLog.Error("关闭代码:%d:%s", wsErr.Code, wsErr.Text)
  116. }
  117. }
  118. }
  119. }
  120. // 强制关闭连接
  121. func (s *Session) forceClose() {
  122. // 添加互斥锁保护
  123. s.mu.Lock()
  124. defer s.mu.Unlock()
  125. // 发送关闭帧
  126. _ = s.Conn.WriteControl(websocket.CloseMessage,
  127. websocket.FormatCloseMessage(websocket.ClosePolicyViolation, "heartbeat failed"),
  128. time.Now().Add(writeWaitTimeout))
  129. _ = s.Conn.Close()
  130. s.Conn = nil // 标记连接已关闭
  131. utils.FileLog.Info("连接已强制关闭",
  132. "user", s.UserId,
  133. "session", s.Id)
  134. }
  135. func NewSession(userId int, sessionId string, conn *websocket.Conn) (session *Session) {
  136. session = &Session{
  137. UserId: userId,
  138. Id: sessionId,
  139. Conn: conn,
  140. LastActive: time.Now(),
  141. CloseChan: make(chan struct{}),
  142. MessageChan: make(chan string, 10),
  143. }
  144. session.Latency = SetupLatencyMeasurement(conn)
  145. go session.readPump()
  146. go session.writePump()
  147. return
  148. }