chat_service.go 5.9 KB


  1. package llm
  2. import (
  3. "encoding/json"
  4. "eta/eta_api/global"
  5. "eta/eta_api/models/llm"
  6. "eta/eta_api/utils"
  7. "eta/eta_api/utils/lock"
  8. "eta/eta_api/utils/redis"
  9. "fmt"
  10. "github.com/google/uuid"
  11. "strconv"
  12. "strings"
  13. "sync"
  14. "time"
  15. )
  16. const (
  17. redisChatPrefix = "chat:zet:"
  18. redisTTL = 24 * time.Hour // Redis 缓存过期时间
  19. )
  20. // AddChatRecord 添加聊天记录到 Redis
  21. func AddChatRecord(record *llm.UserChatRecordRedis) error {
  22. key := fmt.Sprintf("%s%d", redisChatPrefix, record.ChatId)
  23. holder, _ := uuid.NewRandom()
  24. holderStr := fmt.Sprintf("user_%s", holder.String())
  25. if lock.AcquireLock(key, 10, holderStr) {
  26. defer func() {
  27. fmt.Printf("用户释放锁:%s", key)
  28. lock.ReleaseLock(key, holderStr)
  29. }()
  30. data, err := json.Marshal(record)
  31. if err != nil {
  32. return fmt.Errorf("序列化聊天记录失败: %w", err)
  33. }
  34. zSet, _ := utils.Rc.ZRangeWithScores(key)
  35. if len(zSet) == 0 {
  36. // 设置过期时间
  37. _ = utils.Rc.Expire(key, 24*time.Hour)
  38. }
  39. zSet = append(zSet, &redis.Zset{
  40. Member: data,
  41. Score: float64(time.Now().Unix()),
  42. })
  43. err = utils.Rc.ZAdd(key, zSet...)
  44. if err != nil {
  45. return fmt.Errorf("保存聊天记录到 Redis 失败: %w", err)
  46. }
  47. return nil
  48. }
  49. return fmt.Errorf("获取锁失败,请稍后重试")
  50. }
  51. // GetChatRecordsFromRedis 从 Redis 获取聊天记录
  52. func GetChatRecordsFromRedis(chatId int) (redisList []*llm.UserChatRecordRedis, err error) {
  53. key := fmt.Sprintf("%s%d", redisChatPrefix, chatId)
  54. zSet, _ := utils.Rc.ZRangeWithScores(key)
  55. if len(zSet) == 0 {
  56. // 缓存不存在,从数据库拉取数据
  57. records, dbErr := GetChatRecordsFromDB(chatId)
  58. if dbErr != nil {
  59. err = fmt.Errorf("从数据库获取聊天记录失败: %w", dbErr)
  60. return
  61. }
  62. // 将数据保存到 Redis
  63. for _, record := range records {
  64. redisRecord := &llm.UserChatRecordRedis{
  65. Id: record.Id,
  66. ChatId: chatId,
  67. ChatUserType: record.ChatUserType,
  68. Content: record.Content,
  69. SendTime: record.SendTime.Format(utils.FormatDateTime),
  70. }
  71. redisList = append(redisList, redisRecord)
  72. }
  73. return
  74. }
  75. for _, z := range zSet {
  76. var redisRecord llm.UserChatRecordRedis
  77. if err = json.Unmarshal([]byte(z.Member.(string)), &redisRecord); err != nil {
  78. return nil, fmt.Errorf("解析聊天记录失败: %w", err)
  79. }
  80. redisList = append(redisList, &redisRecord)
  81. }
  82. return
  83. }
  84. func flushRecordsToRedis(chatId int) (err error) {
  85. key := fmt.Sprintf("%s%d", redisChatPrefix, chatId)
  86. zSet, _ := utils.Rc.ZRangeWithScores(key)
  87. if len(zSet) == 0 {
  88. // 缓存不存在,从数据库拉取数据
  89. records, dbErr := GetChatRecordsFromDB(chatId)
  90. if dbErr != nil {
  91. err = fmt.Errorf("从数据库获取聊天记录失败: %w", dbErr)
  92. return
  93. }
  94. var zet []*redis.Zset
  95. // 将数据保存到 Redis
  96. for _, record := range records {
  97. redisRecord := &llm.UserChatRecordRedis{
  98. Id: record.Id,
  99. ChatId: chatId,
  100. ChatUserType: record.ChatUserType,
  101. Content: record.Content,
  102. SendTime: record.SendTime.Format(utils.FormatDateTime),
  103. }
  104. data, parseErr := json.Marshal(&redisRecord)
  105. if parseErr != nil {
  106. utils.FileLog.Error("解析聊天记录失败: %w", err)
  107. }
  108. zet = append(zet, &redis.Zset{
  109. Member: data,
  110. Score: float64(record.SendTime.Unix()),
  111. })
  112. }
  113. _ = utils.Rc.ZAdd(key, zet...)
  114. }
  115. return
  116. }
  117. // SaveChatRecordsToDB 将 Redis 中的聊天记录保存到数据库
  118. func SaveChatRecordsToDB(chatId int) error {
  119. list, err := GetChatRecordsFromRedis(chatId)
  120. if err != nil {
  121. return err
  122. }
  123. var newRecords []*llm.UserChatRecord
  124. for _, record := range list {
  125. if record.Id == 0 {
  126. sendTime, parseErr := time.Parse(utils.FormatDateTime, record.SendTime)
  127. if parseErr != nil {
  128. sendTime = time.Now()
  129. }
  130. newRecords = append(newRecords, &llm.UserChatRecord{
  131. Id: record.Id,
  132. ChatId: record.ChatId,
  133. ChatUserType: record.ChatUserType,
  134. Content: record.Content,
  135. SendTime: sendTime,
  136. CreatedTime: time.Now(),
  137. })
  138. }
  139. }
  140. key := fmt.Sprintf("%s%d", redisChatPrefix, chatId)
  141. holder, _ := uuid.NewRandom()
  142. holderStr := fmt.Sprintf("sys_%s", holder.String())
  143. defer func() {
  144. lock.ReleaseLock(key, holderStr)
  145. }()
  146. if lock.AcquireLock(key, 10, holderStr) {
  147. //先删除redis中的缓存
  148. _ = RemoveChatRecord(chatId)
  149. err = llm.BatchInsertRecords(newRecords)
  150. if err != nil {
  151. utils.FileLog.Error("批量插入记录失败:", err.Error())
  152. return fmt.Errorf("批量插入记录失败: %w", err)
  153. }
  154. _ = RemoveChatRecord(chatId)
  155. //重新加载数据
  156. _ = flushRecordsToRedis(chatId)
  157. }
  158. return nil
  159. }
  160. // SaveAllChatRecordsToDB 定时任务保存所有 Redis 中的聊天记录到数据库
  161. func SaveAllChatRecordsToDB() {
  162. for {
  163. keys, err := utils.Rc.Keys(redisChatPrefix + "*")
  164. if err != nil {
  165. utils.FileLog.Error("获取 Redis 键失败: %v", err)
  166. return
  167. }
  168. var wg sync.WaitGroup
  169. wg.Add(len(keys))
  170. for _, key := range keys {
  171. go func(key string) {
  172. defer wg.Done()
  173. chatIdStr := strings.TrimPrefix(key, redisChatPrefix)
  174. chatId, parseErr := strconv.Atoi(chatIdStr)
  175. if parseErr != nil {
  176. utils.FileLog.Error("解析聊天ID失败: %v", err)
  177. return
  178. }
  179. if err = SaveChatRecordsToDB(chatId); err != nil {
  180. utils.FileLog.Error("解析聊天ID失败: %v", err)
  181. }
  182. }(key)
  183. }
  184. wg.Wait()
  185. time.Sleep(10 * time.Second)
  186. }
  187. }
  188. // RemoveChatRecord 从 Redis 删除聊天记录
  189. func RemoveChatRecord(chatId int) error {
  190. key := fmt.Sprintf("%s%d", redisChatPrefix, chatId)
  191. err := utils.Rc.Delete(key)
  192. if err != nil {
  193. return fmt.Errorf("删除 Redis 缓存失败: %w", err)
  194. }
  195. return nil
  196. }
  197. func GetChatRecordsFromDB(chatId int) ([]*llm.UserChatRecord, error) {
  198. o := global.DbMap[utils.DbNameAI]
  199. var records []*llm.UserChatRecord
  200. if err := o.Where("chat_id = ?", chatId).Find(&records).Error; err != nil {
  201. return nil, fmt.Errorf("从数据库获取聊天记录失败: %w", err)
  202. }
  203. return records, nil
  204. }