package ws import ( "errors" "eta/eta_api/utils" "github.com/gorilla/websocket" "sync" "time" ) // Session 会话结构 type Session struct { Id string UserId int Conn *websocket.Conn LastActive time.Time Latency *LatencyMeasurer History []string CloseChan chan struct{} MessageChan chan *Message mu sync.RWMutex sessionOnce sync.Once } type Message struct { MessageType string message []byte } // readPump 处理读操作 func (s *Session) readPump() { defer manager.RemoveSession(s.Id) s.Conn.SetReadLimit(maxMessageSize) _ = s.Conn.SetReadDeadline(time.Now().Add(60 * time.Second)) for { _, message, err := s.Conn.ReadMessage() if err != nil { handleCloseError(err) return } // 更新活跃时间 s.UpdateActivity() // 处理消息 if err = manager.HandleMessage(s.UserId, s.Id, message); err != nil { //写应答 _ = s.writeWithTimeout(&Message{ MessageType: "error", message: []byte(err.Error()), }) } } } // UpdateActivity 跟新最近活跃时间 func (s *Session) UpdateActivity() { s.mu.Lock() defer s.mu.Unlock() s.LastActive = time.Now() } func (s *Session) Close() { s.sessionOnce.Do(func() { // 控制关闭顺序 close(s.CloseChan) close(s.MessageChan) s.forceClose() }) } // 带超时的安全写入 func (s *Session) writeWithTimeout(msg *Message) error { s.mu.Lock() defer s.mu.Unlock() if s.Conn == nil { return errors.New("connection closed") } // 设置写超时 if err := s.Conn.SetWriteDeadline(time.Now().Add(writeWaitTimeout)); err != nil { return err } return s.Conn.WriteMessage(websocket.TextMessage, msg.message) } // writePump 处理写操作 func (s *Session) writePump() { ticker := time.NewTicker(basePingInterval) defer func() { manager.RemoveSession(s.Id) ticker.Stop() }() for { select { case message, ok := <-s.MessageChan: if !ok { return } _ = s.writeWithTimeout(message) case <-ticker.C: _ = s.Latency.SendPing(s.Conn) ticker.Reset(s.Latency.lastLatency) case <-s.CloseChan: return } } } func handleCloseError(err error) { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) { var wsErr *websocket.CloseError if !errors.As(err, &wsErr) { utils.FileLog.Error("未知错误 %s", err.Error()) } else { switch wsErr.Code { case websocket.CloseNormalClosure: utils.FileLog.Info("正常关闭连接") default: utils.FileLog.Error("关闭代码:%d:%s", wsErr.Code, wsErr.Text) } } } } // 强制关闭连接 func (s *Session) forceClose() { // 添加互斥锁保护 s.mu.Lock() defer s.mu.Unlock() // 发送关闭帧 _ = s.Conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.ClosePolicyViolation, "heartbeat failed"), time.Now().Add(5*time.Second)) _ = s.Conn.Close() s.Conn = nil // 标记连接已关闭 utils.FileLog.Info("连接已强制关闭", "user", s.UserId, "session", s.Id) } func NewSession(userId int, sessionId string, conn *websocket.Conn) (session *Session) { session = &Session{ UserId: userId, Id: sessionId, Conn: conn, History: []string{}, LastActive: time.Now(), CloseChan: make(chan struct{}), MessageChan: make(chan *Message, 10), } session.Latency = SetupLatencyMeasurement(conn) go session.readPump() go session.writePump() return }