Browse Source

修改会话逻辑

kobe6258 1 month ago
parent
commit
2a587de1a7
3 changed files with 75 additions and 13 deletions
  1. 69 7
      utils/lock/distrubtLock.go
  2. 0 2
      utils/ws/latency_measurer.go
  3. 6 4
      utils/ws/session_manager.go

+ 69 - 7
utils/lock/distrubtLock.go

@@ -2,9 +2,11 @@ package lock
 
 import (
 	"context"
+	"errors"
 	"eta/eta_api/utils"
 	"fmt"
 	"github.com/go-redis/redis/v8"
+	"time"
 )
 
 const (
@@ -37,15 +39,75 @@ func AcquireLock(key string, expiration int, Holder string) bool {
 	return false
 }
 
-//func TryLock(key string, expiration int, Holder string, wait bool, timeout time.Duration)error{
-//
-//}
-func Lock() error {
-	if !AcquireLock("test", 10, "test") {
-		return fmt.Errorf("加锁失败")
+func TryLock(key string, expiration int, Holder string, timeout time.Duration) error {
+	script := redis.NewScript(`
+        local session = ARGV[1]
+        local current = redis.call("hget", KEYS[1], session)
+        if current then
+            local count = tonumber(current)
+            if count > 0 then
+                redis.call("hincrby", KEYS[1], session, 1)
+                return 1
+            else
+                return 0
+            end
+        else
+            redis.call("hset", KEYS[1], session, 1)
+            redis.call("expire", KEYS[1], 10)
+            return 1
+        end
+    `)
+	start := time.Now()
+	for {
+		result, err := script.Run(context.Background(), utils.Rc.RedisClient(), []string{key}, Holder).Result()
+		if err != nil {
+			return err
+		}
+		if result.(int64) == 1 {
+			//go renewLock()
+			return nil
+		}
+
+		if time.Since(start) >= timeout {
+			return errors.New("获取锁超时")
+		}
+		time.Sleep(200 * time.Millisecond)
 	}
-	return nil
 }
+//
+//func renewLock() {
+//	for {
+//		time.Sleep(5 * time.Second)
+//
+//		script := redis.NewScript(1, `
+//            local session = ARGV[1]
+//            local current = redis.call("hget", KEYS[1], session)
+//            if current then
+//                local count = tonumber(current)
+//                if count > 0 then
+//                    redis.call("expire", KEYS[1], 10)
+//                    return 1
+//                else
+//                    return 0
+//                end
+//            else
+//                return 0
+//            end
+//        `)
+//
+//		result, err := script.Run(r.ctx, r.client, []string{r.key}, r.session).Result()
+//		if err != nil {
+//			fmt.Println("Failed to renew lock:", err)
+//			return
+//		}
+//
+//		if result.(int64) == 0 {
+//			fmt.Println("Lock not held by current client, stopping renewal")
+//			return
+//		}
+//	}
+//}
+
 func ReleaseLock(key string, holder string) bool {
 	script := redis.NewScript(`
 	   if redis.call("get", KEYS[1]) == ARGV[1] then

+ 0 - 2
utils/ws/latency_measurer.go

@@ -2,7 +2,6 @@ package ws
 
 import (
 	"errors"
-	"fmt"
 	"github.com/gorilla/websocket"
 	"sync"
 	"time"
@@ -87,7 +86,6 @@ func (lm *LatencyMeasurer) GetLatency() time.Duration {
 func SetupLatencyMeasurement(conn *websocket.Conn) *LatencyMeasurer {
 	lm := NewLatencyMeasurer(5) // 使用最近5次测量的滑动窗口
 	conn.SetPongHandler(func(appData string) error {
-		fmt.Println("Pong received")
 		lm.CalculateLatency()
 		return nil
 	})

+ 6 - 4
utils/ws/session_manager.go

@@ -56,9 +56,7 @@ func Manager() *ConnectionManager {
 
 // HandleMessage 消息处理核心逻辑
 func (manager *ConnectionManager) HandleMessage(userID int, sessionID string, message []byte) error {
-	if !Allow(userID, QA_LIMITER) {
-		return errors.New("您提问的太频繁了,请稍后再试")
-	}
+
 	session, exists := manager.GetSession(sessionID)
 	if !exists {
 		return errors.New("session not found")
@@ -68,10 +66,14 @@ func (manager *ConnectionManager) HandleMessage(userID int, sessionID string, me
 		fmt.Printf("收到心跳消息,续期长连接:%v", session.LastActive)
 		return nil
 	}
+	if !Allow(userID, QA_LIMITER) {
+		_ = session.Conn.WriteMessage(websocket.TextMessage, []byte("您提问的太频繁了,请稍后再试"))
+		_ = session.Conn.WriteMessage(websocket.TextMessage, []byte("<EOF/>"))
+		return nil
+	}
 	var userMessage Message
 	err := json.Unmarshal(message, &userMessage)
 	if err != nil {
-		fmt.Printf("消息格式错误:%s", string(message))
 		return errors.New("消息格式错误:" + err.Error())
 	}
 	// 处理业务逻辑