Browse Source

优化长连接的消息通知

xyxie 1 day ago
parent
commit
f87e955d9d
4 changed files with 138 additions and 43 deletions
  1. 107 43
      services/websocket_msg.go
  2. 1 0
      utils/redis.go
  3. 15 0
      utils/redis/cluster_redis.go
  4. 15 0
      utils/redis/standalone_redis.go

+ 107 - 43
services/websocket_msg.go

@@ -1,10 +1,13 @@
 package services
 
 import (
+	"context"
 	"eta/eta_api/models"
 	"eta/eta_api/services/data"
 	"eta/eta_api/utils"
 	"fmt"
+	"sync"
+	"time"
 
 	"github.com/gorilla/websocket"
 )
@@ -15,57 +18,118 @@ func DealWebSocketMsg(conn *websocket.Conn, adminId int) {
 
 // 处理巡检消息
 func DealEdbInspectionMessage(conn *websocket.Conn, adminId int) {
+	// 创建上下文用于控制 goroutine 生命周期
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+
+	// 创建互斥锁保护 WebSocket 写操作
+	var wsWriteMutex sync.Mutex
+
+	// 创建连接关闭标志
+	done := make(chan struct{})
+	defer close(done)
+
 	cacheKey := fmt.Sprintf("%s%d", utils.CACHE_EDB_INSPECTION_MESSAGE, adminId)
+
+	// 监听连接关闭
+	go func() {
+		<-done
+		cancel()
+	}()
+
+	// 设置连接关闭处理器
+	conn.SetCloseHandler(func(code int, text string) error {
+		close(done)
+		return nil
+	})
+
 	for {
-		utils.Rc.Brpop(cacheKey, func(b []byte) {
-			messageList, err := data.GetHistoryInspectionMessages(adminId)
-			if err != nil {
-				utils.FileLog.Error("获取巡检信息历史失败,err:%s, adminId:%d", err.Error(), adminId)
-				return
-			}
-			success := make(chan int64, 10)
-			defer close(success)
-			go func() {
-				defer close(success)
-				for i, msg := range messageList {
-					if i == 0 {
-						// 多条消息仅发送最新一条
-						respData, err := data.SendInspectionMessages(adminId, msg)
-						if err != nil {
-							utils.FileLog.Error("巡检信息发送失败,err:%s, adminId:%d", err.Error(), adminId)
-						} else {
-							resp := models.WebsocketMessageResponse{
-								MessageType: 1,
-								Data: respData,
-							}
-							err = conn.WriteJSON(resp)
-							if err != nil {
-								utils.FileLog.Error("巡检信息发送失败,err:%s, adminId:%d", err.Error(), adminId)
-							} else {
+		select {
+		case <-ctx.Done():
+			return
+		default:
+			// 使用带超时的 Redis 操作
+			utils.Rc.BrpopWithTimeout(cacheKey, 30*time.Second, func(b []byte) {
+				messageList, err := data.GetHistoryInspectionMessages(adminId)
+				if err != nil {
+					utils.FileLog.Error("获取巡检信息历史失败,err:%s, adminId:%d", err.Error(), adminId)
+					return
+				}
+
+				success := make(chan int64, 10)
+				var wg sync.WaitGroup
+
+				// 消息发送 goroutine
+				wg.Add(1)
+				go func() {
+					defer wg.Done()
+					defer close(success)
+
+					for i, msg := range messageList {
+						select {
+						case <-ctx.Done():
+							return
+						default:
+							if i == 0 {
+								respData, err := data.SendInspectionMessages(adminId, msg)
+								if err != nil {
+									utils.FileLog.Error("巡检信息发送失败,err:%s, adminId:%d", err.Error(), adminId)
+									continue
+								}
+
+								resp := models.WebsocketMessageResponse{
+									MessageType: 1,
+									Data:       respData,
+								}
+
+								// 使用互斥锁保护 WebSocket 写操作
+								wsWriteMutex.Lock()
+								err = conn.WriteJSON(resp)
+								wsWriteMutex.Unlock()
+
+								if err != nil {
+									utils.FileLog.Error("巡检信息发送失败,err:%s, adminId:%d", err.Error(), adminId)
+									continue
+								}
+
 								utils.FileLog.Info("巡检信息发送成功,adminId:%d, messageId:%d", adminId, msg.MessageId)
 								success <- msg.MessageId
+							} else {
+								success <- msg.MessageId
 							}
 						}
-					} else {
-						success <- msg.MessageId
 					}
-				}
-			}()
-			go func() {
-				readList := make([]int64, 0)
-				for {
-					msgId, ok := <-success
-					if !ok {
-						break
+				}()
+
+				// 消息已读处理 goroutine
+				wg.Add(1)
+				go func() {
+					defer wg.Done()
+					readList := make([]int64, 0)
+					
+					for {
+						select {
+						case <-ctx.Done():
+							return
+						case msgId, ok := <-success:
+							if !ok {
+								// 处理已收集的消息
+								if len(readList) > 0 {
+									_, err = data.ReadEdbInspectionMessageList(readList, adminId)
+									if err != nil {
+										utils.FileLog.Error("巡检信息已读失败,err:%s, adminId:%d", err.Error(), adminId)
+									}
+								}
+								return
+							}
+							readList = append(readList, msgId)
+						}
 					}
-					readList = append(readList, msgId)
-				}
-				_, err = data.ReadEdbInspectionMessageList(readList, adminId)
-				if err != nil {
-					utils.FileLog.Error("巡检信息已读失败,err:%s, adminId:%d", err.Error(), adminId)
-				}
-			}()
+				}()
+
+				// 等待所有 goroutine 完成
+				wg.Wait()
 			})
+		}
 	}
-	
 }

+ 1 - 0
utils/redis.go

@@ -19,6 +19,7 @@ type RedisClient interface {
 	IsExist(key string) bool
 	LPush(key string, val interface{}) error
 	Brpop(key string, callback func([]byte))
+	BrpopWithTimeout(key string, timeout time.Duration, callback func([]byte))
 	GetRedisTTL(key string) time.Duration
 	Incrby(key string, num int) (interface{}, error)
 	Do(commandName string, args ...interface{}) (reply interface{}, err error)

+ 15 - 0
utils/redis/cluster_redis.go

@@ -249,6 +249,21 @@ func (rc *ClusterRedisClient) Brpop(key string, callback func([]byte)) {
 
 }
 
+// BrpopWithTimeout
+// @Description: 从list中读取
+// @receiver rc
+// @param key
+// @param timeout
+// @param callback
+func (rc *ClusterRedisClient) BrpopWithTimeout(key string, timeout time.Duration, callback func([]byte)) {
+	values, err := rc.redisClient.BRPop(context.TODO(), timeout, key).Result()
+	if err != nil {
+		return
+	}
+
+	callback([]byte(values[1]))
+}
+
 // GetRedisTTL
 // @Description: 获取key的过期时间
 // @receiver rc

+ 15 - 0
utils/redis/standalone_redis.go

@@ -237,6 +237,21 @@ func (rc *StandaloneRedisClient) Brpop(key string, callback func([]byte)) {
 
 }
 
+// BrpopWithTimeout
+// @Description: 从list中读取
+// @receiver rc
+// @param key
+// @param timeout
+// @param callback
+func (rc *StandaloneRedisClient) BrpopWithTimeout(key string, timeout time.Duration, callback func([]byte)) {
+	values, err := rc.redisClient.BRPop(context.TODO(), timeout, key).Result()
+	if err != nil {
+		return
+	}
+
+	callback([]byte(values[1]))
+}
+
 // GetRedisTTL
 // @Description: 获取key的过期时间
 // @receiver rc