session.go 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. package ws
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "eta/eta_api/utils"
  6. "github.com/gorilla/websocket"
  7. "sync"
  8. "time"
  9. )
  10. // Session 会话结构
  11. type Session struct {
  12. Id string
  13. UserId int
  14. Conn *websocket.Conn
  15. LastActive time.Time
  16. Latency *LatencyMeasurer
  17. History []json.RawMessage
  18. CloseChan chan struct{}
  19. MessageChan chan *Message
  20. mu sync.RWMutex
  21. sessionOnce sync.Once
  22. }
  23. type Message struct {
  24. MessageType string
  25. message []byte
  26. }
  27. // readPump 处理读操作
  28. func (s *Session) readPump() {
  29. defer manager.RemoveSession(s.Id)
  30. s.Conn.SetReadLimit(maxMessageSize)
  31. _ = s.Conn.SetReadDeadline(time.Now().Add(60 * time.Second))
  32. for {
  33. _, message, err := s.Conn.ReadMessage()
  34. if err != nil {
  35. handleCloseError(err)
  36. return
  37. }
  38. // 更新活跃时间
  39. s.UpdateActivity()
  40. // 处理消息
  41. if err = manager.HandleMessage(s.UserId, s.Id, message); err != nil {
  42. //写应答
  43. _ = s.writeWithTimeout(&Message{
  44. MessageType: "error",
  45. message: []byte(err.Error()),
  46. })
  47. }
  48. }
  49. }
  50. // UpdateActivity 跟新最近活跃时间
  51. func (s *Session) UpdateActivity() {
  52. s.mu.Lock()
  53. defer s.mu.Unlock()
  54. s.LastActive = time.Now()
  55. }
  56. func (s *Session) Close() {
  57. s.sessionOnce.Do(func() {
  58. // 控制关闭顺序
  59. close(s.CloseChan)
  60. close(s.MessageChan)
  61. s.forceClose()
  62. })
  63. }
  64. // 带超时的安全写入
  65. func (s *Session) writeWithTimeout(msg *Message) error {
  66. s.mu.Lock()
  67. defer s.mu.Unlock()
  68. if s.Conn == nil {
  69. return errors.New("connection closed")
  70. }
  71. // 设置写超时
  72. if err := s.Conn.SetWriteDeadline(time.Now().Add(writeWaitTimeout)); err != nil {
  73. return err
  74. }
  75. return s.Conn.WriteMessage(websocket.TextMessage, msg.message)
  76. }
  77. // writePump 处理写操作
  78. func (s *Session) writePump() {
  79. ticker := time.NewTicker(basePingInterval)
  80. defer func() {
  81. manager.RemoveSession(s.Id)
  82. ticker.Stop()
  83. }()
  84. for {
  85. select {
  86. case message, ok := <-s.MessageChan:
  87. if !ok {
  88. return
  89. }
  90. _ = s.writeWithTimeout(message)
  91. case <-ticker.C:
  92. _ = s.Latency.SendPing(s.Conn)
  93. ticker.Reset(s.Latency.lastLatency)
  94. case <-s.CloseChan:
  95. return
  96. }
  97. }
  98. }
  99. func handleCloseError(err error) {
  100. if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) {
  101. var wsErr *websocket.CloseError
  102. if !errors.As(err, &wsErr) {
  103. utils.FileLog.Error("未知错误 %s", err.Error())
  104. } else {
  105. switch wsErr.Code {
  106. case websocket.CloseNormalClosure:
  107. utils.FileLog.Info("正常关闭连接")
  108. default:
  109. utils.FileLog.Error("关闭代码:%d:%s", wsErr.Code, wsErr.Text)
  110. }
  111. }
  112. }
  113. }
  114. // 强制关闭连接
  115. func (s *Session) forceClose() {
  116. // 添加互斥锁保护
  117. s.mu.Lock()
  118. defer s.mu.Unlock()
  119. // 发送关闭帧
  120. _ = s.Conn.WriteControl(websocket.CloseMessage,
  121. websocket.FormatCloseMessage(websocket.ClosePolicyViolation, "heartbeat failed"),
  122. time.Now().Add(5*time.Second))
  123. _ = s.Conn.Close()
  124. s.Conn = nil // 标记连接已关闭
  125. utils.FileLog.Info("连接已强制关闭",
  126. "user", s.UserId,
  127. "session", s.Id)
  128. }
  129. func NewSession(userId int, sessionId string, conn *websocket.Conn) (session *Session) {
  130. session = &Session{
  131. UserId: userId,
  132. Id: sessionId,
  133. Conn: conn,
  134. LastActive: time.Now(),
  135. CloseChan: make(chan struct{}),
  136. MessageChan: make(chan *Message, 10),
  137. }
  138. session.Latency = SetupLatencyMeasurement(conn)
  139. go session.readPump()
  140. go session.writePump()
  141. return
  142. }