浏览代码

Merge branch 'feature/deepseek_rag_1.0' into debug

kobe6258 2 周之前
父节点
当前提交
aa18be0ef3

+ 3 - 45
controllers/rag/chat_controller.go → controllers/rag/chat_ws_controller.go

@@ -1,10 +1,8 @@
 package rag
 
 import (
-	"encoding/json"
 	"eta/eta_api/controllers"
 	"eta/eta_api/models"
-	"eta/eta_api/models/llm"
 	"eta/eta_api/models/system"
 	"eta/eta_api/services/llm/facade"
 	"eta/eta_api/utils"
@@ -17,11 +15,11 @@ import (
 	"time"
 )
 
-type ChatController struct {
+type ChatWsController struct {
 	controllers.BaseAuthController
 }
 
-func (cc *ChatController) Prepare() {
+func (cc *ChatWsController) Prepare() {
 	method := cc.Ctx.Input.Method()
 	uri := cc.Ctx.Input.URI()
 	if method == "GET" {
@@ -142,51 +140,11 @@ func (cc *ChatController) Prepare() {
 	}
 }
 
-// NewChat @Title 新建对话框
-// @Description 新建对话框
-// @Success 101 {object} response.ListResp
-// @router /chat/new_chat [post]
-func (kbctrl *KbController) NewChat() {
-	br := new(models.BaseResponse).Init()
-	defer func() {
-		kbctrl.Data["json"] = br
-		kbctrl.ServeJSON()
-	}()
-	var req facade.LLMKnowledgeSearch
-	err := json.Unmarshal(kbctrl.Ctx.Input.RequestBody, &req)
-	if err != nil {
-		br.Msg = "参数解析异常!"
-		br.ErrMsg = "参数解析失败,Err:" + err.Error()
-		return
-	}
-	sysUser := kbctrl.SysUser
-	if sysUser == nil {
-		br.Msg = "请登录"
-		br.ErrMsg = "请登录,SysUser Is Empty"
-		br.Ret = 408
-		return
-	}
-	session := llm.UserLlmChat{
-		UserId:      sysUser.AdminId,
-		CreatedTime: time.Now(),
-		ChatTitle:   "新会话",
-	}
-	err = session.CreateChatSession()
-	if err != nil {
-		br.Msg = "创建失败"
-		br.ErrMsg = "创建失败,Err:" + err.Error()
-		return
-	}
-	br.Ret = 200
-	br.Success = true
-	br.Msg = "创建成功"
-}
-
 // ChatConnect @Title 知识库问答创建对话连接
 // @Description 知识库问答创建对话连接
 // @Success 101 {object} response.ListResp
 // @router /chat/connect [get]
-func (cc *ChatController) ChatConnect() {
+func (cc *ChatWsController) ChatConnect() {
 	if !ws.Allow(cc.SysUser.AdminId, ws.CONNECT_LIMITER) {
 		utils.FileLog.Error("WebSocket连接太频繁,主动拒绝链接")
 		cc.Ctx.ResponseWriter.WriteHeader(http.StatusTooManyRequests)

+ 10 - 1
controllers/rag/llm_http/request.go

@@ -6,6 +6,15 @@ type LLMQuestionReq struct {
 	SessionId     string `description:"会话ID"`
 }
 
-type CreateChatReq struct {
+type UserChatReq struct {
+	ChatId    int    `json:"ChatId"`
 	ChatTitle string `json:"ChatTitle" description:"会话名称"`
 }
+
+type UserChatRecordReq struct {
+	Id           int    `json:"Id"`
+	ChatId       int    `json:"ChatId"`
+	Content      string `json:"Content" description:"会话名称"`
+	ChatUserType string `json:"ChatUserType" description:"用户类型"`
+	SendTime     string `json:"SendTime" description:"发送时间"`
+}

+ 5 - 3
controllers/rag/llm_http/response.go

@@ -1,7 +1,9 @@
 package llm_http
 
+import "eta/eta_api/models/llm"
 
-type LLMQuestionRes struct {
-	Answer      string `description:"回答"`
-	SessionId     string `description:"会话ID"`
+type UserChatListResp struct {
+	TodayList     []llm.UserLlmChatListViewItem
+	YesterdayList []llm.UserLlmChatListViewItem
+	WeekList      []llm.UserLlmChatListViewItem
 }

+ 1 - 1
controllers/rag/question.go

@@ -26,7 +26,7 @@ type QuestionController struct {
 // @Param   PageSize   query   int  true       "每页数据条数"
 // @Param   CurrentIndex   query   int  true       "当前页页码,从1开始"
 // @Param   KeyWord   query   string  true       "搜索关键词"
-// @Success 200 {object} []*rag.WechatPlatform
+// @Success 200 {object} []*rag.QuestionListListResp
 // @router /question/list [get]
 func (c *QuestionController) List() {
 	br := new(models.BaseResponse).Init()

+ 269 - 0
controllers/rag/user_chat_controller.go

@@ -0,0 +1,269 @@
+package rag
+
+import (
+	"encoding/json"
+	"eta/eta_api/controllers"
+	"eta/eta_api/controllers/rag/llm_http"
+	"eta/eta_api/models"
+	"eta/eta_api/models/llm"
+	llmService "eta/eta_api/services/llm"
+	"eta/eta_api/utils"
+	"time"
+)
+
+type UserChatController struct {
+	controllers.BaseAuthController
+}
+
+// NewChat @Title 新建对话框
+// @Description 新建对话框
+// @Success 101 {object} response.ListResp
+// @router /chat/new_chat [post]
+func (ucCtrl *UserChatController) NewChat() {
+	br := new(models.BaseResponse).Init()
+	defer func() {
+		ucCtrl.Data["json"] = br
+		ucCtrl.ServeJSON()
+	}()
+	var req llm_http.UserChatReq
+	err := json.Unmarshal(ucCtrl.Ctx.Input.RequestBody, &req)
+	if err != nil {
+		br.Msg = "参数解析异常!"
+		br.ErrMsg = "参数解析失败,Err:" + err.Error()
+		return
+	}
+	sysUser := ucCtrl.SysUser
+	if sysUser == nil {
+		br.Msg = "请登录"
+		br.ErrMsg = "请登录,SysUser Is Empty"
+		br.Ret = 408
+		return
+	}
+	if req.ChatTitle == "" {
+		req.ChatTitle = "新会话"
+	}
+	session := llm.UserLlmChat{
+		UserId:      sysUser.AdminId,
+		CreatedTime: time.Now(),
+		ChatTitle:   req.ChatTitle,
+	}
+	err = session.CreateChatSession()
+	if err != nil {
+		br.Msg = "创建失败"
+		br.ErrMsg = "创建失败,Err:" + err.Error()
+		return
+	}
+	br.Ret = 200
+	br.Success = true
+	br.Msg = "创建成功"
+}
+
+// RenameChat @Title 新建对话框
+// @Description 新建对话框
+// @Success 101 {object} response.ListResp
+// @router /chat/rename_chat [post]
+func (ucCtrl *UserChatController) RenameChat() {
+	br := new(models.BaseResponse).Init()
+	defer func() {
+		ucCtrl.Data["json"] = br
+		ucCtrl.ServeJSON()
+	}()
+	var req llm_http.UserChatReq
+	err := json.Unmarshal(ucCtrl.Ctx.Input.RequestBody, &req)
+	if err != nil {
+		br.Msg = "参数解析异常!"
+		br.ErrMsg = "参数解析失败,Err:" + err.Error()
+		return
+	}
+	sysUser := ucCtrl.SysUser
+	if sysUser == nil {
+		br.Msg = "请登录"
+		br.ErrMsg = "请登录,SysUser Is Empty"
+		br.Ret = 408
+		return
+	}
+	if req.ChatId <= 0 {
+		br.Msg = "非法的对话框Id"
+		br.ErrMsg = "非法的对话框Id"
+		return
+	}
+	if req.ChatTitle == "" {
+		br.Msg = "重命名不能为空"
+		br.ErrMsg = "重命名不能为空"
+		return
+	}
+	session := llm.UserLlmChat{
+		Id:         req.ChatId,
+		UpdateTime: time.Now(),
+		UserId:     sysUser.AdminId,
+		ChatTitle:  req.ChatTitle,
+	}
+	err = session.RenameChatSession()
+	if err != nil {
+		br.Msg = "重命名失败"
+		br.ErrMsg = "重命名失败,Err:" + err.Error()
+		return
+	}
+	br.Ret = 200
+	br.Success = true
+	br.Msg = "重命名成功"
+}
+
+// GetUserChatList @Title 获取用户对话框列表
+// @Description  获取用户对话框列表
+// @Success 101 {object} response.ListResp
+// @router /chat/user_chat_list [get]
+func (ucCtrl *UserChatController) GetUserChatList() {
+	br := new(models.BaseResponse).Init()
+	defer func() {
+		ucCtrl.Data["json"] = br
+		ucCtrl.ServeJSON()
+	}()
+	sysUser := ucCtrl.SysUser
+	if sysUser == nil {
+		br.Msg = "请登录"
+		br.ErrMsg = "请登录,SysUser Is Empty"
+		br.Ret = 408
+		return
+	}
+	//周日是0,周六是6
+	weekDay := time.Now().Weekday()
+	offset := int(time.Monday - weekDay)
+	if offset > 0 {
+		offset -= 7
+	}
+	today := time.Now().Format(utils.FormatDate)
+	monDay := time.Now().AddDate(0, 0, offset).Format(utils.FormatDate)
+	yesterday := time.Now().AddDate(0, 0, -1).Format(utils.FormatDate)
+	chatList, err := llm.GetUserChatList(sysUser.AdminId, monDay, time.Now().Format(utils.FormatDate))
+	if err != nil {
+		br.Msg = "获取用户聊天列表失败"
+		br.ErrMsg = "获取用户聊天列表失败,Err:" + err.Error()
+		return
+	}
+	data := new(llm_http.UserChatListResp)
+	data.WeekList = make([]llm.UserLlmChatListViewItem, 0)
+	data.YesterdayList = make([]llm.UserLlmChatListViewItem, 0)
+	data.TodayList = make([]llm.UserLlmChatListViewItem, 0)
+	for _, v := range chatList {
+		if v.CreatedTime.Format(utils.FormatDate) == today {
+			data.TodayList = append(data.TodayList, llm.CovertItemToView(v))
+		} else if v.CreatedTime.Format(utils.FormatDate) == yesterday {
+			data.YesterdayList = append(data.YesterdayList, llm.CovertItemToView(v))
+		} else {
+			data.WeekList = append(data.WeekList, llm.CovertItemToView(v))
+		}
+	}
+
+	br.Data = data
+	br.Ret = 200
+	br.Success = true
+	br.Msg = "获取用户聊天列表成功"
+}
+
+// ChatRecordAdd @Title 保存聊天记录
+// @Description 保存聊天记录
+// @Success 101 {object} response.ListResp
+// @router /chat/chat_record_save [post]
+func (ucCtrl *UserChatController) ChatRecordAdd() {
+	br := new(models.BaseResponse).Init()
+	defer func() {
+		ucCtrl.Data["json"] = br
+		ucCtrl.ServeJSON()
+	}()
+	var req llm_http.UserChatRecordReq
+	err := json.Unmarshal(ucCtrl.Ctx.Input.RequestBody, &req)
+	if err != nil {
+		br.Msg = "参数解析异常!"
+		br.ErrMsg = "参数解析失败,Err:" + err.Error()
+		return
+	}
+	sysUser := ucCtrl.SysUser
+	if sysUser == nil {
+		br.Msg = "请登录"
+		br.ErrMsg = "请登录,SysUser Is Empty"
+		br.Ret = 408
+		return
+	}
+	if req.ChatId <= 0 {
+		br.Msg = "非法的对话框Id"
+		br.ErrMsg = "非法的对话框Id"
+		return
+	}
+	if req.Content == "" {
+		br.Msg = "聊天记录不能为空"
+		br.ErrMsg = "聊天记录不能为空"
+		return
+	}
+	if req.Id < 0 {
+		br.Msg = "非法的Id"
+		br.ErrMsg = "非法的Id"
+		return
+	}
+	if req.ChatUserType != "user" && req.ChatUserType != "assistant" {
+		br.Msg = "非法的用户类型"
+		br.ErrMsg = "非法的用户类型,用户类型支持:user/assistant"
+		return
+	}
+	if req.SendTime == "" {
+		req.SendTime = time.Now().Format(utils.FormatDateTime)
+	} else {
+		_, err = time.Parse(utils.FormatDateTime, req.SendTime)
+		if err != nil {
+			br.Msg = "非法的发送时间"
+			br.ErrMsg = "非法的发送时间,Err:" + err.Error()
+			return
+		}
+	}
+	record := llm.UserChatRecordRedis{
+		ChatId:       req.ChatId,
+		ChatUserType: req.ChatUserType,
+		Content:      req.Content,
+		SendTime:     req.SendTime,
+	}
+	err = llmService.AddChatRecord(&record)
+	if err != nil {
+		br.Msg = "添加聊天记录失败"
+		br.ErrMsg = "添加聊天记录失败,Err:" + err.Error()
+		return
+	}
+	br.Ret = 200
+	br.Success = true
+	br.Msg = "添加聊天记录成功"
+}
+
+// ChatRecordList @Title 获取聊天记录
+// @Description 获取聊天记录
+// @Success 101 {object} response.ListResp
+// @router /chat/chat_record_list [get]
+func (ucCtrl *UserChatController) ChatRecordList() {
+	br := new(models.BaseResponse).Init()
+	defer func() {
+		ucCtrl.Data["json"] = br
+		ucCtrl.ServeJSON()
+	}()
+	sysUser := ucCtrl.SysUser
+	if sysUser == nil {
+		br.Msg = "请登录"
+		br.ErrMsg = "请登录,SysUser Is Empty"
+		br.Ret = 408
+		return
+	}
+	chatId, _ := ucCtrl.GetInt("ChatId", 0)
+	if chatId <= 0 {
+		br.Msg = "非法的对话Id"
+		br.ErrMsg = "非法的对话Id"
+		return
+	}
+
+	list, err := llmService.GetChatRecordsFromRedis(chatId)
+	if err != nil {
+		br.Msg = "获取聊天记录失败"
+		br.ErrMsg = "获取聊天记录失败,Err:" + err.Error()
+		return
+	}
+	br.Data = list
+	br.Ret = 200
+	br.Success = true
+	br.Msg = "获取聊天记录成功"
+}

+ 40 - 0
models/llm/user_chat_record.go

@@ -0,0 +1,40 @@
+package llm
+
+import (
+	"eta/eta_api/global"
+	"eta/eta_api/utils"
+	"gorm.io/gorm"
+	"gorm.io/gorm/clause"
+	"time"
+)
+
+// UserChatRecord 定义用户聊天记录结构体
+type UserChatRecord struct {
+	Id           int       `gorm:"primaryKey;autoIncrement;comment:主键"`
+	ChatId       int       `gorm:"chat_id;comment:会话id"`
+	ChatUserType string    `gorm:"type:enum('user','assistant');comment:用户方"`
+	Content      string    `gorm:"content:内容"`
+	SendTime     time.Time `gorm:"comment:发送时间"`
+	CreatedTime  time.Time `gorm:"comment:创建时间"`
+	UpdateTime   time.Time `gorm:"autoUpdateTime;comment:更新时间"`
+}
+type UserChatRecordRedis struct {
+	Id           int
+	ChatId       int
+	ChatUserType string
+	Content      string
+	SendTime     string
+}
+
+func (u *UserChatRecord) TableName() string {
+	return "user_chat_record"
+}
+
+func BatchInsertRecords(list []*UserChatRecord) (err error) {
+	o := global.DbMap[utils.DbNameAI]
+	err = o.Clauses(clause.OnConflict{
+		Columns:   []clause.Column{{Name: "chat_id"}, {Name: "chat_user_type"}, {Name: "send_time"}},
+		DoUpdates: clause.Assignments(map[string]interface{}{"update_time": gorm.Expr("VALUES(update_time)")}),
+	}).CreateInBatches(list, utils.MultiAddNum).Error
+	return
+}

+ 48 - 0
models/llm/user_llm_chat.go

@@ -1,6 +1,7 @@
 package llm
 
 import (
+	"errors"
 	"eta/eta_api/global"
 	"eta/eta_api/utils"
 	"time"
@@ -15,6 +16,31 @@ type UserLlmChat struct {
 	UpdateTime  time.Time `gorm:"autoUpdateTime;comment:更新时间"`
 }
 
+type UserLlmChatListItem struct {
+	Id          int       `gorm:"primaryKey;autoIncrement;comment:会话主键"`
+	UserId      int       `gorm:"comment:用户id"`
+	ChatTitle   string    `gorm:"comment:会话标题"`
+	CreatedTime time.Time `gorm:"comment:创建时间"`
+	RecordCount int       `gorm:"comment:会话记录数"`
+}
+type UserLlmChatListViewItem struct {
+	Id          int    `gorm:"primaryKey;autoIncrement;comment:会话主键"`
+	UserId      int    `gorm:"comment:用户id"`
+	ChatTitle   string `gorm:"comment:会话标题"`
+	CreatedTime string `gorm:"comment:创建时间"`
+	RecordCount int    `gorm:"comment:会话记录数"`
+}
+
+func CovertItemToView(item UserLlmChatListItem) UserLlmChatListViewItem {
+	return UserLlmChatListViewItem{
+		Id:          item.Id,
+		UserId:      item.UserId,
+		ChatTitle:   item.ChatTitle,
+		CreatedTime: item.CreatedTime.Format(utils.FormatDateTime),
+		RecordCount: item.RecordCount,
+	}
+
+}
 func (u *UserLlmChat) TableName() string {
 	return "user_llm_chat"
 }
@@ -23,3 +49,25 @@ func (u *UserLlmChat) CreateChatSession() (err error) {
 	err = o.Create(u).Error
 	return
 }
+func (u *UserLlmChat) RenameChatSession() (err error) {
+	o := global.DbMap[utils.DbNameAI]
+	var exists bool
+	err = o.Model(&u).Select("1").Where("id = ?", u.Id).Scan(&exists).Error
+	if err != nil {
+		return
+	}
+	if !exists {
+		err = errors.New("当前会话不存在")
+		return
+	}
+	err = o.Select("chat_title").Updates(u).Error
+	return
+}
+
+func GetUserChatList(userId int, monDay, toDay string) (chatList []UserLlmChatListItem, err error) {
+	o := global.DbMap[utils.DbNameAI]
+	sql := `select ulc.id AS id ,ulc.user_id as user_id,ulc.chat_title as chat_title,ulc.created_time,COUNT(ucr.id) AS record_count from user_llm_chat ulc left join user_chat_record ucr
+    ON ucr.chat_id = ulc.id where ulc.user_id=? and ` + utils.GenerateQuerySql(utils.ToDate, &utils.QueryParam{Column: "ulc.created_time"}) + ` BETWEEN ? and ? GROUP BY ulc.id order by ulc.created_time desc`
+	err = o.Raw(sql, userId, monDay, toDay).Find(&chatList).Error
+	return
+}

+ 6 - 6
models/rag/wechat_article.go

@@ -19,9 +19,9 @@ type WechatArticle struct {
 	Content           string    `gorm:"column:content;type:longtext;comment:报告详情;" description:"报告详情"`
 	TextContent       string    `gorm:"column:text_content;type:text;comment:文本内容;" description:"文本内容"`
 	Abstract          string    `gorm:"column:abstract;type:text;comment:摘要;" description:"摘要"`
-	Country           string    `gorm:"column:country;type:varchar(255);comment:国家;" json:"country"`  // 国家
-	Province          string    `gorm:"column:province;type:varchar(255);comment:省;" json:"province"` // 省
-	City              string    `gorm:"column:city;type:varchar(255);comment:市;" json:"city"`         // 市
+	Country           string    `gorm:"column:country;type:varchar(255);comment:国家;" description:"国家"`
+	Province          string    `gorm:"column:province;type:varchar(255);comment:省;" description:"省"`
+	City              string    `gorm:"column:city;type:varchar(255);comment:市;" description:"市"`
 	ArticleCreateTime time.Time `gorm:"column:article_create_time;type:datetime;comment:报告创建时间;default:NULL;" description:"报告创建时间"`
 	IsDeleted         int       `gorm:"column:is_deleted;type:tinyint(4);comment:是否删除,0:未删除,1: 已删除;default:0;" description:"是否删除,0:未删除,1: 已删除"`
 	ModifyTime        time.Time `gorm:"column:modify_time;type:datetime;comment:修改时间;default:NULL;" description:"修改时间"`
@@ -83,9 +83,9 @@ type WechatArticleView struct {
 	Content                    string `gorm:"column:content;type:longtext;comment:报告详情;" description:"报告详情"`
 	TextContent                string `gorm:"column:text_content;type:text;comment:文本内容;" description:"文本内容"`
 	Abstract                   string `gorm:"column:abstract;type:text;comment:摘要;" description:"摘要"`
-	Country                    string `gorm:"column:country;type:varchar(255);comment:国家;" json:"country"`  // 国家
-	Province                   string `gorm:"column:province;type:varchar(255);comment:省;" json:"province"` // 省
-	City                       string `gorm:"column:city;type:varchar(255);comment:市;" json:"city"`         // 市
+	Country                    string `gorm:"column:country;type:varchar(255);comment:国家;" description:"国家"`
+	Province                   string `gorm:"column:province;type:varchar(255);comment:省;" description:"省"`
+	City                       string `gorm:"column:city;type:varchar(255);comment:市;" description:"市"`
 	ArticleCreateTime          string `gorm:"column:article_create_time;type:datetime;comment:报告创建时间;default:NULL;" description:"报告创建时间"`
 	ModifyTime                 string `gorm:"column:modify_time;type:datetime;comment:修改时间;default:NULL;" description:"修改时间"`
 	CreateTime                 string `gorm:"column:create_time;type:datetime;comment:入库时间;default:NULL;" description:"入库时间"`

+ 48 - 11
routers/commentsRouter.go

@@ -8557,7 +8557,7 @@ func init() {
             Filters: nil,
             Params: nil})
 
-    beego.GlobalControllerRouter["eta/eta_api/controllers/rag:ChatController"] = append(beego.GlobalControllerRouter["eta/eta_api/controllers/rag:ChatController"],
+    beego.GlobalControllerRouter["eta/eta_api/controllers/rag:ChatWsController"] = append(beego.GlobalControllerRouter["eta/eta_api/controllers/rag:ChatWsController"],
         beego.ControllerComments{
             Method: "ChatConnect",
             Router: `/chat/connect`,
@@ -8566,15 +8566,6 @@ func init() {
             Filters: nil,
             Params: nil})
 
-    beego.GlobalControllerRouter["eta/eta_api/controllers/rag:KbController"] = append(beego.GlobalControllerRouter["eta/eta_api/controllers/rag:KbController"],
-        beego.ControllerComments{
-            Method: "NewChat",
-            Router: `/chat/new_chat`,
-            AllowHTTPMethods: []string{"post"},
-            MethodParams: param.Make(),
-            Filters: nil,
-            Params: nil})
-
     beego.GlobalControllerRouter["eta/eta_api/controllers/rag:KbController"] = append(beego.GlobalControllerRouter["eta/eta_api/controllers/rag:KbController"],
         beego.ControllerComments{
             Method: "SearchDocs",
@@ -8620,6 +8611,51 @@ func init() {
             Filters: nil,
             Params: nil})
 
+    beego.GlobalControllerRouter["eta/eta_api/controllers/rag:UserChatController"] = append(beego.GlobalControllerRouter["eta/eta_api/controllers/rag:UserChatController"],
+        beego.ControllerComments{
+            Method: "ChatRecordList",
+            Router: `/chat/chat_record_list`,
+            AllowHTTPMethods: []string{"get"},
+            MethodParams: param.Make(),
+            Filters: nil,
+            Params: nil})
+
+    beego.GlobalControllerRouter["eta/eta_api/controllers/rag:UserChatController"] = append(beego.GlobalControllerRouter["eta/eta_api/controllers/rag:UserChatController"],
+        beego.ControllerComments{
+            Method: "ChatRecordAdd",
+            Router: `/chat/chat_record_save`,
+            AllowHTTPMethods: []string{"post"},
+            MethodParams: param.Make(),
+            Filters: nil,
+            Params: nil})
+
+    beego.GlobalControllerRouter["eta/eta_api/controllers/rag:UserChatController"] = append(beego.GlobalControllerRouter["eta/eta_api/controllers/rag:UserChatController"],
+        beego.ControllerComments{
+            Method: "NewChat",
+            Router: `/chat/new_chat`,
+            AllowHTTPMethods: []string{"post"},
+            MethodParams: param.Make(),
+            Filters: nil,
+            Params: nil})
+
+    beego.GlobalControllerRouter["eta/eta_api/controllers/rag:UserChatController"] = append(beego.GlobalControllerRouter["eta/eta_api/controllers/rag:UserChatController"],
+        beego.ControllerComments{
+            Method: "RenameChat",
+            Router: `/chat/rename_chat`,
+            AllowHTTPMethods: []string{"post"},
+            MethodParams: param.Make(),
+            Filters: nil,
+            Params: nil})
+
+    beego.GlobalControllerRouter["eta/eta_api/controllers/rag:UserChatController"] = append(beego.GlobalControllerRouter["eta/eta_api/controllers/rag:UserChatController"],
+        beego.ControllerComments{
+            Method: "GetUserChatList",
+            Router: `/chat/user_chat_list`,
+            AllowHTTPMethods: []string{"get"},
+            MethodParams: param.Make(),
+            Filters: nil,
+            Params: nil})
+
     beego.GlobalControllerRouter["eta/eta_api/controllers/rag:WechatPlatformController"] = append(beego.GlobalControllerRouter["eta/eta_api/controllers/rag:WechatPlatformController"],
         beego.ControllerComments{
             Method: "TagList",
@@ -13281,4 +13317,5 @@ func init() {
             MethodParams: param.Make(),
             Filters: nil,
             Params: nil})
-}
+
+}

+ 3 - 4
routers/router.go

@@ -72,13 +72,12 @@ func init() {
 		),
 		web.NSNamespace("/llm",
 			web.NSInclude(
-				&rag.ChatController{},
+				&rag.ChatWsController{},
+				&rag.UserChatController{},
+				&rag.KbController{},
 				&rag.WechatPlatformController{},
 				&rag.QuestionController{},
 			),
-			web.NSInclude(
-				&rag.KbController{},
-			),
 		),
 		web.NSNamespace("/banner",
 			web.NSInclude(

+ 218 - 0
services/llm/chat_service.go

@@ -0,0 +1,218 @@
+package llm
+
+import (
+	"encoding/json"
+	"eta/eta_api/global"
+	"eta/eta_api/models/llm"
+	"eta/eta_api/utils"
+	"eta/eta_api/utils/lock"
+	"eta/eta_api/utils/redis"
+	"fmt"
+	"github.com/google/uuid"
+	"strconv"
+	"strings"
+	"sync"
+	"time"
+)
+
+const (
+	redisChatPrefix = "chat:zet:"
+	redisTTL        = 24 * time.Hour // Redis 缓存过期时间
+)
+
+// AddChatRecord 添加聊天记录到 Redis
+func AddChatRecord(record *llm.UserChatRecordRedis) error {
+	key := fmt.Sprintf("%s%d", redisChatPrefix, record.ChatId)
+	holder, _ := uuid.NewRandom()
+	holderStr := fmt.Sprintf("user_%s", holder.String())
+	if lock.AcquireLock(key, 10, holderStr) {
+		defer func() {
+			fmt.Printf("用户释放锁:%s", key)
+			lock.ReleaseLock(key, holderStr)
+		}()
+		data, err := json.Marshal(record)
+		if err != nil {
+			return fmt.Errorf("序列化聊天记录失败: %w", err)
+		}
+		zSet, _ := utils.Rc.ZRangeWithScores(key)
+		if len(zSet) == 0 {
+			// 设置过期时间
+			_ = utils.Rc.Expire(key, 24*time.Hour)
+		}
+		zSet = append(zSet, &redis.Zset{
+			Member: data,
+			Score:  float64(time.Now().Unix()),
+		})
+
+		err = utils.Rc.ZAdd(key, zSet...)
+		if err != nil {
+			return fmt.Errorf("保存聊天记录到 Redis 失败: %w", err)
+		}
+		return nil
+	}
+	return fmt.Errorf("获取锁失败,请稍后重试")
+}
+
+// GetChatRecordsFromRedis 从 Redis 获取聊天记录
+func GetChatRecordsFromRedis(chatId int) (redisList []*llm.UserChatRecordRedis, err error) {
+	key := fmt.Sprintf("%s%d", redisChatPrefix, chatId)
+	zSet, _ := utils.Rc.ZRangeWithScores(key)
+	if len(zSet) == 0 {
+		// 缓存不存在,从数据库拉取数据
+		records, dbErr := GetChatRecordsFromDB(chatId)
+		if dbErr != nil {
+			err = fmt.Errorf("从数据库获取聊天记录失败: %w", dbErr)
+			return
+		}
+		// 将数据保存到 Redis
+		for _, record := range records {
+			redisRecord := &llm.UserChatRecordRedis{
+				Id:           record.Id,
+				ChatId:       chatId,
+				ChatUserType: record.ChatUserType,
+				Content:      record.Content,
+				SendTime:     record.SendTime.Format(utils.FormatDateTime),
+			}
+			redisList = append(redisList, redisRecord)
+		}
+		return
+	}
+	for _, z := range zSet {
+		var redisRecord llm.UserChatRecordRedis
+		if err = json.Unmarshal([]byte(z.Member.(string)), &redisRecord); err != nil {
+			return nil, fmt.Errorf("解析聊天记录失败: %w", err)
+		}
+		redisList = append(redisList, &redisRecord)
+	}
+	return
+}
+
+func flushRecordsToRedis(chatId int) (err error) {
+	key := fmt.Sprintf("%s%d", redisChatPrefix, chatId)
+	zSet, _ := utils.Rc.ZRangeWithScores(key)
+	if len(zSet) == 0 {
+		// 缓存不存在,从数据库拉取数据
+		records, dbErr := GetChatRecordsFromDB(chatId)
+		if dbErr != nil {
+			err = fmt.Errorf("从数据库获取聊天记录失败: %w", dbErr)
+			return
+		}
+		var zet []*redis.Zset
+		// 将数据保存到 Redis
+		for _, record := range records {
+			redisRecord := &llm.UserChatRecordRedis{
+				Id:           record.Id,
+				ChatId:       chatId,
+				ChatUserType: record.ChatUserType,
+				Content:      record.Content,
+				SendTime:     record.SendTime.Format(utils.FormatDateTime),
+			}
+			data, parseErr := json.Marshal(&redisRecord)
+			if parseErr != nil {
+				utils.FileLog.Error("解析聊天记录失败: %w", err)
+			}
+			zet = append(zet, &redis.Zset{
+				Member: data,
+				Score:  float64(record.SendTime.Unix()),
+			})
+		}
+		_ = utils.Rc.ZAdd(key, zet...)
+	}
+	return
+}
+
+// SaveChatRecordsToDB 将 Redis 中的聊天记录保存到数据库
+func SaveChatRecordsToDB(chatId int) error {
+	list, err := GetChatRecordsFromRedis(chatId)
+	if err != nil {
+		return err
+	}
+	var newRecords []*llm.UserChatRecord
+	for _, record := range list {
+		if record.Id == 0 {
+			sendTime, parseErr := time.Parse(utils.FormatDateTime, record.SendTime)
+			if parseErr != nil {
+				sendTime = time.Now()
+			}
+			newRecords = append(newRecords, &llm.UserChatRecord{
+				Id:           record.Id,
+				ChatId:       record.ChatId,
+				ChatUserType: record.ChatUserType,
+				Content:      record.Content,
+				SendTime:     sendTime,
+				CreatedTime:  time.Now(),
+			})
+		}
+	}
+	key := fmt.Sprintf("%s%d", redisChatPrefix, chatId)
+	holder, _ := uuid.NewRandom()
+	holderStr := fmt.Sprintf("sys_%s", holder.String())
+	defer func() {
+		fmt.Printf("系统释放锁:%s", key)
+		lock.ReleaseLock(key, holderStr)
+	}()
+	if lock.AcquireLock(key, 10, holderStr) {
+		//先删除redis中的缓存
+		_ = RemoveChatRecord(chatId)
+		err = llm.BatchInsertRecords(newRecords)
+		if err != nil {
+			utils.FileLog.Error("批量插入记录失败:", err.Error())
+			return fmt.Errorf("批量插入记录失败: %w", err)
+		}
+		_ = RemoveChatRecord(chatId)
+		//重新加载数据
+		_ = flushRecordsToRedis(chatId)
+	}
+	return nil
+}
+
+// SaveAllChatRecordsToDB 定时任务保存所有 Redis 中的聊天记录到数据库
+func SaveAllChatRecordsToDB() {
+	for {
+		fmt.Println("开始保存聊天记录到数据库...")
+		keys, err := utils.Rc.Keys(redisChatPrefix + "*")
+		if err != nil {
+			utils.FileLog.Error("获取 Redis 键失败: %v", err)
+			return
+		}
+		var wg sync.WaitGroup
+		wg.Add(len(keys))
+		for _, key := range keys {
+			go func(key string) {
+				defer wg.Done()
+				chatIdStr := strings.TrimPrefix(key, redisChatPrefix)
+				chatId, parseErr := strconv.Atoi(chatIdStr)
+				if parseErr != nil {
+					utils.FileLog.Error("解析聊天ID失败: %v", err)
+					return
+				}
+				if err = SaveChatRecordsToDB(chatId); err != nil {
+					utils.FileLog.Error("解析聊天ID失败: %v", err)
+				}
+				fmt.Println("保存聊天记录到数据库完成")
+			}(key)
+		}
+		wg.Wait()
+		fmt.Printf("计划任务完成")
+		time.Sleep(10 * time.Second)
+	}
+}
+
+// RemoveChatRecord 从 Redis 删除聊天记录
+func RemoveChatRecord(chatId int) error {
+	key := fmt.Sprintf("%s%d", redisChatPrefix, chatId)
+	err := utils.Rc.Delete(key)
+	if err != nil {
+		return fmt.Errorf("删除 Redis 缓存失败: %w", err)
+	}
+	return nil
+}
+
+func GetChatRecordsFromDB(chatId int) ([]*llm.UserChatRecord, error) {
+	o := global.DbMap[utils.DbNameAI]
+	var records []*llm.UserChatRecord
+	if err := o.Where("chat_id = ?", chatId).Find(&records).Error; err != nil {
+		return nil, fmt.Errorf("从数据库获取聊天记录失败: %w", err)
+	}
+	return records, nil
+}

+ 1 - 1
services/llm/wechat_platform.go

@@ -147,7 +147,7 @@ func AddWechatArticle(item *rag.WechatPlatform, articleLink string, articleDetai
 		obj.Title = articleMenu.Title
 		//obj.Link = articleMenu.Link
 		obj.CoverUrl = articleMenu.Cover
-		obj.Abstract = articleMenu.Digest
+		obj.Description = articleMenu.Digest
 	}
 	err = obj.Create()
 }

+ 2 - 0
services/task.go

@@ -5,6 +5,7 @@ import (
 	"eta/eta_api/services/binlog"
 	"eta/eta_api/services/data"
 	edbmonitor "eta/eta_api/services/edb_monitor"
+	"eta/eta_api/services/llm"
 	"eta/eta_api/utils"
 	"fmt"
 	"strings"
@@ -61,6 +62,7 @@ func Task() {
 		go binlog.HandleDataSourceChange2Es()
 	}
 	go StartSessionManager()
+	go llm.SaveAllChatRecordsToDB()
 	// TODO:数据修复
 	//FixNewEs()
 	fmt.Println("task end")

+ 62 - 0
utils/lock/distrubtLock.go

@@ -0,0 +1,62 @@
+package lock
+
+import (
+	"context"
+	"eta/eta_api/utils"
+	"fmt"
+	"github.com/go-redis/redis/v8"
+)
+
+const (
+	lockName = "lock:"
+)
+
+var (
+	ctx = context.Background()
+)
+
+func AcquireLock(key string, expiration int, Holder string) bool {
+	script := redis.NewScript(`local key = KEYS[1]
+			local clientId = ARGV[1]
+			local expiration = tonumber(ARGV[2])
+			if redis.call("EXISTS", key) == 0 then
+				redis.call("SET", key, clientId, "EX", expiration)
+				return 1
+			else
+				return 0
+			end`)
+	fmt.Println("lockName:"+lockName+key, "Holder:"+Holder, "expiration:"+fmt.Sprintf("%d", expiration))
+	lockey := fmt.Sprintf("%s%s", lockName, key)
+	result, err := script.Run(ctx, utils.Rc.RedisClient(), []string{lockey}, Holder, expiration).Int()
+	if err != nil {
+		fmt.Printf("加锁失败:err: %v", err)
+		return false
+	}
+	if result == 1 {
+		return true
+	}
+	fmt.Printf("加锁失败:")
+	return false
+}
+
+func ReleaseLock(key string, holder string) bool {
+	script := redis.NewScript(`
+	   if redis.call("get", KEYS[1]) == ARGV[1] then
+	       return redis.call("del", KEYS[1])
+	   else
+	       return 0
+	   end
+	`)
+	fmt.Println("lockName:"+lockName+key, "Holder:"+holder)
+	lockey := fmt.Sprintf("%s%s", lockName, key)
+	result, err := script.Run(ctx, utils.Rc.RedisClient(), []string{lockey}, holder).Int()
+	if err != nil {
+		fmt.Printf("解锁失败:err: %v", err)
+		return false
+	}
+	if result == 1 {
+		return true
+	}
+	fmt.Printf("解锁失败:")
+	return false
+}

+ 6 - 1
utils/redis.go

@@ -2,6 +2,7 @@ package utils
 
 import (
 	"eta/eta_api/utils/redis"
+	client "github.com/go-redis/redis/v8"
 	"time"
 )
 
@@ -24,6 +25,11 @@ type RedisClient interface {
 	SAdd(key string, args ...interface{}) (err error)
 	SRem(key string, args ...interface{}) (err error)
 	SIsMember(key string, args interface{}) (bool, error)
+	ZAdd(key string, members ...*redis.Zset) error
+	ZRangeWithScores(key string) ([]*redis.Zset, error)
+	Expire(key string, duration time.Duration) error
+	RedisClient() client.UniversalClient
+	Keys(pattern string) ([]string, error)
 }
 
 func initRedis(redisType string, conf string) (redisClient RedisClient, err error) {
@@ -33,7 +39,6 @@ func initRedis(redisType string, conf string) (redisClient RedisClient, err erro
 	default: // 默认走单机
 		redisClient, err = redis.InitStandaloneRedis(conf)
 	}
-
 	return
 }
 

+ 37 - 0
utils/redis/cluster_redis.go

@@ -16,6 +16,10 @@ import (
 type ClusterRedisClient struct {
 	redisClient *redis.ClusterClient
 }
+type Zset struct {
+	Score  float64
+	Member interface{}
+}
 
 var DefaultKey = "zcmRedis"
 
@@ -317,3 +321,36 @@ func (rc *ClusterRedisClient) SIsMember(key string, args interface{}) (isMember
 	isMember, err = rc.redisClient.SIsMember(context.TODO(), key, args).Result()
 	return
 }
+func (rc *ClusterRedisClient) ZAdd(key string, members ...*Zset) error {
+	var redisMembers []*redis.Z
+	for _, member := range members {
+		redisMembers = append(redisMembers, &redis.Z{
+			Member: member.Member,
+			Score:  member.Score,
+		})
+	}
+	return rc.redisClient.ZAdd(context.TODO(), key, redisMembers...).Err()
+}
+func (rc *ClusterRedisClient) ZRangeWithScores(key string) (result []*Zset, err error) {
+	redisZList, err := rc.redisClient.ZRangeWithScores(context.TODO(), key, 0, -1).Result()
+	if err != nil {
+		return
+	}
+	for _, redisZ := range redisZList {
+		result = append(result, &Zset{
+			Member: redisZ.Member,
+			Score:  redisZ.Score,
+		})
+	}
+	return
+}
+func (rc *ClusterRedisClient) Expire(key string, duration time.Duration) error {
+	return rc.redisClient.Expire(context.Background(), key, duration).Err()
+}
+
+func (rc *ClusterRedisClient) Keys(pattern string) (keys []string, err error) {
+	return rc.redisClient.Keys(context.Background(), pattern).Result()
+}
+func (rc *ClusterRedisClient) RedisClient() redis.UniversalClient {
+	return rc.redisClient
+}

+ 34 - 0
utils/redis/standalone_redis.go

@@ -309,3 +309,37 @@ func (rc *StandaloneRedisClient) SIsMember(key string, args interface{}) (isMemb
 	isMember, err = rc.redisClient.SIsMember(context.TODO(), key, args).Result()
 	return
 }
+
+func (rc *StandaloneRedisClient) ZAdd(key string, members ...*Zset) error {
+	var redisMembers []*redis.Z
+	for _, member := range members {
+		redisMembers = append(redisMembers, &redis.Z{
+			Member: member.Member,
+			Score:  member.Score,
+		})
+	}
+	return rc.redisClient.ZAdd(context.TODO(), key, redisMembers...).Err()
+}
+
+func (rc *StandaloneRedisClient) ZRangeWithScores(key string) (result []*Zset, err error) {
+	redisZList, err := rc.redisClient.ZRangeWithScores(context.TODO(), key, 0, -1).Result()
+	if err != nil {
+		return
+	}
+	for _, redisZ := range redisZList {
+		result = append(result, &Zset{
+			Member: redisZ.Member,
+			Score:  redisZ.Score,
+		})
+	}
+	return
+}
+func (rc *StandaloneRedisClient) Expire(key string, duration time.Duration) error {
+	return rc.redisClient.Expire(context.Background(), key, duration).Err()
+}
+func (rc *StandaloneRedisClient) Keys(pattern string) (keys []string, err error) {
+	return rc.redisClient.Keys(context.Background(), pattern).Result()
+}
+func (rc *StandaloneRedisClient) RedisClient() redis.UniversalClient {
+	return rc.redisClient
+}

+ 72 - 0
utils/sql.go

@@ -20,6 +20,8 @@ const (
 	Order         SqlCondition = "Order"
 	Delimiter     SqlCondition = "Delimiter"
 	ConvertColumn SqlCondition = "ConvertColumn"
+
+	ToDate SqlCondition = "ToDate"
 )
 
 var TemplateMap = map[SqlCondition]map[Driver]string{
@@ -31,6 +33,10 @@ var TemplateMap = map[SqlCondition]map[Driver]string{
 		MySql: `CONVERT({{.ConvertColumn}} USING gbk )`,
 		DM:    `{{.ConvertColumn}}`,
 	},
+	ToDate: {
+		MySql: `DATE({{.Column}})`,
+		DM:    `TO_DATE({{.Column}})`,
+	},
 }
 
 var supportDriverMap = map[string]Driver{
@@ -67,6 +73,71 @@ func (distinctParam *DistinctParam) GetFormatConditionStr(param *QueryParam) str
 	return ""
 }
 
+type ToDateParam struct {
+}
+
+func (toDateParam *ToDateParam) GetParamName() string {
+	return "ToDate"
+}
+func (toDateParam *ToDateParam) GetFormatConditionStr(param *QueryParam) (sql string) {
+	dbDriver, _ := getDriverInstance(param.Driver)
+	if param.Column == "" {
+		FileLog.Error("聚合字段为空,无法生成聚合sql")
+		return
+	}
+	var templateSqlStr string
+	if _, ok := TemplateMap[ToDate][dbDriver]; !ok {
+		templateSqlStr = TemplateMap[ToDate][MySql]
+	} else {
+		templateSqlStr = TemplateMap[ToDate][dbDriver]
+	}
+	if templateSqlStr == "" {
+		FileLog.Error("聚合sql模板不存在,无法生成聚合sql")
+		return
+	}
+	templateSql, err := template.New("ToDate").Parse(templateSqlStr)
+	if err != nil {
+		FileLog.Error("failed to parse template: %v", err)
+		return
+	}
+	//反射获取结构体的值
+	value := reflect.ValueOf(param)
+	// 检查是否是指针
+	if value.Kind() != reflect.Ptr {
+		fmt.Println("请求参数必须是一个结构体")
+		return
+	}
+	// 获取结构体的元素
+	elem := value.Elem()
+	// 检查是否是结构体
+	if elem.Kind() != reflect.Struct {
+		fmt.Println("请求参数必须是一个结构体")
+		return
+	}
+	// 获取字段的值
+	fieldValue := elem.FieldByName("ConvertColumn")
+	// 检查字段是否存在
+	if !fieldValue.IsValid() {
+		fmt.Printf("Error: field %s not found\n", "ConvertColumn")
+		return
+	}
+	// 检查字段是否可导出
+	if !fieldValue.CanSet() {
+		fmt.Printf("Error: field %s is not exported and cannot be set\n", "ConvertColumn")
+		return
+	}
+	// 渲染模板
+	var buf bytes.Buffer
+	err = templateSql.Execute(&buf, param)
+	if err != nil {
+		fmt.Sprintf("执行模板填充失败: %v", err)
+		return
+	}
+	sql = buf.String()
+	fmt.Printf("生成的转换日期语句为:%s\n", sql)
+	return sql
+}
+
 type ConvertParam struct {
 }
 
@@ -145,6 +216,7 @@ var sqlGeneratorFactory = map[SqlCondition]SqlParam{
 	Delimiter:     &DelimiterParam{},
 	Distinct:      &DistinctParam{},
 	ConvertColumn: &ConvertParam{},
+	ToDate:        &ToDateParam{},
 }
 
 type DelimiterParam struct {

+ 1 - 0
utils/ws/session.go

@@ -22,6 +22,7 @@ type Session struct {
 	mu          sync.RWMutex
 	sessionOnce sync.Once
 }
+
 type Message struct {
 	KbName     string   `json:"KbName"`
 	Query      string   `json:"Query"`

+ 2 - 0
utils/ws/session_manager.go

@@ -67,6 +67,8 @@ func (manager *ConnectionManager) HandleMessage(userID int, sessionID string, me
 	}
 	// 处理业务逻辑
 	session.History = append(session.History, userMessage.LastTopics...)
+
+	//TODO
 	resp, err := llmService.KnowledgeBaseChat(userMessage.Query, userMessage.KbName, session.History)
 	defer func() {
 		_ = resp.Body.Close()