chat_service.go 5.9 KB


  1. package llm
  2. import (
  3. "encoding/json"
  4. "eta/eta_api/global"
  5. "eta/eta_api/models/llm"
  6. "eta/eta_api/utils"
  7. "eta/eta_api/utils/lock"
  8. "eta/eta_api/utils/redis"
  9. "fmt"
  10. "github.com/google/uuid"
  11. "strconv"
  12. "strings"
  13. "sync"
  14. "time"
  15. )
  16. const (
  17. redisChatPrefix = "chat:zet:"
  18. redisTTL = 24 * time.Hour // Redis 缓存过期时间
  19. )
  20. // AddChatRecord 添加聊天记录到 Redis
  21. func AddChatRecord(record *llm.UserChatRecordRedis) error {
  22. key := fmt.Sprintf("%s%d", redisChatPrefix, record.ChatId)
  23. holder, _ := uuid.NewRandom()
  24. holderStr := fmt.Sprintf("user_%s", holder.String())
  25. if err := lock.TryLock(key, 10, holderStr, 10*time.Second); err == nil {
  26. defer func() {
  27. lock.ReleaseLock(key, holderStr)
  28. }()
  29. data, parseErr := json.Marshal(record)
  30. if parseErr != nil {
  31. return fmt.Errorf("序列化聊天记录失败: %w", parseErr)
  32. }
  33. zSet, _ := utils.Rc.ZRangeWithScores(key)
  34. if len(zSet) == 0 {
  35. // 设置过期时间
  36. _ = utils.Rc.Expire(key, 24*time.Hour)
  37. }
  38. zSet = append(zSet, &redis.Zset{
  39. Member: data,
  40. Score: float64(time.Now().Unix()),
  41. })
  42. err = utils.Rc.ZAdd(key, zSet...)
  43. if err != nil {
  44. return fmt.Errorf("保存聊天记录到 Redis 失败: %w", err)
  45. }
  46. return nil
  47. }
  48. return fmt.Errorf("获取锁失败,请稍后重试")
  49. }
  50. // GetChatRecordsFromRedis 从 Redis 获取聊天记录
  51. func GetChatRecordsFromRedis(chatId int) (redisList []*llm.UserChatRecordRedis, err error) {
  52. key := fmt.Sprintf("%s%d", redisChatPrefix, chatId)
  53. zSet, _ := utils.Rc.ZRangeWithScores(key)
  54. if len(zSet) == 0 {
  55. // 缓存不存在,从数据库拉取数据
  56. records, dbErr := GetChatRecordsFromDB(chatId)
  57. if dbErr != nil {
  58. err = fmt.Errorf("从数据库获取聊天记录失败: %w", dbErr)
  59. return
  60. }
  61. // 将数据保存到 Redis
  62. for _, record := range records {
  63. redisRecord := &llm.UserChatRecordRedis{
  64. Id: record.Id,
  65. ChatId: chatId,
  66. ChatUserType: record.ChatUserType,
  67. Content: record.Content,
  68. SendTime: record.SendTime.Format(utils.FormatDateTime),
  69. }
  70. redisList = append(redisList, redisRecord)
  71. }
  72. return
  73. }
  74. for _, z := range zSet {
  75. var redisRecord llm.UserChatRecordRedis
  76. if err = json.Unmarshal([]byte(z.Member.(string)), &redisRecord); err != nil {
  77. return nil, fmt.Errorf("解析聊天记录失败: %w", err)
  78. }
  79. redisList = append(redisList, &redisRecord)
  80. }
  81. return
  82. }
  83. func flushRecordsToRedis(chatId int) (err error) {
  84. key := fmt.Sprintf("%s%d", redisChatPrefix, chatId)
  85. zSet, _ := utils.Rc.ZRangeWithScores(key)
  86. if len(zSet) == 0 {
  87. // 缓存不存在,从数据库拉取数据
  88. records, dbErr := GetChatRecordsFromDB(chatId)
  89. if dbErr != nil {
  90. err = fmt.Errorf("从数据库获取聊天记录失败: %w", dbErr)
  91. return
  92. }
  93. var zet []*redis.Zset
  94. // 将数据保存到 Redis
  95. for _, record := range records {
  96. redisRecord := &llm.UserChatRecordRedis{
  97. Id: record.Id,
  98. ChatId: chatId,
  99. ChatUserType: record.ChatUserType,
  100. Content: record.Content,
  101. SendTime: record.SendTime.Format(utils.FormatDateTime),
  102. }
  103. data, parseErr := json.Marshal(&redisRecord)
  104. if parseErr != nil {
  105. utils.FileLog.Error("解析聊天记录失败: %w", err)
  106. }
  107. zet = append(zet, &redis.Zset{
  108. Member: data,
  109. Score: float64(record.SendTime.Unix()),
  110. })
  111. }
  112. _ = utils.Rc.ZAdd(key, zet...)
  113. }
  114. return
  115. }
  116. // SaveChatRecordsToDB 将 Redis 中的聊天记录保存到数据库
  117. func SaveChatRecordsToDB(chatId int) error {
  118. list, err := GetChatRecordsFromRedis(chatId)
  119. if err != nil {
  120. return err
  121. }
  122. var newRecords []*llm.UserChatRecord
  123. for _, record := range list {
  124. if record.Id == 0 {
  125. sendTime, parseErr := time.ParseInLocation(utils.FormatDateTime, record.SendTime, time.Local)
  126. if parseErr != nil {
  127. sendTime = time.Now()
  128. }
  129. newRecords = append(newRecords, &llm.UserChatRecord{
  130. Id: record.Id,
  131. ChatId: record.ChatId,
  132. ChatUserType: record.ChatUserType,
  133. Content: record.Content,
  134. SendTime: sendTime,
  135. CreatedTime: time.Now(),
  136. })
  137. }
  138. }
  139. key := fmt.Sprintf("%s%d", redisChatPrefix, chatId)
  140. holder, _ := uuid.NewRandom()
  141. holderStr := fmt.Sprintf("sys_%s", holder.String())
  142. defer func() {
  143. lock.ReleaseLock(key, holderStr)
  144. }()
  145. if lock.AcquireLock(key, 10, holderStr) {
  146. //先删除redis中的缓存
  147. _ = RemoveChatRecord(chatId)
  148. err = llm.BatchInsertRecords(newRecords)
  149. if err != nil {
  150. utils.FileLog.Error("批量插入记录失败:", err.Error())
  151. return fmt.Errorf("批量插入记录失败: %w", err)
  152. }
  153. _ = RemoveChatRecord(chatId)
  154. //重新加载数据
  155. _ = flushRecordsToRedis(chatId)
  156. }
  157. return nil
  158. }
  159. // SaveAllChatRecordsToDB 定时任务保存所有 Redis 中的聊天记录到数据库
  160. func SaveAllChatRecordsToDB() {
  161. for {
  162. keys, err := utils.Rc.Keys(redisChatPrefix + "*")
  163. if err != nil {
  164. utils.FileLog.Error("获取 Redis 键失败: %v", err)
  165. return
  166. }
  167. var wg sync.WaitGroup
  168. wg.Add(len(keys))
  169. for _, key := range keys {
  170. go func(key string) {
  171. defer wg.Done()
  172. chatIdStr := strings.TrimPrefix(key, redisChatPrefix)
  173. chatId, parseErr := strconv.Atoi(chatIdStr)
  174. if parseErr != nil {
  175. utils.FileLog.Error("解析聊天ID失败: %v", err)
  176. return
  177. }
  178. if err = SaveChatRecordsToDB(chatId); err != nil {
  179. utils.FileLog.Error("解析聊天ID失败: %v", err)
  180. }
  181. }(key)
  182. }
  183. wg.Wait()
  184. time.Sleep(10 * time.Minute)
  185. }
  186. }
  187. // RemoveChatRecord 从 Redis 删除聊天记录
  188. func RemoveChatRecord(chatId int) error {
  189. key := fmt.Sprintf("%s%d", redisChatPrefix, chatId)
  190. err := utils.Rc.Delete(key)
  191. if err != nil {
  192. return fmt.Errorf("删除 Redis 缓存失败: %w", err)
  193. }
  194. return nil
  195. }
  196. func GetChatRecordsFromDB(chatId int) ([]*llm.UserChatRecord, error) {
  197. o := global.DbMap[utils.DbNameAI]
  198. var records []*llm.UserChatRecord
  199. if err := o.Where("chat_id = ?", chatId).Find(&records).Error; err != nil {
  200. return nil, fmt.Errorf("从数据库获取聊天记录失败: %w", err)
  201. }
  202. return records, nil
  203. }