Browse Source

Merge remote-tracking branch 'origin/feature/deepseek_rag_1.0' into feature/deepseek_rag_1.0

Roc 6 days ago
parent
commit
bc93256eb1

+ 9 - 1
controllers/rag/llm_http/request.go

@@ -7,6 +7,14 @@ type LLMQuestionReq struct {
 }
 
 type UserChatReq struct {
-	ChatId   int    `json:"ChatId"`
+	ChatId    int    `json:"ChatId"`
 	ChatTitle string `json:"ChatTitle" description:"会话名称"`
 }
+
+type UserChatRecordReq struct {
+	Id           int    `json:"Id"`
+	ChatId       int    `json:"ChatId"`
+	Content      string `json:"Content" description:"会话名称"`
+	ChatUserType string `json:"ChatUserType" description:"用户类型"`
+	SendTime     string `json:"SendTime" description:"发送时间"`
+}

+ 5 - 3
controllers/rag/llm_http/response.go

@@ -1,7 +1,9 @@
 package llm_http
 
+import "eta/eta_api/models/llm"
 
-type LLMQuestionRes struct {
-	Answer      string `description:"回答"`
-	SessionId     string `description:"会话ID"`
+type UserChatListResp struct {
+	TodayList     []llm.UserLlmChatListViewItem
+	YesterdayList []llm.UserLlmChatListViewItem
+	WeekList      []llm.UserLlmChatListViewItem
 }

+ 130 - 6
controllers/rag/user_chat_controller.go

@@ -6,6 +6,7 @@ import (
 	"eta/eta_api/controllers/rag/llm_http"
 	"eta/eta_api/models"
 	"eta/eta_api/models/llm"
+	llmService "eta/eta_api/services/llm"
 	"eta/eta_api/utils"
 	"time"
 )
@@ -126,20 +127,143 @@ func (ucCtrl *UserChatController) GetUserChatList() {
 		return
 	}
 	//周日是0,周六是6
-	toDay := time.Now().Weekday()
-	offset := int(time.Monday - toDay)
+	weekDay := time.Now().Weekday()
+	offset := int(time.Monday - weekDay)
 	if offset > 0 {
 		offset -= 7
 	}
+	today := time.Now().Format(utils.FormatDate)
 	monDay := time.Now().AddDate(0, 0, offset).Format(utils.FormatDate)
+	yesterday := time.Now().AddDate(0, 0, -1).Format(utils.FormatDate)
 	chatList, err := llm.GetUserChatList(sysUser.AdminId, monDay, time.Now().Format(utils.FormatDate))
 	if err != nil {
-		br.Msg = "重命名失败"
-		br.ErrMsg = "重命名失败,Err:" + err.Error()
+		br.Msg = "获取用户聊天列表失败"
+		br.ErrMsg = "获取用户聊天列表失败,Err:" + err.Error()
 		return
 	}
-	br.Data = chatList
+	data := new(llm_http.UserChatListResp)
+	data.WeekList = make([]llm.UserLlmChatListViewItem, 0)
+	data.YesterdayList = make([]llm.UserLlmChatListViewItem, 0)
+	data.TodayList = make([]llm.UserLlmChatListViewItem, 0)
+	for _, v := range chatList {
+		if v.CreatedTime.Format(utils.FormatDate) == today {
+			data.TodayList = append(data.TodayList, llm.CovertItemToView(v))
+		} else if v.CreatedTime.Format(utils.FormatDate) == yesterday {
+			data.YesterdayList = append(data.YesterdayList, llm.CovertItemToView(v))
+		} else {
+			data.WeekList = append(data.WeekList, llm.CovertItemToView(v))
+		}
+	}
+
+	br.Data = data
 	br.Ret = 200
 	br.Success = true
-	br.Msg = "重命名成功"
+	br.Msg = "获取用户聊天列表成功"
+}
+
+// ChatRecordAdd @Title 保存聊天记录
+// @Description 保存聊天记录
+// @Success 101 {object} response.ListResp
+// @router /chat/chat_record_save [post]
+func (ucCtrl *UserChatController) ChatRecordAdd() {
+	br := new(models.BaseResponse).Init()
+	defer func() {
+		ucCtrl.Data["json"] = br
+		ucCtrl.ServeJSON()
+	}()
+	var req llm_http.UserChatRecordReq
+	err := json.Unmarshal(ucCtrl.Ctx.Input.RequestBody, &req)
+	if err != nil {
+		br.Msg = "参数解析异常!"
+		br.ErrMsg = "参数解析失败,Err:" + err.Error()
+		return
+	}
+	sysUser := ucCtrl.SysUser
+	if sysUser == nil {
+		br.Msg = "请登录"
+		br.ErrMsg = "请登录,SysUser Is Empty"
+		br.Ret = 408
+		return
+	}
+	if req.ChatId <= 0 {
+		br.Msg = "非法的对话框Id"
+		br.ErrMsg = "非法的对话框Id"
+		return
+	}
+	if req.Content == "" {
+		br.Msg = "聊天记录不能为空"
+		br.ErrMsg = "聊天记录不能为空"
+		return
+	}
+	if req.Id < 0 {
+		br.Msg = "非法的Id"
+		br.ErrMsg = "非法的Id"
+		return
+	}
+	if req.ChatUserType != "user" && req.ChatUserType != "assistant" {
+		br.Msg = "非法的用户类型"
+		br.ErrMsg = "非法的用户类型,用户类型支持:user/assistant"
+		return
+	}
+	if req.SendTime == "" {
+		req.SendTime = time.Now().Format(utils.FormatDateTime)
+	} else {
+		_, err = time.Parse(utils.FormatDateTime, req.SendTime)
+		if err != nil {
+			br.Msg = "非法的发送时间"
+			br.ErrMsg = "非法的发送时间,Err:" + err.Error()
+			return
+		}
+	}
+	record := llm.UserChatRecordRedis{
+		ChatId:       req.ChatId,
+		ChatUserType: req.ChatUserType,
+		Content:      req.Content,
+		SendTime:     req.SendTime,
+	}
+	err = llmService.AddChatRecord(&record)
+	if err != nil {
+		br.Msg = "添加聊天记录失败"
+		br.ErrMsg = "添加聊天记录失败,Err:" + err.Error()
+		return
+	}
+	br.Ret = 200
+	br.Success = true
+	br.Msg = "添加聊天记录成功"
+}
+
+// ChatRecordList @Title 获取聊天记录
+// @Description 获取聊天记录
+// @Success 101 {object} response.ListResp
+// @router /chat/chat_record_list [get]
+func (ucCtrl *UserChatController) ChatRecordList() {
+	br := new(models.BaseResponse).Init()
+	defer func() {
+		ucCtrl.Data["json"] = br
+		ucCtrl.ServeJSON()
+	}()
+	sysUser := ucCtrl.SysUser
+	if sysUser == nil {
+		br.Msg = "请登录"
+		br.ErrMsg = "请登录,SysUser Is Empty"
+		br.Ret = 408
+		return
+	}
+	chatId, _ := ucCtrl.GetInt("ChatId", 0)
+	if chatId <= 0 {
+		br.Msg = "非法的对话Id"
+		br.ErrMsg = "非法的对话Id"
+		return
+	}
+
+	list, err := llmService.GetChatRecordsFromRedis(chatId)
+	if err != nil {
+		br.Msg = "获取聊天记录失败"
+		br.ErrMsg = "获取聊天记录失败,Err:" + err.Error()
+		return
+	}
+	br.Data = list
+	br.Ret = 200
+	br.Success = true
+	br.Msg = "获取聊天记录成功"
 }

+ 24 - 1
models/llm/user_chat_record.go

@@ -1,17 +1,40 @@
 package llm
 
-import "time"
+import (
+	"eta/eta_api/global"
+	"eta/eta_api/utils"
+	"gorm.io/gorm"
+	"gorm.io/gorm/clause"
+	"time"
+)
 
 // UserChatRecord 定义用户聊天记录结构体
 type UserChatRecord struct {
 	Id           int       `gorm:"primaryKey;autoIncrement;comment:主键"`
 	ChatId       int       `gorm:"comment:会话id"`
 	ChatUserType string    `gorm:"type:enum('user','assistant');comment:用户方"`
+	Content      string    `gorm:"content:内容"`
 	SendTime     time.Time `gorm:"comment:发送时间"`
 	CreatedTime  time.Time `gorm:"comment:创建时间"`
 	UpdateTime   time.Time `gorm:"autoUpdateTime;comment:更新时间"`
 }
+type UserChatRecordRedis struct {
+	Id           int
+	ChatId       int
+	ChatUserType string
+	Content      string
+	SendTime     string
+}
 
 func (u *UserChatRecord) TableName() string {
 	return "user_chat_record"
 }
+
+func BatchInsertRecords(list []*UserChatRecord) (err error) {
+	o := global.DbMap[utils.DbNameAI]
+	err = o.Clauses(clause.OnConflict{
+		Columns:   []clause.Column{{Name: "chat_id"}, {Name: "chat_user_type"}, {Name: "send_time"}},
+		DoUpdates: clause.Assignments(map[string]interface{}{"update_time": gorm.Expr("VALUES(update_time)")}),
+	}).CreateInBatches(list, utils.MultiAddNum).Error
+	return
+}

+ 29 - 4
models/llm/user_llm_chat.go

@@ -16,6 +16,31 @@ type UserLlmChat struct {
 	UpdateTime  time.Time `gorm:"autoUpdateTime;comment:更新时间"`
 }
 
+type UserLlmChatListItem struct {
+	Id          int       `gorm:"primaryKey;autoIncrement;comment:会话主键"`
+	UserId      int       `gorm:"comment:用户id"`
+	ChatTitle   string    `gorm:"comment:会话标题"`
+	CreatedTime time.Time `gorm:"comment:创建时间"`
+	RecordCount int       `gorm:"comment:会话记录数"`
+}
+type UserLlmChatListViewItem struct {
+	Id          int    `gorm:"primaryKey;autoIncrement;comment:会话主键"`
+	UserId      int    `gorm:"comment:用户id"`
+	ChatTitle   string `gorm:"comment:会话标题"`
+	CreatedTime string `gorm:"comment:创建时间"`
+	RecordCount int    `gorm:"comment:会话记录数"`
+}
+
+func CovertItemToView(item UserLlmChatListItem) UserLlmChatListViewItem {
+	return UserLlmChatListViewItem{
+		Id:          item.Id,
+		UserId:      item.UserId,
+		ChatTitle:   item.ChatTitle,
+		CreatedTime: item.CreatedTime.Format(utils.FormatDateTime),
+		RecordCount: item.RecordCount,
+	}
+
+}
 func (u *UserLlmChat) TableName() string {
 	return "user_llm_chat"
 }
@@ -39,10 +64,10 @@ func (u *UserLlmChat) RenameChatSession() (err error) {
 	return
 }
 
-func GetUserChatList(userId int, monDay, toDay string) (chatList []UserLlmChat, err error) {
+func GetUserChatList(userId int, monDay, toDay string) (chatList []UserLlmChatListItem, err error) {
 	o := global.DbMap[utils.DbNameAI]
-	sql := `select ulc.id ,ulc.chat_title,ulc.created_time,COUNT(ucr.id) AS record_count from user_llm_chat ulc left join user_chat_record ucr
-    ON ucr.chat_id = ulc.id where ulc.user_id=? and ? BETWEEN ? and ? GROUP BY ulc.id`
-	err = o.Raw(sql, userId, utils.GenerateQuerySql(utils.ToDate, &utils.QueryParam{Column: "ulc.created_time"}), monDay, toDay).Find(&chatList).Error
+	sql := `select ulc.id AS id ,ulc.user_id as user_id,ulc.chat_title as chat_title,ulc.created_time,COUNT(ucr.id) AS record_count from user_llm_chat ulc left join user_chat_record ucr
+    ON ucr.chat_id = ulc.id where ulc.user_id=? and ` + utils.GenerateQuerySql(utils.ToDate, &utils.QueryParam{Column: "ulc.created_time"}) + ` BETWEEN ? and ? GROUP BY ulc.id order by ulc.created_time desc`
+	err = o.Raw(sql, userId, monDay, toDay).Find(&chatList).Error
 	return
 }

+ 18 - 0
routers/commentsRouter.go

@@ -8575,6 +8575,24 @@ func init() {
             Filters: nil,
             Params: nil})
 
+    beego.GlobalControllerRouter["eta/eta_api/controllers/rag:UserChatController"] = append(beego.GlobalControllerRouter["eta/eta_api/controllers/rag:UserChatController"],
+        beego.ControllerComments{
+            Method: "ChatRecordList",
+            Router: `/chat/chat_record_list`,
+            AllowHTTPMethods: []string{"get"},
+            MethodParams: param.Make(),
+            Filters: nil,
+            Params: nil})
+
+    beego.GlobalControllerRouter["eta/eta_api/controllers/rag:UserChatController"] = append(beego.GlobalControllerRouter["eta/eta_api/controllers/rag:UserChatController"],
+        beego.ControllerComments{
+            Method: "ChatRecordAdd",
+            Router: `/chat/chat_record_save`,
+            AllowHTTPMethods: []string{"post"},
+            MethodParams: param.Make(),
+            Filters: nil,
+            Params: nil})
+
     beego.GlobalControllerRouter["eta/eta_api/controllers/rag:UserChatController"] = append(beego.GlobalControllerRouter["eta/eta_api/controllers/rag:UserChatController"],
         beego.ControllerComments{
             Method: "NewChat",

+ 167 - 0
services/llm/chat_service.go

@@ -0,0 +1,167 @@
+package llm
+
+import (
+	"encoding/json"
+	"eta/eta_api/global"
+	"eta/eta_api/models/llm"
+	"eta/eta_api/utils"
+	"eta/eta_api/utils/lock"
+	"eta/eta_api/utils/redis"
+	"fmt"
+	"github.com/google/uuid"
+	"strconv"
+	"strings"
+	"time"
+)
+
+const (
+	redisChatPrefix = "chat:zet:"
+	RecordLock      = "lock:chat:record:"
+	redisTTL        = 24 * time.Hour // Redis 缓存过期时间
+)
+
+// AddChatRecord 添加聊天记录到 Redis
+func AddChatRecord(record *llm.UserChatRecordRedis) error {
+	key := fmt.Sprintf("%s%d", redisChatPrefix, record.ChatId)
+	lockKey := fmt.Sprintf("%s%s", RecordLock, key)
+	fmt.Printf("2%s", lockKey)
+	holder, _ := uuid.NewUUID()
+	if lock.AcquireLock(lockKey, 10, holder.String()) {
+		defer lock.ReleaseLock(key, holder.String())
+		data, err := json.Marshal(record)
+		if err != nil {
+			return fmt.Errorf("序列化聊天记录失败: %w", err)
+		}
+		zSet, _ := utils.Rc.ZRangeWithScores(key)
+		if len(zSet) == 0 {
+			// 设置过期时间
+			_ = utils.Rc.Expire(key, 24*time.Hour)
+		}
+		zSet = append(zSet, &redis.Zset{
+			Member: data,
+			Score:  float64(time.Now().Unix()),
+		})
+		fmt.Println(strconv.Itoa(len(zSet)))
+		err = utils.Rc.ZAdd(key, zSet...)
+		if err != nil {
+			return fmt.Errorf("保存聊天记录到 Redis 失败: %w", err)
+		}
+		return nil
+	}
+	return fmt.Errorf("获取锁失败,请稍后重试")
+}
+
+// GetChatRecordsFromRedis 从 Redis 获取聊天记录
+func GetChatRecordsFromRedis(chatId int) (redisList []*llm.UserChatRecordRedis, err error) {
+	key := fmt.Sprintf("%s%d", redisChatPrefix, chatId)
+	zSet, _ := utils.Rc.ZRangeWithScores(key)
+	if len(zSet) == 0 {
+		// 缓存不存在,从数据库拉取数据
+		records, dbErr := GetChatRecordsFromDB(chatId)
+		if dbErr != nil {
+			err = fmt.Errorf("从数据库获取聊天记录失败: %w", dbErr)
+			return
+		}
+		// 将数据保存到 Redis
+		for _, record := range records {
+			redisRecord := &llm.UserChatRecordRedis{
+				Id:           record.Id,
+				ChatId:       chatId,
+				ChatUserType: record.ChatUserType,
+				Content:      record.Content,
+				SendTime:     record.SendTime.Format(utils.FormatDateTime),
+			}
+			redisList = append(redisList, redisRecord)
+		}
+		return
+	}
+	for _, z := range zSet {
+		var redisRecord llm.UserChatRecordRedis
+		if err = json.Unmarshal([]byte(z.Member.(string)), &redisRecord); err != nil {
+			return nil, fmt.Errorf("解析聊天记录失败: %w", err)
+		}
+		redisList = append(redisList, &redisRecord)
+	}
+	return
+}
+
+// SaveChatRecordsToDB 将 Redis 中的聊天记录保存到数据库
+func SaveChatRecordsToDB(chatId int) error {
+	list, err := GetChatRecordsFromRedis(chatId)
+	if err != nil {
+		return err
+	}
+	var newRecords []*llm.UserChatRecord
+	for _, record := range list {
+		//if record.Id == 0 {
+		sendTime, parseErr := time.Parse(utils.FormatDateTime, record.SendTime)
+		if parseErr != nil {
+			sendTime = time.Now()
+		}
+		newRecords = append(newRecords, &llm.UserChatRecord{
+			ChatId:       record.ChatId,
+			ChatUserType: record.ChatUserType,
+			Content:      record.Content,
+			SendTime:     sendTime,
+			CreatedTime:  time.Now(),
+		})
+		//}
+	}
+	//先删除redis中的缓存
+	//_ = RemoveChatRecord(chatId)
+	err = llm.BatchInsertRecords(newRecords)
+	if err != nil {
+		utils.FileLog.Error("批量插入记录失败:", err.Error())
+		return fmt.Errorf("批量插入记录失败: %w", err)
+	}
+	//_ = RemoveChatRecord(chatId)
+	return nil
+}
+
+// SaveAllChatRecordsToDB 定时任务保存所有 Redis 中的聊天记录到数据库
+func SaveAllChatRecordsToDB() {
+	for {
+		fmt.Println("开始保存聊天记录到数据库...")
+		keys, err := utils.Rc.Keys(redisChatPrefix + "*")
+		if err != nil {
+			utils.FileLog.Error("获取 Redis 键失败: %v", err)
+			return
+		}
+		for _, key := range keys {
+			lockKey := fmt.Sprintf("%s%s", RecordLock, key)
+			fmt.Printf("1%s", lockKey)
+			chatIdStr := strings.TrimPrefix(key, redisChatPrefix)
+			chatId, parseErr := strconv.Atoi(chatIdStr)
+			if parseErr != nil {
+				utils.FileLog.Error("解析聊天ID失败: %v", err)
+				continue
+			}
+			if lock.AcquireLock(lockKey, 10, "system_task") {
+				if err = SaveChatRecordsToDB(chatId); err != nil {
+					utils.FileLog.Error("解析聊天ID失败: %v", err)
+				}
+				lock.ReleaseLock(key, "system_task")
+			}
+		}
+		time.Sleep(10 * time.Second)
+	}
+}
+
+// RemoveChatRecord 从 Redis 删除聊天记录
+func RemoveChatRecord(chatId int) error {
+	key := fmt.Sprintf("%s%d", redisChatPrefix, chatId)
+	err := utils.Rc.Delete(key)
+	if err != nil {
+		return fmt.Errorf("删除 Redis 缓存失败: %w", err)
+	}
+	return nil
+}
+
+func GetChatRecordsFromDB(chatID int) ([]*llm.UserChatRecord, error) {
+	o := global.DbMap[utils.DbNameAI]
+	var records []*llm.UserChatRecord
+	if err := o.Where("chat_id = ?", chatID).Find(&records).Error; err != nil {
+		return nil, fmt.Errorf("从数据库获取聊天记录失败: %w", err)
+	}
+	return records, nil
+}

+ 2 - 0
services/task.go

@@ -5,6 +5,7 @@ import (
 	"eta/eta_api/services/binlog"
 	"eta/eta_api/services/data"
 	edbmonitor "eta/eta_api/services/edb_monitor"
+	"eta/eta_api/services/llm"
 	"eta/eta_api/utils"
 	"fmt"
 	"strings"
@@ -61,6 +62,7 @@ func Task() {
 		go binlog.HandleDataSourceChange2Es()
 	}
 	go StartSessionManager()
+	go llm.SaveAllChatRecordsToDB()
 	// TODO:数据修复
 	//FixNewEs()
 	fmt.Println("task end")

+ 56 - 0
utils/lock/distrubtLock.go

@@ -0,0 +1,56 @@
+package lock
+
+import (
+	"context"
+	"eta/eta_api/utils"
+	"fmt"
+	"github.com/go-redis/redis/v8"
+)
+
+const (
+	lockName = "lock:"
+)
+
+var (
+	ctx = context.Background()
+)
+
+func AcquireLock(key string, expiration int, Holder string) bool {
+	script := redis.NewScript(`local key = KEYS[1]
+			local clientId = ARGV[1]
+			local expiration = tonumber(ARGV[2])
+			if redis.call("EXISTS", key) == 0 then
+				redis.call("SET", key, clientId, "EX", expiration)
+				return 1
+			else
+				return 0
+			end`)
+	lockey := fmt.Sprintf("%s%s", lockName, key)
+	result, err := script.Run(ctx, utils.Rc.RedisClient(), []string{lockey}, Holder, expiration).Int()
+	if err != nil {
+		return false
+	}
+	if result == 1 {
+		return true
+	}
+	return false
+}
+
+func ReleaseLock(key string, holder string) bool {
+	script := redis.NewScript(`
+	   if redis.call("get", KEYS[1]) == ARGV[1] then
+	       return redis.call("del", KEYS[1])
+	   else
+	       return 0
+	   end
+	`)
+	lockey := fmt.Sprintf("%s%s", lockName, key)
+	result, err := script.Run(ctx, utils.Rc.RedisClient(), []string{lockey}, holder).Int()
+	if err != nil {
+		return false
+	}
+	if result == 1 {
+		return true
+	}
+	return false
+}

+ 6 - 1
utils/redis.go

@@ -2,6 +2,7 @@ package utils
 
 import (
 	"eta/eta_api/utils/redis"
+	client "github.com/go-redis/redis/v8"
 	"time"
 )
 
@@ -24,6 +25,11 @@ type RedisClient interface {
 	SAdd(key string, args ...interface{}) (err error)
 	SRem(key string, args ...interface{}) (err error)
 	SIsMember(key string, args interface{}) (bool, error)
+	ZAdd(key string, members ...*redis.Zset) error
+	ZRangeWithScores(key string) ([]*redis.Zset, error)
+	Expire(key string, duration time.Duration) error
+	RedisClient() client.UniversalClient
+	Keys(pattern string) ([]string, error)
 }
 
 func initRedis(redisType string, conf string) (redisClient RedisClient, err error) {
@@ -33,7 +39,6 @@ func initRedis(redisType string, conf string) (redisClient RedisClient, err erro
 	default: // 默认走单机
 		redisClient, err = redis.InitStandaloneRedis(conf)
 	}
-
 	return
 }
 

+ 37 - 0
utils/redis/cluster_redis.go

@@ -16,6 +16,10 @@ import (
 type ClusterRedisClient struct {
 	redisClient *redis.ClusterClient
 }
+type Zset struct {
+	Score  float64
+	Member interface{}
+}
 
 var DefaultKey = "zcmRedis"
 
@@ -317,3 +321,36 @@ func (rc *ClusterRedisClient) SIsMember(key string, args interface{}) (isMember
 	isMember, err = rc.redisClient.SIsMember(context.TODO(), key, args).Result()
 	return
 }
+func (rc *ClusterRedisClient) ZAdd(key string, members ...*Zset) error {
+	var redisMembers []*redis.Z
+	for _, member := range members {
+		redisMembers = append(redisMembers, &redis.Z{
+			Member: member.Member,
+			Score:  member.Score,
+		})
+	}
+	return rc.redisClient.ZAdd(context.TODO(), key, redisMembers...).Err()
+}
+func (rc *ClusterRedisClient) ZRangeWithScores(key string) (result []*Zset, err error) {
+	redisZList, err := rc.redisClient.ZRangeWithScores(context.TODO(), key, 0, -1).Result()
+	if err != nil {
+		return
+	}
+	for _, redisZ := range redisZList {
+		result = append(result, &Zset{
+			Member: redisZ.Member,
+			Score:  redisZ.Score,
+		})
+	}
+	return
+}
+func (rc *ClusterRedisClient) Expire(key string, duration time.Duration) error {
+	return rc.redisClient.Expire(context.Background(), key, duration).Err()
+}
+
+func (rc *ClusterRedisClient) Keys(pattern string) (keys []string, err error) {
+	return rc.redisClient.Keys(context.Background(), pattern).Result()
+}
+func (rc *ClusterRedisClient) RedisClient() redis.UniversalClient {
+	return rc.redisClient
+}

+ 34 - 0
utils/redis/standalone_redis.go

@@ -309,3 +309,37 @@ func (rc *StandaloneRedisClient) SIsMember(key string, args interface{}) (isMemb
 	isMember, err = rc.redisClient.SIsMember(context.TODO(), key, args).Result()
 	return
 }
+
+func (rc *StandaloneRedisClient) ZAdd(key string, members ...*Zset) error {
+	var redisMembers []*redis.Z
+	for _, member := range members {
+		redisMembers = append(redisMembers, &redis.Z{
+			Member: member.Member,
+			Score:  member.Score,
+		})
+	}
+	return rc.redisClient.ZAdd(context.TODO(), key, redisMembers...).Err()
+}
+
+func (rc *StandaloneRedisClient) ZRangeWithScores(key string) (result []*Zset, err error) {
+	redisZList, err := rc.redisClient.ZRangeWithScores(context.TODO(), key, 0, -1).Result()
+	if err != nil {
+		return
+	}
+	for _, redisZ := range redisZList {
+		result = append(result, &Zset{
+			Member: redisZ.Member,
+			Score:  redisZ.Score,
+		})
+	}
+	return
+}
+func (rc *StandaloneRedisClient) Expire(key string, duration time.Duration) error {
+	return rc.redisClient.Expire(context.Background(), key, duration).Err()
+}
+func (rc *StandaloneRedisClient) Keys(pattern string) (keys []string, err error) {
+	return rc.redisClient.Keys(context.Background(), pattern).Result()
+}
+func (rc *StandaloneRedisClient) RedisClient() redis.UniversalClient {
+	return rc.redisClient
+}

+ 1 - 0
utils/ws/session.go

@@ -22,6 +22,7 @@ type Session struct {
 	mu          sync.RWMutex
 	sessionOnce sync.Once
 }
+
 type Message struct {
 	KbName     string   `json:"KbName"`
 	Query      string   `json:"Query"`

+ 2 - 0
utils/ws/session_manager.go

@@ -67,6 +67,8 @@ func (manager *ConnectionManager) HandleMessage(userID int, sessionID string, me
 	}
 	// 处理业务逻辑
 	session.History = append(session.History, userMessage.LastTopics...)
+
+	//TODO
 	resp, err := llmService.KnowledgeBaseChat(userMessage.Query, userMessage.KbName, session.History)
 	defer func() {
 		_ = resp.Body.Close()