package ws import ( "encoding/json" "errors" "eta/eta_api/utils" "fmt" "github.com/gorilla/websocket" "sync" "time" ) // Session 会话结构 type Session struct { Id string UserId int Conn *websocket.Conn LastActive time.Time Latency *LatencyMeasurer History []json.RawMessage CloseChan chan struct{} MessageChan chan string mu sync.RWMutex sessionOnce sync.Once } type Message struct { KbName string `json:"KbName"` Query string `json:"Query"` ChatId int `json:"ChatId"` //LastTopics []json.RawMessage `json:"LastTopics"` } // readPump 处理读操作 func (s *Session) readPump() { defer func() { fmt.Printf("读进程session %s closed", s.Id) manager.RemoveSession(s.Id) }() s.Conn.SetReadLimit(maxMessageSize) _ = s.Conn.SetReadDeadline(time.Now().Add(ReadTimeout)) for { _, message, err := s.Conn.ReadMessage() if err != nil { fmt.Printf("websocket 错误关闭 %s closed", err.Error()) handleCloseError(err) return } // 更新活跃时间 s.UpdateActivity() // 处理消息 if err = manager.HandleMessage(s.UserId, s.Id, message); err != nil { //写应答 _ = s.writeWithTimeout("") _ = s.writeWithTimeout(err.Error()) _ = s.writeWithTimeout("") } } } // UpdateActivity 跟新最近活跃时间 func (s *Session) UpdateActivity() { s.mu.Lock() defer s.mu.Unlock() s.LastActive = time.Now() } func (s *Session) Close() { s.sessionOnce.Do(func() { // 控制关闭顺序 close(s.CloseChan) close(s.MessageChan) s.forceClose() }) } // 带超时的安全写入 func (s *Session) writeWithTimeout(msg string) error { s.mu.Lock() defer s.mu.Unlock() if s.Conn == nil { utils.FileLog.Error("写入消息失败,connection已关闭") return errors.New("connection closed") } // 设置写超时 if err := s.Conn.SetWriteDeadline(time.Now().Add(writeWaitTimeout)); err != nil { return err } return s.Conn.WriteMessage(websocket.TextMessage, []byte(msg)) } // writePump 处理写操作 func (s *Session) writePump() { ticker := time.NewTicker(basePingInterval) defer func() { fmt.Printf("写继进程:session %s closed", s.Id) manager.RemoveSession(s.Id) ticker.Stop() }() for { select { case message, ok := <-s.MessageChan: if !ok { return } _ = s.writeWithTimeout(message) case <-ticker.C: _ = s.Latency.SendPing(s.Conn) ticker.Reset(s.Latency.lastLatency) case <-s.CloseChan: return } } } func handleCloseError(err error) { utils.FileLog.Error("websocket错误关闭 %s closed", err.Error()) if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) { var wsErr *websocket.CloseError if !errors.As(err, &wsErr) { fmt.Printf("websocket未知错误 %s", err.Error()) utils.FileLog.Error("未知错误 %s", err.Error()) } else { switch wsErr.Code { case websocket.CloseNormalClosure: fmt.Println("websocket正常关闭连接") utils.FileLog.Info("正常关闭连接") default: fmt.Printf("websocket关闭代码 %d:%s", wsErr.Code, wsErr.Text) utils.FileLog.Error(":%d:%s", wsErr.Code, wsErr.Text) } } } } // 强制关闭连接 func (s *Session) forceClose() { // 添加互斥锁保护 s.mu.Lock() defer s.mu.Unlock() // 发送关闭帧 _ = s.Conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.ClosePolicyViolation, "heartbeat failed"), time.Now().Add(writeWaitTimeout)) _ = s.Conn.Close() s.Conn = nil // 标记连接已关闭 utils.FileLog.Info("连接已强制关闭", "user", s.UserId, "session", s.Id) } func NewSession(userId int, sessionId string, conn *websocket.Conn) (session *Session) { session = &Session{ UserId: userId, Id: sessionId, Conn: conn, LastActive: time.Now(), CloseChan: make(chan struct{}), MessageChan: make(chan string, 10), } session.Latency = SetupLatencyMeasurement(conn) go session.readPump() go session.writePump() return }