kobe6258 6 days ago
parent
commit
b39b2f765e

+ 16 - 1
models/llm/user_chat_record.go

@@ -1,6 +1,12 @@
 package llm
 
-import "time"
+import (
+	"eta/eta_api/global"
+	"eta/eta_api/utils"
+	"gorm.io/gorm"
+	"gorm.io/gorm/clause"
+	"time"
+)
 
 // UserChatRecord 定义用户聊天记录结构体
 type UserChatRecord struct {
@@ -23,3 +29,12 @@ type UserChatRecordRedis struct {
 func (u *UserChatRecord) TableName() string {
 	return "user_chat_record"
 }
+
+func BatchInsertRecords(list []*UserChatRecord) (err error) {
+	o := global.DbMap[utils.DbNameAI]
+	err = o.Clauses(clause.OnConflict{
+		Columns:   []clause.Column{{Name: "chat_id"}, {Name: "chat_user_type"}, {Name: "send_time"}},
+		DoUpdates: clause.Assignments(map[string]interface{}{"update_time": gorm.Expr("VALUES(update_time)")}),
+	}).CreateInBatches(list, utils.MultiAddNum).Error
+	return
+}

+ 102 - 82
services/llm/chat_service.go

@@ -5,37 +5,50 @@ import (
 	"eta/eta_api/global"
 	"eta/eta_api/models/llm"
 	"eta/eta_api/utils"
+	"eta/eta_api/utils/lock"
 	"eta/eta_api/utils/redis"
 	"fmt"
+	"github.com/google/uuid"
+	"strconv"
+	"strings"
 	"time"
 )
 
 const (
 	redisChatPrefix = "chat:zet:"
+	RecordLock      = "lock:chat:record:"
 	redisTTL        = 24 * time.Hour // Redis 缓存过期时间
 )
 
 // AddChatRecord 添加聊天记录到 Redis
 func AddChatRecord(record *llm.UserChatRecordRedis) error {
 	key := fmt.Sprintf("%s%d", redisChatPrefix, record.ChatId)
-	data, err := json.Marshal(record)
-	if err != nil {
-		return fmt.Errorf("序列化聊天记录失败: %w", err)
-	}
-	zSet, _ := utils.Rc.ZRangeWithScores(key)
-	if len(zSet) == 0 {
-		// 设置过期时间
-		_ = utils.Rc.Expire(key, 24*time.Hour)
-	}
-	zSet = append(zSet, &redis.Zset{
-		Member: data,
-		Score:  float64(time.Now().Unix()),
-	})
-	err = utils.Rc.ZAdd(key, zSet...)
-	if err != nil {
-		return fmt.Errorf("保存聊天记录到 Redis 失败: %w", err)
+	lockKey := fmt.Sprintf("%s%s", RecordLock, key)
+	fmt.Printf("2%s", lockKey)
+	holder, _ := uuid.NewUUID()
+	if lock.AcquireLock(lockKey, 10, holder.String()) {
+		defer lock.ReleaseLock(key, holder.String())
+		data, err := json.Marshal(record)
+		if err != nil {
+			return fmt.Errorf("序列化聊天记录失败: %w", err)
+		}
+		zSet, _ := utils.Rc.ZRangeWithScores(key)
+		if len(zSet) == 0 {
+			// 设置过期时间
+			_ = utils.Rc.Expire(key, 24*time.Hour)
+		}
+		zSet = append(zSet, &redis.Zset{
+			Member: data,
+			Score:  float64(time.Now().Unix()),
+		})
+		fmt.Println(strconv.Itoa(len(zSet)))
+		err = utils.Rc.ZAdd(key, zSet...)
+		if err != nil {
+			return fmt.Errorf("保存聊天记录到 Redis 失败: %w", err)
+		}
+		return nil
 	}
-	return nil
+	return fmt.Errorf("获取锁失败,请稍后重试")
 }
 
 // GetChatRecordsFromRedis 从 Redis 获取聊天记录
@@ -52,16 +65,13 @@ func GetChatRecordsFromRedis(chatId int) (redisList []*llm.UserChatRecordRedis,
 		// 将数据保存到 Redis
 		for _, record := range records {
 			redisRecord := &llm.UserChatRecordRedis{
+				Id:           record.Id,
 				ChatId:       chatId,
 				ChatUserType: record.ChatUserType,
 				Content:      record.Content,
 				SendTime:     record.SendTime.Format(utils.FormatDateTime),
 			}
 			redisList = append(redisList, redisRecord)
-			if err = AddChatRecord(redisRecord); err != nil {
-				err = fmt.Errorf("保存聊天记录到 Redis 失败: %w", err)
-				return
-			}
 		}
 		return
 	}
@@ -75,67 +85,77 @@ func GetChatRecordsFromRedis(chatId int) (redisList []*llm.UserChatRecordRedis,
 	return
 }
 
-//
-//// SaveChatRecordsToDB 将 Redis 中的聊天记录保存到数据库
-//func SaveChatRecordsToDB(chatID int) error {
-//	key := fmt.Sprintf("%s%d", redisChatPrefix, chatID)
-//	val, err := global.Redis.Get(global.Context, key).Result()
-//	if err == redis.Nil {
-//		return nil // 缓存不存在,无需保存
-//	} else if err != nil {
-//		return fmt.Errorf("从 Redis 获取聊天记录失败: %w", err)
-//	}
-//
-//	var records []*llm.UserChatRecord
-//	if err := json.Unmarshal([]byte(val), &records); err != nil {
-//		return fmt.Errorf("解析聊天记录失败: %w", err)
-//	}
-//
-//	o := global.DbMap[utils.DbNameAI]
-//	for _, record := range records {
-//		if err := o.Create(record).Error; err != nil {
-//			return fmt.Errorf("保存聊天记录到数据库失败: %w", err)
-//		}
-//	}
-//
-//	// 删除 Redis 缓存
-//	if err := global.Redis.Del(global.Context, key).Err(); err != nil {
-//		return fmt.Errorf("删除 Redis 缓存失败: %w", err)
-//	}
-//
-//	return nil
-//}
-//
-//// SaveAllChatRecordsToDB 定时任务保存所有 Redis 中的聊天记录到数据库
-//func SaveAllChatRecordsToDB() {
-//	keys, err := global.Redis.Keys(global.Context, redisChatPrefix+"*").Result()
-//	if err != nil {
-//		log.Printf("获取 Redis 键失败: %v", err)
-//		return
-//	}
-//
-//	for _, key := range keys {
-//		chatIDStr := strings.TrimPrefix(key, redisChatPrefix)
-//		chatID, err := strconv.Atoi(chatIDStr)
-//		if err != nil {
-//			log.Printf("解析聊天ID失败: %v", err)
-//			continue
-//		}
-//
-//		if err := SaveChatRecordsToDB(chatID); err != nil {
-//			log.Printf("保存聊天记录到数据库失败: %v", err)
-//		}
-//	}
-//}
-//
-//// RemoveChatRecord 从 Redis 删除聊天记录
-//func RemoveChatRecord(chatID int) error {
-//	key := fmt.Sprintf("%s%d", redisChatPrefix, chatID)
-//	if err := global.Redis.Del(global.Context, key).Err(); err != nil {
-//		return fmt.Errorf("删除 Redis 缓存失败: %w", err)
-//	}
-//	return nil
-//}
+// SaveChatRecordsToDB 将 Redis 中的聊天记录保存到数据库
+func SaveChatRecordsToDB(chatId int) error {
+	list, err := GetChatRecordsFromRedis(chatId)
+	if err != nil {
+		return err
+	}
+	var newRecords []*llm.UserChatRecord
+	for _, record := range list {
+		//if record.Id == 0 {
+		sendTime, parseErr := time.Parse(utils.FormatDateTime, record.SendTime)
+		if parseErr != nil {
+			sendTime = time.Now()
+		}
+		newRecords = append(newRecords, &llm.UserChatRecord{
+			ChatId:       record.ChatId,
+			ChatUserType: record.ChatUserType,
+			Content:      record.Content,
+			SendTime:     sendTime,
+			CreatedTime:  time.Now(),
+		})
+		//}
+	}
+	//先删除redis中的缓存
+	//_ = RemoveChatRecord(chatId)
+	err = llm.BatchInsertRecords(newRecords)
+	if err != nil {
+		utils.FileLog.Error("批量插入记录失败:", err.Error())
+		return fmt.Errorf("批量插入记录失败: %w", err)
+	}
+	//_ = RemoveChatRecord(chatId)
+	return nil
+}
+
+// SaveAllChatRecordsToDB 定时任务保存所有 Redis 中的聊天记录到数据库
+func SaveAllChatRecordsToDB() {
+	for {
+		fmt.Println("开始保存聊天记录到数据库...")
+		keys, err := utils.Rc.Keys(redisChatPrefix + "*")
+		if err != nil {
+			utils.FileLog.Error("获取 Redis 键失败: %v", err)
+			return
+		}
+		for _, key := range keys {
+			lockKey := fmt.Sprintf("%s%s", RecordLock, key)
+			fmt.Printf("1%s", lockKey)
+			chatIdStr := strings.TrimPrefix(key, redisChatPrefix)
+			chatId, parseErr := strconv.Atoi(chatIdStr)
+			if parseErr != nil {
+				utils.FileLog.Error("解析聊天ID失败: %v", err)
+				continue
+			}
+			if lock.AcquireLock(lockKey, 10, "system_task") {
+				if err = SaveChatRecordsToDB(chatId); err != nil {
+					utils.FileLog.Error("解析聊天ID失败: %v", err)
+				}
+				lock.ReleaseLock(key, "system_task")
+			}
+		}
+		time.Sleep(10 * time.Second)
+	}
+}
+
+// RemoveChatRecord 从 Redis 删除聊天记录
+func RemoveChatRecord(chatId int) error {
+	key := fmt.Sprintf("%s%d", redisChatPrefix, chatId)
+	err := utils.Rc.Delete(key)
+	if err != nil {
+		return fmt.Errorf("删除 Redis 缓存失败: %w", err)
+	}
+	return nil
+}
 
 func GetChatRecordsFromDB(chatID int) ([]*llm.UserChatRecord, error) {
 	o := global.DbMap[utils.DbNameAI]

+ 2 - 0
services/task.go

@@ -5,6 +5,7 @@ import (
 	"eta/eta_api/services/binlog"
 	"eta/eta_api/services/data"
 	edbmonitor "eta/eta_api/services/edb_monitor"
+	"eta/eta_api/services/llm"
 	"eta/eta_api/utils"
 	"fmt"
 	"strings"
@@ -61,6 +62,7 @@ func Task() {
 		go binlog.HandleDataSourceChange2Es()
 	}
 	go StartSessionManager()
+	go llm.SaveAllChatRecordsToDB()
 	// TODO:数据修复
 	//FixNewEs()
 	fmt.Println("task end")

+ 56 - 0
utils/lock/distrubtLock.go

@@ -0,0 +1,56 @@
+package lock
+
+import (
+	"context"
+	"eta/eta_api/utils"
+	"fmt"
+	"github.com/go-redis/redis/v8"
+)
+
+const (
+	lockName = "lock:"
+)
+
+var (
+	ctx = context.Background()
+)
+
+func AcquireLock(key string, expiration int, Holder string) bool {
+	script := redis.NewScript(`local key = KEYS[1]
+			local clientId = ARGV[1]
+			local expiration = tonumber(ARGV[2])
+			if redis.call("EXISTS", key) == 0 then
+				redis.call("SET", key, clientId, "EX", expiration)
+				return 1
+			else
+				return 0
+			end`)
+	lockey := fmt.Sprintf("%s%s", lockName, key)
+	result, err := script.Run(ctx, utils.Rc.RedisClient(), []string{lockey}, Holder, expiration).Int()
+	if err != nil {
+		return false
+	}
+	if result == 1 {
+		return true
+	}
+	return false
+}
+
+func ReleaseLock(key string, holder string) bool {
+	script := redis.NewScript(`
+	   if redis.call("get", KEYS[1]) == ARGV[1] then
+	       return redis.call("del", KEYS[1])
+	   else
+	       return 0
+	   end
+	`)
+	lockey := fmt.Sprintf("%s%s", lockName, key)
+	result, err := script.Run(ctx, utils.Rc.RedisClient(), []string{lockey}, holder).Int()
+	if err != nil {
+		return false
+	}
+	if result == 1 {
+		return true
+	}
+	return false
+}

+ 3 - 0
utils/redis.go

@@ -2,6 +2,7 @@ package utils
 
 import (
 	"eta/eta_api/utils/redis"
+	client "github.com/go-redis/redis/v8"
 	"time"
 )
 
@@ -27,6 +28,8 @@ type RedisClient interface {
 	ZAdd(key string, members ...*redis.Zset) error
 	ZRangeWithScores(key string) ([]*redis.Zset, error)
 	Expire(key string, duration time.Duration) error
+	RedisClient() client.UniversalClient
+	Keys(pattern string) ([]string, error)
 }
 
 func initRedis(redisType string, conf string) (redisClient RedisClient, err error) {

+ 8 - 0
utils/redis/cluster_redis.go

@@ -20,6 +20,7 @@ type Zset struct {
 	Score  float64
 	Member interface{}
 }
+
 var DefaultKey = "zcmRedis"
 
 // InitClusterRedis
@@ -346,3 +347,10 @@ func (rc *ClusterRedisClient) ZRangeWithScores(key string) (result []*Zset, err
 func (rc *ClusterRedisClient) Expire(key string, duration time.Duration) error {
 	return rc.redisClient.Expire(context.Background(), key, duration).Err()
 }
+
+func (rc *ClusterRedisClient) Keys(pattern string) (keys []string, err error) {
+	return rc.redisClient.Keys(context.Background(), pattern).Result()
+}
+func (rc *ClusterRedisClient) RedisClient() redis.UniversalClient {
+	return rc.redisClient
+}

+ 6 - 0
utils/redis/standalone_redis.go

@@ -337,3 +337,9 @@ func (rc *StandaloneRedisClient) ZRangeWithScores(key string) (result []*Zset, e
 func (rc *StandaloneRedisClient) Expire(key string, duration time.Duration) error {
 	return rc.redisClient.Expire(context.Background(), key, duration).Err()
 }
+func (rc *StandaloneRedisClient) Keys(pattern string) (keys []string, err error) {
+	return rc.redisClient.Keys(context.Background(), pattern).Result()
+}
+func (rc *StandaloneRedisClient) RedisClient() redis.UniversalClient {
+	return rc.redisClient
+}

+ 1 - 0
utils/ws/session.go

@@ -22,6 +22,7 @@ type Session struct {
 	mu          sync.RWMutex
 	sessionOnce sync.Once
 }
+
 type Message struct {
 	KbName     string   `json:"KbName"`
 	Query      string   `json:"Query"`

+ 2 - 0
utils/ws/session_manager.go

@@ -67,6 +67,8 @@ func (manager *ConnectionManager) HandleMessage(userID int, sessionID string, me
 	}
 	// 处理业务逻辑
 	session.History = append(session.History, userMessage.LastTopics...)
+
+	//TODO
 	resp, err := llmService.KnowledgeBaseChat(userMessage.Query, userMessage.KbName, session.History)
 	defer func() {
 		_ = resp.Body.Close()