package ws import ( "errors" "eta/eta_api/utils" "fmt" "github.com/gorilla/websocket" "sync" "time" ) const ( defaultCheckInterval = 5 * time.Second // 检测间隔应小于心跳超时时间 connectionTimeout = 20 * time.Second // 客户端超时时间 ReadTimeout = 10 * time.Second // 客户端超时时间 writeWaitTimeout = 5 * time.Second ) type ConnectionManager struct { Sessions sync.Map ticker *time.Ticker stopChan chan struct{} } var ( smOnce sync.Once manager *ConnectionManager ) func GetInstance() *ConnectionManager { smOnce.Do(func() { if manager == nil { manager = &ConnectionManager{ ticker: time.NewTicker(defaultCheckInterval), stopChan: make(chan struct{}), } } }) return manager } func Manager() *ConnectionManager { return manager } // HandleMessage 消息处理核心逻辑 func (manager *ConnectionManager) HandleMessage(userID int, sessionID string, message []byte) error { if !Allow(userID, QA_LIMITER) { return errors.New("您提问的太频繁了,请稍后再试") } session, exists := manager.GetSession(sessionID) if !exists { return errors.New("session not found") } // 处理业务逻辑 session.History = append(session.History, string(message)) response := "Processed: " + string(message) // 更新最后活跃时间 session.LastActive = time.Now() // 发送响应 return session.Conn.WriteMessage(websocket.TextMessage, []byte(response)) } // AddSession Add 添加一个新的会话 func (manager *ConnectionManager) AddSession(session *Session) { manager.Sessions.Store(session.Id, session) } func (manager *ConnectionManager) GetSessionId(userId int, sessionId string) (sessionID string) { return fmt.Sprintf("%d_%s", userId, sessionId) } // RemoveSession Remove 移除一个会话 func (manager *ConnectionManager) RemoveSession(sessionCode string) { manager.Sessions.Delete(sessionCode) } // GetSession 获取一个会话 func (manager *ConnectionManager) GetSession(sessionCode string) (session *Session, exists bool) { if data, ok := manager.Sessions.Load(sessionCode); ok { session = data.(*Session) exists = ok } return } // CheckAll 批量检测所有连接 func (manager *ConnectionManager) CheckAll() { n := 0 manager.Sessions.Range(func(key, value interface{}) bool { n++ session := value.(*Session) // 判断超时 if time.Since(session.LastActive) > connectionTimeout { fmt.Printf("连接超时关闭: SessionID=%s, UserID=%s", session.Id, session.UserId) utils.FileLog.Warn("连接超时关闭: SessionID=%s, UserID=%s", session.Id, session.UserId) session.Close() return true } // 发送心跳 go func(s *Session) { err := s.Conn.WriteControl(websocket.PingMessage, nil, time.Now().Add(5*time.Second)) if err != nil { fmt.Printf("心跳发送失败: SessionID=%s, Error=%v", s.Id, err) utils.FileLog.Warn("心跳发送失败: SessionID=%s, Error=%v", s.Id, err) session.Close() } }(session) fmt.Println("当前连接数:", n) return true }) fmt.Println("当前连接数:", n) } // Start 启动心跳检测 func (manager *ConnectionManager) Start() { defer manager.ticker.Stop() for { select { case <-manager.ticker.C: fmt.Printf("开始检测连接超时") manager.CheckAll() case <-manager.stopChan: fmt.Printf("退出检测") return } } } // Stop 停止心跳检测 func (manager *ConnectionManager) Stop() { close(manager.stopChan) }