kobe6258 5 days ago
parent
commit
444243e827
2 changed files with 41 additions and 21 deletions
  1. 40 20
      utils/ws/limiter.go
  2. 1 1
      utils/ws/session_manager.go

+ 40 - 20
utils/ws/limiter.go

@@ -10,23 +10,38 @@ import (
 var (
 	limiterManagers map[string]*LimiterManger
 	limiterOnce     sync.Once
-	limters         = map[string]string{
-		CONNECT_LIMITER: LIMITER_KEY,
-		QA_LIMITER:      CONNECT_LIMITER_KEY,
+	limiters        = map[string]LimiterConfig{
+		CONNECT_LIMITER: {
+			LimiterKey: LIMITER_KEY,
+			Duration:   RATE_LIMTER_TIME,
+		},
+		QA_LIMITER: {
+			LimiterKey: CONNECT_LIMITER_KEY,
+			Duration:   CONNECT_LIMITER_TIME,
+		},
 	}
 )
 
+type LimiterConfig struct {
+	LimiterKey string
+	Duration   time.Duration
+}
+
 const (
 	CONNECT_LIMITER     = "connetLimiter"
 	QA_LIMITER          = "qaLimiter"
 	LIMITER_KEY         = "llm_chat_key_user_%d"
 	CONNECT_LIMITER_KEY = "llm_chat_connect_key_user_%d"
 
-	RATE_LIMTER_TIME	=30*time.Second
+	RATE_LIMTER_TIME     = 30 * time.Second
+	CONNECT_LIMITER_TIME = 5 * time.Second
 )
 
+var ()
+
 type RateLimiter struct {
 	LastRequest time.Time
+	Duration    time.Duration
 	*rate.Limiter
 }
 type LimiterManger struct {
@@ -39,27 +54,32 @@ type LimiterManger struct {
 //}
 
 // GetLimiter 获取或创建用户的限流器
-func (qalm *LimiterManger) GetLimiter(token string) *RateLimiter {
+func (qalm *LimiterManger) GetLimiter(token string, limiterKey string) (limiter *RateLimiter, duration time.Duration) {
 	qalm.Lock()
 	defer qalm.Unlock()
-	if limiter, exists := qalm.limiterMap[token]; exists {
-		return limiter
+	if config, ok := limiters[limiterKey]; !ok {
+		duration = 0 * time.Second
+	} else {
+		duration = config.Duration
+	}
+	if target, exists := qalm.limiterMap[token]; exists {
+		limiter = target
+		return
 	}
-
 	// 创建一个新的限流器,例如每10秒1个请求
-	limiter := &RateLimiter{
-		Limiter: rate.NewLimiter(rate.Every(RATE_LIMTER_TIME), 1),
+	limiter = &RateLimiter{
+		Limiter: rate.NewLimiter(rate.Every(duration), 1),
 	}
 	qalm.limiterMap[token] = limiter
-	return limiter
+	return
 }
-func (qalm *LimiterManger) Allow(token string) bool {
-	limiter := qalm.GetLimiter(token)
+func (qalm *LimiterManger) Allow(token string, limiterKey string) bool {
+	limiter, duration := qalm.GetLimiter(token, limiterKey)
 	if limiter.LastRequest.IsZero() {
 		limiter.LastRequest = time.Now()
 		return limiter.Allow()
 	}
-	if time.Now().Sub(limiter.LastRequest) < RATE_LIMTER_TIME {
+	if time.Now().Sub(limiter.LastRequest) < duration {
 		return false
 	}
 	limiter.LastRequest = time.Now()
@@ -68,9 +88,9 @@ func (qalm *LimiterManger) Allow(token string) bool {
 func getInstance(key string) *LimiterManger {
 	limiterOnce.Do(func() {
 		if limiterManagers == nil {
-			limiterManagers = make(map[string]*LimiterManger, len(limters))
+			limiterManagers = make(map[string]*LimiterManger, len(limiters))
 		}
-		for key, _ := range limters {
+		for key = range limiters {
 			limiterManagers[key] = &LimiterManger{
 				limiterMap: make(map[string]*RateLimiter),
 			}
@@ -80,14 +100,14 @@ func getInstance(key string) *LimiterManger {
 }
 
 func Allow(userId int, limiter string) bool {
-	tokenKey := limters[limiter]
-	if tokenKey == "" {
+	config := limiters[limiter]
+	if config.LimiterKey == "" {
 		return false
 	}
-	token := fmt.Sprintf(tokenKey, userId)
+	token := fmt.Sprintf(config.LimiterKey, userId)
 	handler := getInstance(limiter)
 	if handler == nil {
 		return false
 	}
-	return handler.Allow(token)
+	return handler.Allow(token,limiter)
 }

+ 1 - 1
utils/ws/session_manager.go

@@ -103,7 +103,7 @@ func (manager *ConnectionManager) HandleMessage(userID int, sessionID string, me
 			// 发送错误消息到 WebSocket
 			return err
 		case <-closeChan:
-			_ = session.Conn.WriteMessage(websocket.TextMessage, []byte("<EOF>"))
+			_ = session.Conn.WriteMessage(websocket.TextMessage, []byte("<EOF/>"))
 			return nil
 		}
 	}