kobe6258 1 сар өмнө
parent
commit
69a2b6be18

+ 9 - 1
controllers/rag/llm_http/request.go

@@ -7,6 +7,14 @@ type LLMQuestionReq struct {
 }
 
 type UserChatReq struct {
-	ChatId   int    `json:"ChatId"`
+	ChatId    int    `json:"ChatId"`
 	ChatTitle string `json:"ChatTitle" description:"会话名称"`
 }
+
+type UserChatRecordReq struct {
+	Id           int    `json:"Id"`
+	ChatId       int    `json:"ChatId"`
+	Content      string `json:"Content" description:"会话名称"`
+	ChatUserType string `json:"ChatUserType" description:"用户类型"`
+	SendTime     string `json:"SendTime" description:"发送时间"`
+}

+ 5 - 3
controllers/rag/llm_http/response.go

@@ -1,7 +1,9 @@
 package llm_http
 
+import "eta/eta_api/models/llm"
 
-type LLMQuestionRes struct {
-	Answer      string `description:"回答"`
-	SessionId     string `description:"会话ID"`
+type UserChatListResp struct {
+	TodayList     []llm.UserLlmChatListViewItem
+	YesterdayList []llm.UserLlmChatListViewItem
+	WeekList      []llm.UserLlmChatListViewItem
 }

+ 130 - 6
controllers/rag/user_chat_controller.go

@@ -6,6 +6,7 @@ import (
 	"eta/eta_api/controllers/rag/llm_http"
 	"eta/eta_api/models"
 	"eta/eta_api/models/llm"
+	llmService "eta/eta_api/services/llm"
 	"eta/eta_api/utils"
 	"time"
 )
@@ -126,20 +127,143 @@ func (ucCtrl *UserChatController) GetUserChatList() {
 		return
 	}
 	//周日是0,周六是6
-	toDay := time.Now().Weekday()
-	offset := int(time.Monday - toDay)
+	weekDay := time.Now().Weekday()
+	offset := int(time.Monday - weekDay)
 	if offset > 0 {
 		offset -= 7
 	}
+	today := time.Now().Format(utils.FormatDate)
 	monDay := time.Now().AddDate(0, 0, offset).Format(utils.FormatDate)
+	yesterday := time.Now().AddDate(0, 0, -1).Format(utils.FormatDate)
 	chatList, err := llm.GetUserChatList(sysUser.AdminId, monDay, time.Now().Format(utils.FormatDate))
 	if err != nil {
-		br.Msg = "重命名失败"
-		br.ErrMsg = "重命名失败,Err:" + err.Error()
+		br.Msg = "获取用户聊天列表失败"
+		br.ErrMsg = "获取用户聊天列表失败,Err:" + err.Error()
 		return
 	}
-	br.Data = chatList
+	data := new(llm_http.UserChatListResp)
+	data.WeekList = make([]llm.UserLlmChatListViewItem, 0)
+	data.YesterdayList = make([]llm.UserLlmChatListViewItem, 0)
+	data.TodayList = make([]llm.UserLlmChatListViewItem, 0)
+	for _, v := range chatList {
+		if v.CreatedTime.Format(utils.FormatDate) == today {
+			data.TodayList = append(data.TodayList, llm.CovertItemToView(v))
+		} else if v.CreatedTime.Format(utils.FormatDate) == yesterday {
+			data.YesterdayList = append(data.YesterdayList, llm.CovertItemToView(v))
+		} else {
+			data.WeekList = append(data.WeekList, llm.CovertItemToView(v))
+		}
+	}
+
+	br.Data = data
 	br.Ret = 200
 	br.Success = true
-	br.Msg = "重命名成功"
+	br.Msg = "获取用户聊天列表成功"
+}
+
+// ChatRecordAdd @Title 保存聊天记录
+// @Description 保存聊天记录
+// @Success 101 {object} response.ListResp
+// @router /chat/chat_record_save [post]
+func (ucCtrl *UserChatController) ChatRecordAdd() {
+	br := new(models.BaseResponse).Init()
+	defer func() {
+		ucCtrl.Data["json"] = br
+		ucCtrl.ServeJSON()
+	}()
+	var req llm_http.UserChatRecordReq
+	err := json.Unmarshal(ucCtrl.Ctx.Input.RequestBody, &req)
+	if err != nil {
+		br.Msg = "参数解析异常!"
+		br.ErrMsg = "参数解析失败,Err:" + err.Error()
+		return
+	}
+	sysUser := ucCtrl.SysUser
+	if sysUser == nil {
+		br.Msg = "请登录"
+		br.ErrMsg = "请登录,SysUser Is Empty"
+		br.Ret = 408
+		return
+	}
+	if req.ChatId <= 0 {
+		br.Msg = "非法的对话框Id"
+		br.ErrMsg = "非法的对话框Id"
+		return
+	}
+	if req.Content == "" {
+		br.Msg = "聊天记录不能为空"
+		br.ErrMsg = "聊天记录不能为空"
+		return
+	}
+	if req.Id < 0 {
+		br.Msg = "非法的Id"
+		br.ErrMsg = "非法的Id"
+		return
+	}
+	if req.ChatUserType != "user" && req.ChatUserType != "assistant" {
+		br.Msg = "非法的用户类型"
+		br.ErrMsg = "非法的用户类型,用户类型支持:user/assistant"
+		return
+	}
+	if req.SendTime == "" {
+		req.SendTime = time.Now().Format(utils.FormatDateTime)
+	} else {
+		_, err = time.Parse(utils.FormatDateTime, req.SendTime)
+		if err != nil {
+			br.Msg = "非法的发送时间"
+			br.ErrMsg = "非法的发送时间,Err:" + err.Error()
+			return
+		}
+	}
+	record := llm.UserChatRecordRedis{
+		ChatId:       req.ChatId,
+		ChatUserType: req.ChatUserType,
+		Content:      req.Content,
+		SendTime:     req.SendTime,
+	}
+	err = llmService.AddChatRecord(&record)
+	if err != nil {
+		br.Msg = "添加聊天记录失败"
+		br.ErrMsg = "添加聊天记录失败,Err:" + err.Error()
+		return
+	}
+	br.Ret = 200
+	br.Success = true
+	br.Msg = "添加聊天记录成功"
+}
+
+// ChatRecordList @Title 获取聊天记录
+// @Description 获取聊天记录
+// @Success 101 {object} response.ListResp
+// @router /chat/chat_record_list [get]
+func (ucCtrl *UserChatController) ChatRecordList() {
+	br := new(models.BaseResponse).Init()
+	defer func() {
+		ucCtrl.Data["json"] = br
+		ucCtrl.ServeJSON()
+	}()
+	sysUser := ucCtrl.SysUser
+	if sysUser == nil {
+		br.Msg = "请登录"
+		br.ErrMsg = "请登录,SysUser Is Empty"
+		br.Ret = 408
+		return
+	}
+	chatId, _ := ucCtrl.GetInt("ChatId", 0)
+	if chatId <= 0 {
+		br.Msg = "非法的对话Id"
+		br.ErrMsg = "非法的对话Id"
+		return
+	}
+
+	list, err := llmService.GetChatRecordsFromRedis(chatId)
+	if err != nil {
+		br.Msg = "获取聊天记录失败"
+		br.ErrMsg = "获取聊天记录失败,Err:" + err.Error()
+		return
+	}
+	br.Data = list
+	br.Ret = 200
+	br.Success = true
+	br.Msg = "获取聊天记录成功"
 }

+ 8 - 0
models/llm/user_chat_record.go

@@ -7,10 +7,18 @@ type UserChatRecord struct {
 	Id           int       `gorm:"primaryKey;autoIncrement;comment:主键"`
 	ChatId       int       `gorm:"comment:会话id"`
 	ChatUserType string    `gorm:"type:enum('user','assistant');comment:用户方"`
+	Content      string    `gorm:"content:内容"`
 	SendTime     time.Time `gorm:"comment:发送时间"`
 	CreatedTime  time.Time `gorm:"comment:创建时间"`
 	UpdateTime   time.Time `gorm:"autoUpdateTime;comment:更新时间"`
 }
+type UserChatRecordRedis struct {
+	Id           int
+	ChatId       int
+	ChatUserType string
+	Content      string
+	SendTime     string
+}
 
 func (u *UserChatRecord) TableName() string {
 	return "user_chat_record"

+ 29 - 4
models/llm/user_llm_chat.go

@@ -16,6 +16,31 @@ type UserLlmChat struct {
 	UpdateTime  time.Time `gorm:"autoUpdateTime;comment:更新时间"`
 }
 
+type UserLlmChatListItem struct {
+	Id          int       `gorm:"primaryKey;autoIncrement;comment:会话主键"`
+	UserId      int       `gorm:"comment:用户id"`
+	ChatTitle   string    `gorm:"comment:会话标题"`
+	CreatedTime time.Time `gorm:"comment:创建时间"`
+	RecordCount int       `gorm:"comment:会话记录数"`
+}
+type UserLlmChatListViewItem struct {
+	Id          int    `gorm:"primaryKey;autoIncrement;comment:会话主键"`
+	UserId      int    `gorm:"comment:用户id"`
+	ChatTitle   string `gorm:"comment:会话标题"`
+	CreatedTime string `gorm:"comment:创建时间"`
+	RecordCount int    `gorm:"comment:会话记录数"`
+}
+
+func CovertItemToView(item UserLlmChatListItem) UserLlmChatListViewItem {
+	return UserLlmChatListViewItem{
+		Id:          item.Id,
+		UserId:      item.UserId,
+		ChatTitle:   item.ChatTitle,
+		CreatedTime: item.CreatedTime.Format(utils.FormatDateTime),
+		RecordCount: item.RecordCount,
+	}
+
+}
 func (u *UserLlmChat) TableName() string {
 	return "user_llm_chat"
 }
@@ -39,10 +64,10 @@ func (u *UserLlmChat) RenameChatSession() (err error) {
 	return
 }
 
-func GetUserChatList(userId int, monDay, toDay string) (chatList []UserLlmChat, err error) {
+func GetUserChatList(userId int, monDay, toDay string) (chatList []UserLlmChatListItem, err error) {
 	o := global.DbMap[utils.DbNameAI]
-	sql := `select ulc.id ,ulc.chat_title,ulc.created_time,COUNT(ucr.id) AS record_count from user_llm_chat ulc left join user_chat_record ucr
-    ON ucr.chat_id = ulc.id where ulc.user_id=? and ? BETWEEN ? and ? GROUP BY ulc.id`
-	err = o.Raw(sql, userId, utils.GenerateQuerySql(utils.ToDate, &utils.QueryParam{Column: "ulc.created_time"}), monDay, toDay).Find(&chatList).Error
+	sql := `select ulc.id AS id ,ulc.user_id as user_id,ulc.chat_title as chat_title,ulc.created_time,COUNT(ucr.id) AS record_count from user_llm_chat ulc left join user_chat_record ucr
+    ON ucr.chat_id = ulc.id where ulc.user_id=? and ` + utils.GenerateQuerySql(utils.ToDate, &utils.QueryParam{Column: "ulc.created_time"}) + ` BETWEEN ? and ? GROUP BY ulc.id order by ulc.created_time desc`
+	err = o.Raw(sql, userId, monDay, toDay).Find(&chatList).Error
 	return
 }

+ 18 - 0
routers/commentsRouter.go

@@ -8575,6 +8575,24 @@ func init() {
             Filters: nil,
             Params: nil})
 
+    beego.GlobalControllerRouter["eta/eta_api/controllers/rag:UserChatController"] = append(beego.GlobalControllerRouter["eta/eta_api/controllers/rag:UserChatController"],
+        beego.ControllerComments{
+            Method: "ChatRecordList",
+            Router: `/chat/chat_record_list`,
+            AllowHTTPMethods: []string{"get"},
+            MethodParams: param.Make(),
+            Filters: nil,
+            Params: nil})
+
+    beego.GlobalControllerRouter["eta/eta_api/controllers/rag:UserChatController"] = append(beego.GlobalControllerRouter["eta/eta_api/controllers/rag:UserChatController"],
+        beego.ControllerComments{
+            Method: "ChatRecordAdd",
+            Router: `/chat/chat_record_save`,
+            AllowHTTPMethods: []string{"post"},
+            MethodParams: param.Make(),
+            Filters: nil,
+            Params: nil})
+
     beego.GlobalControllerRouter["eta/eta_api/controllers/rag:UserChatController"] = append(beego.GlobalControllerRouter["eta/eta_api/controllers/rag:UserChatController"],
         beego.ControllerComments{
             Method: "NewChat",

+ 147 - 0
services/llm/chat_service.go

@@ -0,0 +1,147 @@
+package llm
+
+import (
+	"encoding/json"
+	"eta/eta_api/global"
+	"eta/eta_api/models/llm"
+	"eta/eta_api/utils"
+	"eta/eta_api/utils/redis"
+	"fmt"
+	"time"
+)
+
+const (
+	redisChatPrefix = "chat:zet:"
+	redisTTL        = 24 * time.Hour // Redis 缓存过期时间
+)
+
+// AddChatRecord 添加聊天记录到 Redis
+func AddChatRecord(record *llm.UserChatRecordRedis) error {
+	key := fmt.Sprintf("%s%d", redisChatPrefix, record.ChatId)
+	data, err := json.Marshal(record)
+	if err != nil {
+		return fmt.Errorf("序列化聊天记录失败: %w", err)
+	}
+	zSet, _ := utils.Rc.ZRangeWithScores(key)
+	if len(zSet) == 0 {
+		// 设置过期时间
+		_ = utils.Rc.Expire(key, 24*time.Hour)
+	}
+	zSet = append(zSet, &redis.Zset{
+		Member: data,
+		Score:  float64(time.Now().Unix()),
+	})
+	err = utils.Rc.ZAdd(key, zSet...)
+	if err != nil {
+		return fmt.Errorf("保存聊天记录到 Redis 失败: %w", err)
+	}
+	return nil
+}
+
+// GetChatRecordsFromRedis 从 Redis 获取聊天记录
+func GetChatRecordsFromRedis(chatId int) (redisList []*llm.UserChatRecordRedis, err error) {
+	key := fmt.Sprintf("%s%d", redisChatPrefix, chatId)
+	zSet, _ := utils.Rc.ZRangeWithScores(key)
+	if len(zSet) == 0 {
+		// 缓存不存在,从数据库拉取数据
+		records, dbErr := GetChatRecordsFromDB(chatId)
+		if dbErr != nil {
+			err = fmt.Errorf("从数据库获取聊天记录失败: %w", dbErr)
+			return
+		}
+		// 将数据保存到 Redis
+		for _, record := range records {
+			redisRecord := &llm.UserChatRecordRedis{
+				ChatId:       chatId,
+				ChatUserType: record.ChatUserType,
+				Content:      record.Content,
+				SendTime:     record.SendTime.Format(utils.FormatDateTime),
+			}
+			redisList = append(redisList, redisRecord)
+			if err = AddChatRecord(redisRecord); err != nil {
+				err = fmt.Errorf("保存聊天记录到 Redis 失败: %w", err)
+				return
+			}
+		}
+		return
+	}
+	for _, z := range zSet {
+		var redisRecord llm.UserChatRecordRedis
+		if err = json.Unmarshal([]byte(z.Member.(string)), &redisRecord); err != nil {
+			return nil, fmt.Errorf("解析聊天记录失败: %w", err)
+		}
+		redisList = append(redisList, &redisRecord)
+	}
+	return
+}
+
+//
+//// SaveChatRecordsToDB 将 Redis 中的聊天记录保存到数据库
+//func SaveChatRecordsToDB(chatID int) error {
+//	key := fmt.Sprintf("%s%d", redisChatPrefix, chatID)
+//	val, err := global.Redis.Get(global.Context, key).Result()
+//	if err == redis.Nil {
+//		return nil // 缓存不存在,无需保存
+//	} else if err != nil {
+//		return fmt.Errorf("从 Redis 获取聊天记录失败: %w", err)
+//	}
+//
+//	var records []*llm.UserChatRecord
+//	if err := json.Unmarshal([]byte(val), &records); err != nil {
+//		return fmt.Errorf("解析聊天记录失败: %w", err)
+//	}
+//
+//	o := global.DbMap[utils.DbNameAI]
+//	for _, record := range records {
+//		if err := o.Create(record).Error; err != nil {
+//			return fmt.Errorf("保存聊天记录到数据库失败: %w", err)
+//		}
+//	}
+//
+//	// 删除 Redis 缓存
+//	if err := global.Redis.Del(global.Context, key).Err(); err != nil {
+//		return fmt.Errorf("删除 Redis 缓存失败: %w", err)
+//	}
+//
+//	return nil
+//}
+//
+//// SaveAllChatRecordsToDB 定时任务保存所有 Redis 中的聊天记录到数据库
+//func SaveAllChatRecordsToDB() {
+//	keys, err := global.Redis.Keys(global.Context, redisChatPrefix+"*").Result()
+//	if err != nil {
+//		log.Printf("获取 Redis 键失败: %v", err)
+//		return
+//	}
+//
+//	for _, key := range keys {
+//		chatIDStr := strings.TrimPrefix(key, redisChatPrefix)
+//		chatID, err := strconv.Atoi(chatIDStr)
+//		if err != nil {
+//			log.Printf("解析聊天ID失败: %v", err)
+//			continue
+//		}
+//
+//		if err := SaveChatRecordsToDB(chatID); err != nil {
+//			log.Printf("保存聊天记录到数据库失败: %v", err)
+//		}
+//	}
+//}
+//
+//// RemoveChatRecord 从 Redis 删除聊天记录
+//func RemoveChatRecord(chatID int) error {
+//	key := fmt.Sprintf("%s%d", redisChatPrefix, chatID)
+//	if err := global.Redis.Del(global.Context, key).Err(); err != nil {
+//		return fmt.Errorf("删除 Redis 缓存失败: %w", err)
+//	}
+//	return nil
+//}
+
+func GetChatRecordsFromDB(chatID int) ([]*llm.UserChatRecord, error) {
+	o := global.DbMap[utils.DbNameAI]
+	var records []*llm.UserChatRecord
+	if err := o.Where("chat_id = ?", chatID).Find(&records).Error; err != nil {
+		return nil, fmt.Errorf("从数据库获取聊天记录失败: %w", err)
+	}
+	return records, nil
+}

+ 3 - 1
utils/redis.go

@@ -24,6 +24,9 @@ type RedisClient interface {
 	SAdd(key string, args ...interface{}) (err error)
 	SRem(key string, args ...interface{}) (err error)
 	SIsMember(key string, args interface{}) (bool, error)
+	ZAdd(key string, members ...*redis.Zset) error
+	ZRangeWithScores(key string) ([]*redis.Zset, error)
+	Expire(key string, duration time.Duration) error
 }
 
 func initRedis(redisType string, conf string) (redisClient RedisClient, err error) {
@@ -33,7 +36,6 @@ func initRedis(redisType string, conf string) (redisClient RedisClient, err erro
 	default: // 默认走单机
 		redisClient, err = redis.InitStandaloneRedis(conf)
 	}
-
 	return
 }
 

+ 30 - 1
utils/redis/cluster_redis.go

@@ -16,7 +16,10 @@ import (
 type ClusterRedisClient struct {
 	redisClient *redis.ClusterClient
 }
-
+type Zset struct {
+	Score  float64
+	Member interface{}
+}
 var DefaultKey = "zcmRedis"
 
 // InitClusterRedis
@@ -317,3 +320,29 @@ func (rc *ClusterRedisClient) SIsMember(key string, args interface{}) (isMember
 	isMember, err = rc.redisClient.SIsMember(context.TODO(), key, args).Result()
 	return
 }
+func (rc *ClusterRedisClient) ZAdd(key string, members ...*Zset) error {
+	var redisMembers []*redis.Z
+	for _, member := range members {
+		redisMembers = append(redisMembers, &redis.Z{
+			Member: member.Member,
+			Score:  member.Score,
+		})
+	}
+	return rc.redisClient.ZAdd(context.TODO(), key, redisMembers...).Err()
+}
+func (rc *ClusterRedisClient) ZRangeWithScores(key string) (result []*Zset, err error) {
+	redisZList, err := rc.redisClient.ZRangeWithScores(context.TODO(), key, 0, -1).Result()
+	if err != nil {
+		return
+	}
+	for _, redisZ := range redisZList {
+		result = append(result, &Zset{
+			Member: redisZ.Member,
+			Score:  redisZ.Score,
+		})
+	}
+	return
+}
+func (rc *ClusterRedisClient) Expire(key string, duration time.Duration) error {
+	return rc.redisClient.Expire(context.Background(), key, duration).Err()
+}

+ 28 - 0
utils/redis/standalone_redis.go

@@ -309,3 +309,31 @@ func (rc *StandaloneRedisClient) SIsMember(key string, args interface{}) (isMemb
 	isMember, err = rc.redisClient.SIsMember(context.TODO(), key, args).Result()
 	return
 }
+
+func (rc *StandaloneRedisClient) ZAdd(key string, members ...*Zset) error {
+	var redisMembers []*redis.Z
+	for _, member := range members {
+		redisMembers = append(redisMembers, &redis.Z{
+			Member: member.Member,
+			Score:  member.Score,
+		})
+	}
+	return rc.redisClient.ZAdd(context.TODO(), key, redisMembers...).Err()
+}
+
+func (rc *StandaloneRedisClient) ZRangeWithScores(key string) (result []*Zset, err error) {
+	redisZList, err := rc.redisClient.ZRangeWithScores(context.TODO(), key, 0, -1).Result()
+	if err != nil {
+		return
+	}
+	for _, redisZ := range redisZList {
+		result = append(result, &Zset{
+			Member: redisZ.Member,
+			Score:  redisZ.Score,
+		})
+	}
+	return
+}
+func (rc *StandaloneRedisClient) Expire(key string, duration time.Duration) error {
+	return rc.redisClient.Expire(context.Background(), key, duration).Err()
+}