kobe6258 2 months ago
parent
commit
0700c11051
3 changed files with 33 additions and 31 deletions
  1. 1 1
      controllers/rag/chat_controller.go
  2. 6 4
      utils/ws/session.go
  3. 26 26
      utils/ws/session_manager.go

+ 1 - 1
controllers/rag/chat_controller.go

@@ -66,6 +66,6 @@ func webSocketHandler(w http.ResponseWriter, r *http.Request) (conn *websocket.C
 		_ = tcpConn.SetKeepAlivePeriod(90 * time.Second)
 		utils.FileLog.Info("TCP KeepAlive 已启用")
 	}
-	_ = conn.SetReadDeadline(time.Now().Add(time.Second * 60))
+	_ = conn.SetReadDeadline(time.Now().Add(ws.ReadTimeout))
 	return
 }

+ 6 - 4
utils/ws/session.go

@@ -74,8 +74,7 @@ func (s *Session) writeWithTimeout(msg *Message) error {
 		return errors.New("connection closed")
 	}
 	// 设置写超时
-	deadline := time.Now().Add(5 * time.Second)
-	if err := s.Conn.SetWriteDeadline(deadline); err != nil {
+	if err := s.Conn.SetWriteDeadline(time.Now().Add(writeWaitTimeout)); err != nil {
 		return err
 	}
 	return s.Conn.WriteMessage(websocket.TextMessage, msg.message)
@@ -84,7 +83,10 @@ func (s *Session) writeWithTimeout(msg *Message) error {
 // writePump 处理写操作
 func (s *Session) writePump() {
 	ticker := time.NewTicker(basePingInterval)
-	defer ticker.Stop()
+	defer func() {
+		manager.RemoveSession(s.Id)
+		ticker.Stop()
+	}()
 	for {
 		select {
 		case message, ok := <-s.MessageChan:
@@ -96,7 +98,6 @@ func (s *Session) writePump() {
 			_ = s.Latency.SendPing(s.Conn)
 			ticker.Reset(s.Latency.lastLatency)
 		case <-s.CloseChan:
-			manager.RemoveSession(s.Id)
 			return
 		}
 	}
@@ -139,6 +140,7 @@ func NewSession(userId int, sessionId string, conn *websocket.Conn) (session *Se
 		Id:          sessionId,
 		Conn:        conn,
 		History:     []string{},
+		LastActive:  time.Now(),
 		CloseChan:   make(chan struct{}),
 		MessageChan: make(chan *Message, 10),
 	}

+ 26 - 26
utils/ws/session_manager.go

@@ -10,8 +10,10 @@ import (
 )
 
 const (
-	defaultCheckInterval = 20 * time.Second // 检测间隔应小于心跳超时时间
-	connectionTimeout    = 60 * time.Second // 客户端超时时间
+	defaultCheckInterval = 5 * time.Second  // 检测间隔应小于心跳超时时间
+	connectionTimeout    = 20 * time.Second // 客户端超时时间
+	ReadTimeout    = 10 * time.Second // 客户端超时时间
+	writeWaitTimeout = 5 * time.Second
 )
 
 type ConnectionManager struct {
@@ -83,36 +85,32 @@ func (manager *ConnectionManager) GetSession(sessionCode string) (session *Sessi
 
 // CheckAll 批量检测所有连接
 func (manager *ConnectionManager) CheckAll() {
+	n := 0
 	manager.Sessions.Range(func(key, value interface{}) bool {
-		fmt.Printf("检测接接:%s", key)
+		n++
 		session := value.(*Session)
-		if session.mu.TryRLock() {
-			defer session.mu.Unlock()
-			// 判断超时
-			if time.Since(session.LastActive) > connectionTimeout {
-				fmt.Printf("连接超时关闭: SessionID=%s, UserID=%s", session.Id, session.UserId)
-				utils.FileLog.Warn("连接超时关闭: SessionID=%s, UserID=%s", session.Id, session.UserId)
-				session.Close()
-				manager.RemoveSession(session.Id)
-				return true
-			}
-			// 发送心跳
-			go func(s *Session) {
-				err := s.Conn.WriteControl(websocket.PingMessage,
-					nil, time.Now().Add(5*time.Second))
-				fmt.Printf(s.Id)
-				if err != nil {
-					fmt.Printf("心跳发送失败: SessionID=%s, Error=%v", s.Id, err)
-					utils.FileLog.Warn("心跳发送失败: SessionID=%s, Error=%v",
-						s.Id, err)
-					session.Close()
-					manager.RemoveSession(s.Id)
-				}
-			}(session)
+		// 判断超时
+		if time.Since(session.LastActive) > connectionTimeout {
+			fmt.Printf("连接超时关闭: SessionID=%s, UserID=%s", session.Id, session.UserId)
+			utils.FileLog.Warn("连接超时关闭: SessionID=%s, UserID=%s", session.Id, session.UserId)
+			session.Close()
 			return true
 		}
+		// 发送心跳
+		go func(s *Session) {
+			err := s.Conn.WriteControl(websocket.PingMessage,
+				nil, time.Now().Add(5*time.Second))
+			if err != nil {
+				fmt.Printf("心跳发送失败: SessionID=%s, Error=%v", s.Id, err)
+				utils.FileLog.Warn("心跳发送失败: SessionID=%s, Error=%v",
+					s.Id, err)
+				session.Close()
+			}
+		}(session)
+		fmt.Println("当前连接数:", n)
 		return true
 	})
+	fmt.Println("当前连接数:", n)
 }
 
 // Start 启动心跳检测
@@ -121,8 +119,10 @@ func (manager *ConnectionManager) Start() {
 	for {
 		select {
 		case <-manager.ticker.C:
+			fmt.Printf("开始检测连接超时")
 			manager.CheckAll()
 		case <-manager.stopChan:
+			fmt.Printf("退出检测")
 			return
 		}
 	}