kobe6258 1 周之前
父節點
當前提交
7ebeb4704b

+ 12 - 7
controllers/rag/chat_controller.go

@@ -3,11 +3,11 @@ package rag
 import (
 	"eta/eta_api/controllers"
 	"eta/eta_api/models/system"
-	"eta/eta_api/services"
 	"eta/eta_api/services/llm/facade"
 	"eta/eta_api/utils"
 	"eta/eta_api/utils/ws"
 	"github.com/gorilla/websocket"
+	"net"
 	"net/http"
 	"time"
 )
@@ -31,11 +31,11 @@ func (cc *ChatController) Prepare() {
 // @Success 200 {object} response.ListResp
 // @router /chat/connect [get]
 func (cc *ChatController) ChatConnect() {
-	//if !ws.Allow(cc.SysUser.AdminId) {
-	//	utils.FileLog.Error("WebSocket连接太频繁,主动拒绝链接")
-	//	cc.Ctx.ResponseWriter.WriteHeader(http.StatusTooManyRequests)
-	//	return
-	//}
+	if !ws.Allow(cc.SysUser.AdminId, ws.CONNECT_LIMITER) {
+		utils.FileLog.Error("WebSocket连接太频繁,主动拒绝链接")
+		cc.Ctx.ResponseWriter.WriteHeader(http.StatusTooManyRequests)
+		return
+	}
 	wsCon, err := webSocketHandler(cc.Ctx.ResponseWriter, cc.Ctx.Request)
 	if err != nil {
 		utils.FileLog.Error("WebSocket连接失败:", err)
@@ -66,7 +66,12 @@ func webSocketHandler(w http.ResponseWriter, r *http.Request) (conn *websocket.C
 		utils.FileLog.Error("升级协议失败:WebSocket:%s", err.Error())
 		return
 	}
+	// 获取底层 TCP 连接并设置保活
+	if tcpConn, ok := conn.NetConn().(*net.TCPConn); ok {
+		_ = tcpConn.SetKeepAlive(true)
+		_ = tcpConn.SetKeepAlivePeriod(90 * time.Second)
+		utils.FileLog.Info("TCP KeepAlive 已启用")
+	}
 	_ = conn.SetReadDeadline(time.Now().Add(time.Second * 60))
-	services.HandleWebSocketConnection(conn)
 	return
 }

+ 0 - 4
services/ws_service.go

@@ -48,10 +48,6 @@ func WsAuthenticate() web.FilterFunc {
 			tokenArr := strings.Split(tokenStr, "=")
 			token := tokenArr[1]
 
-			//accountStr := authorizationArr[1]
-			//accountArr := strings.Split(accountStr, "=")
-			//account := accountArr[1]
-
 			session, err := system.GetSysSessionByToken(token)
 			if err != nil {
 				if utils.IsErrNoRow(err) {

+ 106 - 0
utils/ws/heart_beat_manager.go

@@ -0,0 +1,106 @@
+package ws
+
+import (
+	"eta/eta_api/utils"
+	"github.com/gorilla/websocket"
+	"sync"
+	"time"
+)
+
+const (
+	defaultCheckInterval = 25 * time.Second // 检测间隔应小于心跳超时时间
+	connectionTimeout    = 60 * time.Second // 客户端超时时间
+)
+
+var (
+	hbManager *HeartbeatManager
+	hbOnce    sync.Once
+)
+
+// HeartbeatManager 心跳管理器
+type HeartbeatManager struct {
+	ticker   *time.Ticker
+	sessions sync.Map
+	stopChan chan struct{}
+}
+
+// 获取单例心跳管理器
+func GetHeartbeatManager() *HeartbeatManager {
+	hbOnce.Do(func() {
+		hbManager = &HeartbeatManager{
+			ticker:   time.NewTicker(defaultCheckInterval),
+			stopChan: make(chan struct{}),
+		}
+		go hbManager.Start()
+	})
+	return hbManager
+}
+
+// 启动心跳检测
+func (hm *HeartbeatManager) Start() {
+	defer hm.ticker.Stop()
+
+	for {
+		select {
+		case <-hm.ticker.C:
+			hm.CheckAll()
+		case <-hm.stopChan:
+			return
+		}
+	}
+}
+
+// 停止心跳检测
+func (hm *HeartbeatManager) Stop() {
+	close(hm.stopChan)
+}
+
+// 注册会话
+func (hm *HeartbeatManager) Register(session *Session) {
+	session.LastActive = time.Now()
+	session.CloseChan = make(chan struct{})
+	hm.sessions.Store(session.ID, session)
+}
+
+// 注销会话
+func (hm *HeartbeatManager) Unregister(sessionID string) {
+	if session, ok := hm.sessions.LoadAndDelete(sessionID); ok {
+		close(session.(*Session).CloseChan)
+		hm.sessions.Delete(sessionID)
+	}
+}
+
+// 批量检测所有连接
+func (hm *HeartbeatManager) CheckAll() {
+	hm.sessions.Range(func(key, value interface{}) bool {
+		session := value.(*Session)
+		// 判断超时
+		if time.Since(session.LastActive) > connectionTimeout {
+			utils.FileLog.Warn("连接超时关闭: SessionID=%s, UserID=%s", session.ID, session.UserID)
+			_ = session.Conn.Close()
+			hm.Unregister(session.ID)
+			return true
+		}
+
+		// 发送心跳
+		go func(s *Session) {
+			err := s.Conn.WriteControl(websocket.PingMessage,
+				nil, time.Now().Add(5*time.Second))
+			if err != nil {
+				utils.FileLog.Warn("心跳发送失败: SessionID=%s, Error=%v",
+					s.ID, err)
+				_ = s.Conn.Close()
+				hm.Unregister(s.ID)
+			}
+		}(session)
+
+		return true
+	})
+}
+
+// UpdateActivity 跟新最近活跃时间
+func (hm *HeartbeatManager) UpdateActivity(sessionID string) {
+	if session, ok := hm.sessions.Load(sessionID); ok {
+		session.(*Session).LastActive = time.Now()
+	}
+}

+ 74 - 0
utils/ws/latency_measurer.go

@@ -0,0 +1,74 @@
+package ws
+
+import (
+	"github.com/gorilla/websocket"
+	"sync"
+	"time"
+)
+
+// LatencyMeasurer 延迟测量器
+type LatencyMeasurer struct {
+	measurements    []time.Duration
+	lastLatency     time.Duration
+	mu              sync.Mutex
+	lastPingTime    time.Time // 最后一次发送Ping的时间
+	maxMeasurements int       // 保留的最大测量次数
+}
+
+func NewLatencyMeasurer(windowSize int) *LatencyMeasurer {
+	return &LatencyMeasurer{
+		maxMeasurements: windowSize,
+	}
+}
+
+// 发送Ping并记录时间戳
+func (lm *LatencyMeasurer) SendPing(conn *websocket.Conn) error {
+	lm.mu.Lock()
+	defer lm.mu.Unlock()
+	// 发送Ping消息
+	err := conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(5*time.Second))
+	if err != nil {
+		return err
+	}
+	lm.lastPingTime = time.Now()
+	return nil
+}
+
+// 处理Pong响应
+func (lm *LatencyMeasurer) CalculateLatency() {
+	lm.mu.Lock()
+	defer lm.mu.Unlock()
+	if lm.lastPingTime.IsZero() {
+		return
+	}
+	// 计算往返时间
+	rtt := time.Since(lm.lastPingTime)
+	// 维护滑动窗口
+	if len(lm.measurements) >= lm.maxMeasurements {
+		lm.measurements = lm.measurements[1:]
+	}
+	lm.measurements = append(lm.measurements, rtt)
+	// 计算平均延迟(可根据需求改为中位数等)
+	sum := time.Duration(0)
+	for _, d := range lm.measurements {
+		sum += d
+	}
+	lm.lastLatency = sum / time.Duration(len(lm.measurements))
+}
+
+// 获取当前网络延迟估值
+func (lm *LatencyMeasurer) GetLatency() time.Duration {
+	lm.mu.Lock()
+	defer lm.mu.Unlock()
+	return lm.lastLatency
+}
+
+// 在连接初始化时设置Pong处理器
+func SetupLatencyMeasurement(conn *websocket.Conn) *LatencyMeasurer {
+	lm := NewLatencyMeasurer(5) // 使用最近5次测量的滑动窗口
+	conn.SetPongHandler(func(appData string) error {
+		lm.CalculateLatency()
+		return nil
+	})
+	return lm
+}

+ 37 - 22
utils/ws/limiter.go

@@ -8,46 +8,50 @@ import (
 )
 
 var (
-	limiterManager *QALimiterManger
-	limiterOnce    sync.Once
+	limiterManagers map[string]*LimiterManger
+	limiterOnce     sync.Once
+	limters         = map[string]string{
+		CONNECT_LIMITER: LIMITER_KEY,
+		QA_LIMITER:      CONNECT_LIMITER_KEY,
+	}
 )
 
 const (
-	LIMITER_KEY = "llm_chat_key_user_%d"
+	CONNECT_LIMITER     = "connetLimiter"
+	QA_LIMITER          = "qaLimiter"
+	LIMITER_KEY         = "llm_chat_key_user_%d"
+	CONNECT_LIMITER_KEY = "llm_chat_connect_key_user_%d"
 )
 
-type QALimiterManger struct {
-	sync.RWMutex
-	limiterMap map[string]*QALimiter
-}
-
-type QALimiter struct {
+type RateLimiter struct {
 	LastRequest time.Time
 	*rate.Limiter
 }
+type LimiterManger struct {
+	sync.RWMutex
+	limiterMap map[string]*RateLimiter
+}
 
 //func (qaLimiter *QALimiter) Allow() bool {
 //	return qaLimiter.Limiter.Allow()
 //}
 
 // GetLimiter 获取或创建用户的限流器
-func (qalm *QALimiterManger) GetLimiter(token string) *QALimiter {
+func (qalm *LimiterManger) GetLimiter(token string) *RateLimiter {
 	qalm.Lock()
 	defer qalm.Unlock()
-
 	if limiter, exists := qalm.limiterMap[token]; exists {
 		return limiter
 	}
 
 	// 创建一个新的限流器,例如每10秒1个请求
-	limiter := &QALimiter{
+	limiter := &RateLimiter{
 		Limiter: rate.NewLimiter(rate.Every(10*time.Second), 1),
 	}
 	qalm.limiterMap[token] = limiter
 	return limiter
 }
-
-func (qalm *QALimiterManger) Allow(token string) bool {
+func (qalm *LimiterManger) Allow(token string) bool {
 
 	limiter := qalm.GetLimiter(token)
 	if limiter.LastRequest.IsZero() {
@@ -60,18 +64,29 @@ func (qalm *QALimiterManger) Allow(token string) bool {
 	limiter.LastRequest = time.Now()
 	return limiter.Allow()
 }
-func getInstance() *QALimiterManger {
+func getInstance(key string) *LimiterManger {
 	limiterOnce.Do(func() {
-		if limiterManager == nil {
-			limiterManager = &QALimiterManger{
-				limiterMap: make(map[string]*QALimiter),
+		if limiterManagers == nil {
+			limiterManagers = make(map[string]*LimiterManger, len(limters))
+		}
+		for key, _ := range limters {
+			limiterManagers[key] = &LimiterManger{
+				limiterMap: make(map[string]*RateLimiter),
 			}
 		}
 	})
-	return limiterManager
+	return limiterManagers[key]
 }
 
-func Allow(userId int) bool {
-	token := fmt.Sprintf(LIMITER_KEY, userId)
-	return getInstance().Allow(token)
+func Allow(userId int, limiter string) bool {
+	tokenKey := limters[limiter]
+	if tokenKey == "" {
+		return false
+	}
+	token := fmt.Sprintf(tokenKey, userId)
+	handler := getInstance(limiter)
+	if handler == nil {
+		return false
+	}
+	return handler.Allow(token)
 }

+ 70 - 18
utils/ws/session.go

@@ -1,6 +1,8 @@
 package ws
 
 import (
+	"errors"
+	"eta/eta_api/utils"
 	"github.com/gorilla/websocket"
 	"sync"
 	"time"
@@ -12,31 +14,81 @@ type Session struct {
 	UserID      int
 	Conn        *websocket.Conn
 	LastActive  time.Time
-	qaLimiter   *QALimiter
 	Latency     *LatencyMeasurer
+	History     []string
 	CloseChan   chan struct{}
 	MessageChan chan []byte
 	mu          sync.RWMutex
 }
 
-// HeartbeatManager 心跳管理器
-type HeartbeatManager struct {
-	interval  time.Duration
-	sessions  sync.Map
-	closeChan chan struct{}
-}
+// readPump 处理读操作
+func (s *Session) readPump() {
+	defer manager.RemoveSession(s.ID)
+	s.Conn.SetReadLimit(maxMessageSize)
+	_ = s.Conn.SetReadDeadline(time.Now().Add(60 * time.Second))
+	for {
+		_, message, err := s.Conn.ReadMessage()
+		if err != nil {
+			handleCloseError(err)
+			return
+		}
+
+		// 更新活跃时间
+		s.mu.Lock()
+		s.LastActive = time.Now()
+		s.mu.Unlock()
+		// 处理消息
+		if err = manager.HandleMessage(s.ID, message); err != nil {
 
-// LatencyMeasurer 延迟测量器
-type LatencyMeasurer struct {
-	measurements []time.Duration
-	lastLatency  time.Duration
-	mu           sync.Mutex
+		}
+	}
 }
 
-// NewHeartbeatManager 创建心跳管理器
-func NewHeartbeatManager(interval time.Duration) *HeartbeatManager {
-	return &HeartbeatManager{
-		interval:  interval,
-		closeChan: make(chan struct{}),
+// writePump 处理写操作
+func (s *Session) writePump() {
+	ticker := time.NewTicker(basePingInterval)
+	defer ticker.Stop()
+	for {
+		select {
+		case message, ok := <-s.MessageChan:
+			if !ok {
+				return
+			}
+			_ = s.Conn.WriteMessage(websocket.TextMessage, message)
+		case <-ticker.C:
+			_ = s.Latency.SendPing(s.Conn)
+			ticker.Reset(s.Latency.lastLatency)
+		case <-s.CloseChan:
+			return
+		}
+	}
+}
+func handleCloseError(err error) {
+	if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) {
+		var wsErr *websocket.CloseError
+		if !errors.As(err, &wsErr) {
+			utils.FileLog.Error("未知错误 %s", err.Error())
+		} else {
+			switch wsErr.Code {
+			case websocket.CloseNormalClosure:
+				utils.FileLog.Info("正常关闭连接")
+			default:
+				utils.FileLog.Error("关闭代码:%d:%s", wsErr.Code, wsErr.Text)
+			}
+		}
 	}
-}
+}
+
+// 强制关闭连接
+func (s *Session) forceClose() {
+	close(s.CloseChan)
+	// 发送关闭帧
+	s.Conn.WriteControl(websocket.CloseMessage,
+		websocket.FormatCloseMessage(websocket.ClosePolicyViolation, "heartbeat failed"),
+		time.Now().Add(5*time.Second))
+
+	s.Conn.Close()
+	utils.FileLog.Info("连接已强制关闭",
+		"user", s.UserID,
+		"session", s.ID)
+}

+ 63 - 104
utils/ws/session_manager.go

@@ -1,28 +1,28 @@
 package ws
 
 import (
-	"eta/eta_api/utils"
+	"errors"
 	"fmt"
 	"github.com/gorilla/websocket"
-	"math/rand"
-	"net"
+	"sync"
 	"time"
 )
+
 const (
-	maxMessageSize  = 1024 * 1024 // 1MB
+	maxMessageSize   = 1024 * 1024 // 1MB
 	basePingInterval = 30 * time.Second
 	maxPingInterval  = 120 * time.Second
 	minPingInterval  = 15 * time.Second
 )
+
 type ConnectionManager struct {
-	Sessions    map[string]*Session
-	heartbeat   *HeartbeatManager
+	Sessions  sync.Map
+	heartbeat *HeartbeatManager
 }
 
 var (
 	manager = &ConnectionManager{
-		Sessions:  make(map[string]*Session),
-		heartbeat: NewHeartbeatManager(30 * time.Second),
+		heartbeat: GetHeartbeatManager(),
 	}
 )
 
@@ -30,112 +30,71 @@ func Manager() *ConnectionManager {
 	return manager
 }
 
-// Add 添加一个新的会话
-func (manager *ConnectionManager) Add(session *Session) {
-	manager.Lock()
-	defer manager.Unlock()
-	manager.Sessions[manager.GetSessionId(session.UserId, session.SessionId)] = session
-}
-func (manager *ConnectionManager) GetSessionId(userId int, sessionId string) (sessionID string) {
-	return fmt.Sprintf("%d_%s", userId, sessionId)
-}
+// 消息处理核心逻辑
+func (manager *ConnectionManager) HandleMessage(userID int, sessionID string, message []byte) error {
+	if !Allow(userID, QA_LIMITER) {
+		return errors.New("request too frequent")
+	}
+	session, exists := manager.GetSession(sessionID)
+	if !exists {
+		return errors.New("session not found")
+	}
 
-// Remove 移除一个会话
-func (manager *ConnectionManager) Remove(sessionCode string) {
-	delete(manager.Sessions, sessionCode)
-}
+	// 处理业务逻辑
+	session.History = append(session.History, string(message))
+	response := "Processed: " + string(message)
 
-func (manager *ConnectionManager) Get(sessionID string) (session *Session, ok bool) {
-	session, ok = manager.Sessions[sessionID]
-	return
+	// 更新最后活跃时间
+	session.LastActive = time.Now()
+
+	// 发送响应
+	return session.Conn.WriteMessage(websocket.TextMessage, []byte(response))
 }
-func (manager *ConnectionManager) HeartBeat(session *Session) {
-	fmt.Println("执行心跳")
-	if err := session.Conn.WriteMessage(websocket.PingMessage, nil); err != nil {
-		err = session.Conn.Close()
-		if err != nil {
-			utils.FileLog.Error("关闭长连接失败: %v", err)
-			return
-		}
-		delete(manager.Sessions, manager.GetSessionId(session.UserId, session.SessionId))
+
+// 心跳管理
+func (manager *ConnectionManager) StartHeartbeat() {
+	ticker := time.NewTicker(basePingInterval)
+	defer ticker.Stop()
+	for range ticker.C {
+		manager.checkSessions()
 	}
 }
-func (manager *ConnectionManager) HandleWebSocketConnection(conn *websocket.Conn) {
-	defer func() {
-		if err := conn.Close(); err != nil {
-			handleClose(err)
-		}
-	}()
-	// 获取底层 TCP 连接并设置保活
-	if tcpConn, ok := conn.NetConn().(*net.TCPConn); ok {
-		_ = tcpConn.SetKeepAlive(true)
-		_ = tcpConn.SetKeepAlivePeriod(90 * time.Second)
-		utils.FileLog.Info("TCP KeepAlive 已启用")
-	}
-	// 初始化心跳间隔(基础值)
-	baseInterval := 30 * time.Second
-	adjustHeartbeatInterval(conn, baseInterval)
-	// 设置心跳检测
-	conn.SetPongHandler(func(string) error {
-		err := conn.SetReadDeadline(time.Now().Add(60 * time.Second))
-		if err != nil {
-			utils.FileLog.Error("设置读取超时失败:WebSocket:", err)
+
+func (manager *ConnectionManager) checkSessions() {
+	manager.Sessions.Range(func(key, value interface{}) bool {
+		session := value.(*Session)
+		if time.Since(session.LastActive) > 2*basePingInterval {
+			session.Conn.Close()
+			manager.Sessions.Delete(key)
+		} else {
+			_ = session.Latency.SendPing(session.Conn)
 		}
-		return nil
+		return true
 	})
-	// 消息处理循环
-	for {
-		messageType, message, err := conn.ReadMessage()
-		if err != nil {
-			utils.FileLog.Error("Read error:", err)
-			return
-		}
-		// 业务处理逻辑
-		response := processMessage(message)
-		// 返回响应
-		if err = conn.WriteMessage(messageType, response); err != nil {
-			utils.FileLog.Error("Write error:", err)
-			return
-		}
-	}
 }
-func handleClose(err error) {
-	if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) {
-		if wsErr, ok := err.(*websocket.CloseError); !ok {
-			utils.FileLog.Error("未知错误 %s", err.Error())
-		} else {
-			switch wsErr.Code {
-			case websocket.CloseNormalClosure:
-				utils.FileLog.Info("正常关闭连接")
-			default:
-				utils.FileLog.Error("关闭代码:%d:%s", wsErr.Code, wsErr.Text)
-			}
-		}
 
-	}
+// AddSession Add 添加一个新的会话
+func (manager *ConnectionManager) AddSession(session *Session) {
+	manager.Sessions.Store(session.ID, session)
+	manager.heartbeat.Register(session)
+}
+func (manager *ConnectionManager) GetSessionId(userId int, sessionId string) (sessionID string) {
+	return fmt.Sprintf("%d_%s", userId, sessionId)
 }
 
-// 动态调整心跳间隔(需配合业务逻辑调用)
-func adjustHeartbeatInterval(conn *websocket.Conn, baseInterval time.Duration) {
-	// 模拟网络延迟计算(实际应通过Ping-Pong测量)
-	latency := time.Duration(rand.Intn(100)) * time.Millisecond
-	newInterval := baseInterval + latency*2
-
-	// 创建新的心跳定时器
-	ticker := time.NewTicker(newInterval)
-	defer ticker.Stop()
-
-	go func() {
-		for range ticker.C {
-			if err := conn.WriteControl(websocket.PingMessage, nil, time.Now().Add(5*time.Second)); err != nil {
-				utils.FileLog.Error("发送心跳包失败:", err)
-				return
-			}
-		}
-	}()
-	utils.FileLog.Info("心跳间隔调整为: %v", newInterval)
+// RemoveSession Remove 移除一个会话
+func (manager *ConnectionManager) RemoveSession(sessionCode string) {
+	if data, ok := manager.Sessions.LoadAndDelete(sessionCode); ok {
+		session := data.(*Session)
+		close(session.CloseChan)
+		_ = session.Conn.Close()
+	}
 }
-func processMessage(msg []byte) []byte {
-	// 实现具体的业务逻辑
-	return []byte("Received: " + string(msg))
+
+// GetSession 获取一个会话
+func (manager *ConnectionManager) GetSession(sessionCode string) (session *Session, ok bool) {
+	if data, ok := manager.Sessions.Load(sessionCode); ok {
+		session = data.(*Session)
+	}
+	return
 }