kobe6258 hace 1 semana
padre
commit
f146b49767
Se han modificado 6 ficheros con 129 adiciones y 146 borrados
  1. 1 0
      main.go
  2. 7 0
      services/ws_service.go
  3. 0 108
      utils/ws/heart_beat_manager.go
  4. 4 0
      utils/ws/latency_measurer.go
  5. 33 5
      utils/ws/session.go
  6. 84 33
      utils/ws/session_manager.go

+ 1 - 0
main.go

@@ -26,6 +26,7 @@ func main() {
 	web.Router("/", &controllers.BaseCommonController{})
 	go services.Task()
 
+	go services.StartSessionManager()
 	// 异常处理
 	web.ErrorController(&controllers.ErrorController{})
 

+ 7 - 0
services/ws_service.go

@@ -4,6 +4,7 @@ import (
 	"eta/eta_api/models"
 	"eta/eta_api/models/system"
 	"eta/eta_api/utils"
+	"eta/eta_api/utils/ws"
 	"fmt"
 	"github.com/beego/beego/v2/server/web"
 	"github.com/beego/beego/v2/server/web/context"
@@ -12,6 +13,8 @@ import (
 	"time"
 )
 
+var ()
+
 func WsAuthenticate() web.FilterFunc {
 	return func(ctx *context.Context) {
 		method := ctx.Input.Method()
@@ -134,3 +137,7 @@ func WsAuthenticate() web.FilterFunc {
 		}
 	}
 }
+
+func StartSessionManager() {
+	ws.GetInstance().Start()
+}

+ 0 - 108
utils/ws/heart_beat_manager.go

@@ -1,108 +0,0 @@
-package ws
-
-import (
-	"eta/eta_api/utils"
-	"github.com/gorilla/websocket"
-	"sync"
-	"time"
-)
-
-const (
-	defaultCheckInterval = 25 * time.Second // 检测间隔应小于心跳超时时间
-	connectionTimeout    = 60 * time.Second // 客户端超时时间
-)
-
-var (
-	hbManager *HeartbeatManager
-	hbOnce    sync.Once
-)
-
-// HeartbeatManager 心跳管理器
-type HeartbeatManager struct {
-	ticker   *time.Ticker
-	sessions sync.Map
-	stopChan chan struct{}
-}
-
-// 获取单例心跳管理器
-func GetHeartbeatManager() *HeartbeatManager {
-	hbOnce.Do(func() {
-		hbManager = &HeartbeatManager{
-			ticker:   time.NewTicker(defaultCheckInterval),
-			stopChan: make(chan struct{}),
-		}
-		go hbManager.Start()
-	})
-	return hbManager
-}
-
-// 启动心跳检测
-func (hm *HeartbeatManager) Start() {
-	defer hm.ticker.Stop()
-
-	for {
-		select {
-		case <-hm.ticker.C:
-			hm.CheckAll()
-		case <-hm.stopChan:
-			return
-		}
-	}
-}
-
-// 停止心跳检测
-func (hm *HeartbeatManager) Stop() {
-	close(hm.stopChan)
-}
-
-// 注册会话
-func (hm *HeartbeatManager) Register(session *Session) {
-	session.LastActive = time.Now()
-	session.CloseChan = make(chan struct{})
-	hm.sessions.Store(session.Id, session)
-}
-
-// 注销会话
-func (hm *HeartbeatManager) Unregister(sessionID string) {
-	if session, ok := hm.sessions.Load(sessionID); ok {
-		if session.(*Session).Conn != nil {
-			_ = session.(*Session).Conn.Close()
-			close(session.(*Session).CloseChan)
-		}
-		hm.sessions.Delete(sessionID)
-	}
-}
-
-// 批量检测所有连接
-func (hm *HeartbeatManager) CheckAll() {
-	hm.sessions.Range(func(key, value interface{}) bool {
-		session := value.(*Session)
-		// 判断超时
-		if time.Since(session.LastActive) > connectionTimeout {
-			utils.FileLog.Warn("连接超时关闭: SessionID=%s, UserID=%s", session.Id, session.UserId)
-			_ = session.Conn.Close()
-			hm.Unregister(session.Id)
-			return true
-		}
-		// 发送心跳
-		go func(s *Session) {
-			err := s.Conn.WriteControl(websocket.PingMessage,
-				nil, time.Now().Add(5*time.Second))
-			if err != nil {
-				utils.FileLog.Warn("心跳发送失败: SessionID=%s, Error=%v",
-					s.Id, err)
-				_ = s.Conn.Close()
-				hm.Unregister(s.Id)
-			}
-		}(session)
-
-		return true
-	})
-}
-
-// UpdateActivity 跟新最近活跃时间
-func (hm *HeartbeatManager) UpdateActivity(sessionID string) {
-	if session, ok := hm.sessions.Load(sessionID); ok {
-		session.(*Session).LastActive = time.Now()
-	}
-}

+ 4 - 0
utils/ws/latency_measurer.go

@@ -1,6 +1,7 @@
 package ws
 
 import (
+	"errors"
 	"github.com/gorilla/websocket"
 	"sync"
 	"time"
@@ -34,6 +35,9 @@ func NewLatencyMeasurer(windowSize int) *LatencyMeasurer {
 func (lm *LatencyMeasurer) SendPing(conn *websocket.Conn) error {
 	lm.mu.Lock()
 	defer lm.mu.Unlock()
+	if conn == nil {
+		return errors.New("connection closed")
+	}
 	// 发送Ping消息
 	err := conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(5*time.Second))
 	if err != nil {

+ 33 - 5
utils/ws/session.go

@@ -20,6 +20,7 @@ type Session struct {
 	CloseChan   chan struct{}
 	MessageChan chan *Message
 	mu          sync.RWMutex
+	sessionOnce sync.Once
 }
 type Message struct {
 	MessageType string
@@ -50,6 +51,30 @@ func (s *Session) readPump() {
 	}
 }
 
+func (s *Session) Close() {
+	s.sessionOnce.Do(func() {
+		// 控制关闭顺序
+		close(s.CloseChan)
+		close(s.MessageChan)
+		s.forceClose()
+	})
+}
+
+// 带超时的安全写入
+func (s *Session) writeWithTimeout(msg *Message) error {
+	s.mu.Lock()
+	defer s.mu.Unlock()
+	if s.Conn == nil {
+		return errors.New("connection closed")
+	}
+	// 设置写超时
+	deadline := time.Now().Add(5 * time.Second)
+	if err := s.Conn.SetWriteDeadline(deadline); err != nil {
+		return err
+	}
+	return s.Conn.WriteMessage(websocket.TextMessage, msg.message)
+}
+
 // writePump 处理写操作
 func (s *Session) writePump() {
 	fmt.Printf("用户写入数据")
@@ -61,7 +86,7 @@ func (s *Session) writePump() {
 			if !ok {
 				return
 			}
-			_ = s.Conn.WriteMessage(websocket.TextMessage, message.message)
+			_ = s.writeWithTimeout(message)
 		case <-ticker.C:
 			_ = s.Latency.SendPing(s.Conn)
 			ticker.Reset(s.Latency.lastLatency)
@@ -88,12 +113,15 @@ func handleCloseError(err error) {
 
 // 强制关闭连接
 func (s *Session) forceClose() {
-	close(s.CloseChan)
+	// 添加互斥锁保护
+	s.mu.Lock()
+	defer s.mu.Unlock()
 	// 发送关闭帧
-	s.Conn.WriteControl(websocket.CloseMessage,
+	_ = s.Conn.WriteControl(websocket.CloseMessage,
 		websocket.FormatCloseMessage(websocket.ClosePolicyViolation, "heartbeat failed"),
 		time.Now().Add(5*time.Second))
-	s.Conn.Close()
+	_ = s.Conn.Close()
+	s.Conn = nil // 标记连接已关闭
 	utils.FileLog.Info("连接已强制关闭",
 		"user", s.UserId,
 		"session", s.Id)
@@ -106,7 +134,7 @@ func NewSession(userId int, sessionId string, conn *websocket.Conn) (session *Se
 		Conn:        conn,
 		History:     []string{},
 		CloseChan:   make(chan struct{}),
-		MessageChan: make(chan *Message),
+		MessageChan: make(chan *Message, 10),
 	}
 	session.Latency = SetupLatencyMeasurement(conn)
 	go session.readPump()

+ 84 - 33
utils/ws/session_manager.go

@@ -2,28 +2,45 @@ package ws
 
 import (
 	"errors"
+	"eta/eta_api/utils"
 	"fmt"
 	"github.com/gorilla/websocket"
 	"sync"
 	"time"
 )
 
+const (
+	defaultCheckInterval = 20 * time.Second // 检测间隔应小于心跳超时时间
+	connectionTimeout    = 60 * time.Second // 客户端超时时间
+)
+
 type ConnectionManager struct {
-	Sessions  sync.Map
-	heartbeat *HeartbeatManager
+	Sessions sync.Map
+	ticker   *time.Ticker
+	stopChan chan struct{}
 }
 
 var (
-	manager = &ConnectionManager{
-		heartbeat: GetHeartbeatManager(),
-	}
+	smOnce  sync.Once
+	manager *ConnectionManager
 )
 
+func GetInstance() *ConnectionManager {
+	smOnce.Do(func() {
+		if manager == nil {
+			manager = &ConnectionManager{
+				ticker:   time.NewTicker(defaultCheckInterval),
+				stopChan: make(chan struct{}),
+			}
+		}
+	})
+	return manager
+}
 func Manager() *ConnectionManager {
 	return manager
 }
 
-// 消息处理核心逻辑
+// HandleMessage 消息处理核心逻辑
 func (manager *ConnectionManager) HandleMessage(userID int, sessionID string, message []byte) error {
 	if !Allow(userID, QA_LIMITER) {
 		return errors.New("您提问的太频繁了,请稍后再试")
@@ -36,40 +53,15 @@ func (manager *ConnectionManager) HandleMessage(userID int, sessionID string, me
 	// 处理业务逻辑
 	session.History = append(session.History, string(message))
 	response := "Processed: " + string(message)
-
 	// 更新最后活跃时间
 	session.LastActive = time.Now()
-
 	// 发送响应
 	return session.Conn.WriteMessage(websocket.TextMessage, []byte(response))
 }
 
-// 心跳管理
-func (manager *ConnectionManager) StartHeartbeat() {
-	ticker := time.NewTicker(basePingInterval)
-	defer ticker.Stop()
-	for range ticker.C {
-		manager.checkSessions()
-	}
-}
-
-func (manager *ConnectionManager) checkSessions() {
-	manager.Sessions.Range(func(key, value interface{}) bool {
-		session := value.(*Session)
-		if time.Since(session.LastActive) > 2*basePingInterval {
-			session.Conn.Close()
-			manager.Sessions.Delete(key)
-		} else {
-			_ = session.Latency.SendPing(session.Conn)
-		}
-		return true
-	})
-}
-
 // AddSession Add 添加一个新的会话
 func (manager *ConnectionManager) AddSession(session *Session) {
 	manager.Sessions.Store(session.Id, session)
-	manager.heartbeat.Register(session)
 }
 func (manager *ConnectionManager) GetSessionId(userId int, sessionId string) (sessionID string) {
 	return fmt.Sprintf("%d_%s", userId, sessionId)
@@ -79,8 +71,7 @@ func (manager *ConnectionManager) GetSessionId(userId int, sessionId string) (se
 func (manager *ConnectionManager) RemoveSession(sessionCode string) {
 	if data, ok := manager.Sessions.LoadAndDelete(sessionCode); ok {
 		session := data.(*Session)
-		close(session.CloseChan)
-		_ = session.Conn.Close()
+		session.Close()
 	}
 }
 
@@ -92,3 +83,63 @@ func (manager *ConnectionManager) GetSession(sessionCode string) (session *Sessi
 	}
 	return
 }
+
+// CheckAll 批量检测所有连接
+func (manager *ConnectionManager) CheckAll() {
+	manager.Sessions.Range(func(key, value interface{}) bool {
+		fmt.Printf("检测接接:%s", key)
+		session := value.(*Session)
+		if session.mu.TryRLock() {
+			defer session.mu.Unlock()
+			// 判断超时
+			if time.Since(session.LastActive) > connectionTimeout {
+				fmt.Printf("连接超时关闭: SessionID=%s, UserID=%s", session.Id, session.UserId)
+				utils.FileLog.Warn("连接超时关闭: SessionID=%s, UserID=%s", session.Id, session.UserId)
+				_ = session.Conn.Close()
+				session.Close()
+				return true
+			}
+			// 发送心跳
+			go func(s *Session) {
+				err := s.Conn.WriteControl(websocket.PingMessage,
+					nil, time.Now().Add(5*time.Second))
+				fmt.Printf(s.Id)
+				if err != nil {
+					fmt.Printf("心跳发送失败: SessionID=%s, Error=%v", s.Id, err)
+					utils.FileLog.Warn("心跳发送失败: SessionID=%s, Error=%v",
+						s.Id, err)
+					_ = s.Conn.Close()
+					session.Close()
+					manager.Sessions.Delete(session.Id)
+				}
+			}(session)
+			return true
+		}
+		return true
+	})
+}
+
+// Start 启动心跳检测
+func (manager *ConnectionManager) Start() {
+	defer manager.ticker.Stop()
+	for {
+		select {
+		case <-manager.ticker.C:
+			manager.CheckAll()
+		case <-manager.stopChan:
+			return
+		}
+	}
+}
+
+// Stop 停止心跳检测
+func (manager *ConnectionManager) Stop() {
+	close(manager.stopChan)
+}
+
+// UpdateActivity 跟新最近活跃时间
+func (manager *ConnectionManager) UpdateActivity(sessionID string) {
+	if session, ok := manager.Sessions.Load(sessionID); ok {
+		session.(*Session).LastActive = time.Now()
+	}
+}