Browse Source

Merge branch 'feature/deepseek_rag_1.0' of http://8.136.199.33:3000/eta_server/eta_api into feature/deepseek_rag_1.0

kobe6258 6 days ago
parent
commit
3f3f6e8eac

+ 22 - 43
controllers/rag/wechat_platform_controller.go

@@ -432,17 +432,20 @@ func (c *WechatPlatformController) ArticleList() {
 	var pars []interface{}
 
 	if keyWord != "" {
-		condition = fmt.Sprintf(` AND %s = ?`, rag.WechatPlatformColumns.Nickname)
+		condition = fmt.Sprintf(` AND a.%s = ?`, rag.WechatPlatformColumns.Nickname)
 		pars = append(pars, `%`+keyWord+`%`)
 	}
 
 	if wechatPlatformId > 0 {
-		condition = fmt.Sprintf(` AND %s = ?`, rag.WechatArticleColumns.WechatPlatformID)
+		condition = fmt.Sprintf(` AND a.%s = ?`, rag.WechatArticleColumns.WechatPlatformID)
 		pars = append(pars, wechatPlatformId)
 	}
 
+	condition = fmt.Sprintf(` AND b.%s = ?`, rag.WechatPlatformColumns.Enabled)
+	pars = append(pars, 1)
+
 	obj := new(rag.WechatArticle)
-	total, list, err := obj.GetPageListByCondition(condition, pars, startSize, pageSize)
+	total, list, err := obj.GetPageListByPlatformCondition(condition, pars, startSize, pageSize)
 	if err != nil {
 		br.Msg = "获取失败"
 		br.ErrMsg = "获取失败,Err:" + err.Error()
@@ -451,44 +454,7 @@ func (c *WechatPlatformController) ArticleList() {
 
 	viewList := make([]rag.WechatArticleView, 0)
 	if list != nil && len(list) > 0 {
-		viewList = obj.ListToViewList(list)
-
-		wechatPlatformIdList := make([]int, 0)
-		wechatPlatformIdMap := make(map[int]bool)
-		for _, v := range list {
-			if _, ok := wechatPlatformIdMap[v.WechatPlatformId]; ok {
-				continue
-			}
-			wechatPlatformIdList = append(wechatPlatformIdList, v.WechatPlatformId)
-			wechatPlatformIdMap[v.WechatPlatformId] = true
-		}
-
-		wechatPlatformMap := make(map[int]*rag.WechatPlatform)
-		if len(wechatPlatformIdList) > 0 {
-			var wechatArticleCondition string
-			var wechatArticlePars []interface{}
-			wechatArticleCondition = fmt.Sprintf(` AND %s in (?)`, rag.WechatArticleColumns.WechatPlatformID)
-			wechatArticlePars = append(wechatArticlePars, wechatPlatformIdList)
-			wechatPlatformObj := new(rag.WechatPlatform)
-			wechatPlatformList, err := wechatPlatformObj.GetListByCondition(wechatArticleCondition, wechatArticlePars, startSize, 100000)
-			if err != nil {
-				br.Msg = "获取失败"
-				br.ErrMsg = "获取失败,Err:" + err.Error()
-				return
-			}
-			for _, v := range wechatPlatformList {
-				wechatPlatformMap[v.WechatPlatformId] = v
-			}
-		}
-
-		for k, v := range viewList {
-			wechatPlatformInfo, ok := wechatPlatformMap[v.WechatPlatformId]
-			if !ok {
-				continue
-			}
-			viewList[k].WechatPlatformName = wechatPlatformInfo.Nickname
-			viewList[k].WechatPlatformRoundHeadImg = wechatPlatformInfo.RoundHeadImg
-		}
+		viewList = obj.ArticleAndPlatformListToViewList(list)
 	}
 	page := paging.GetPaging(currentIndex, pageSize, total)
 	resp := response.WechatArticleListListResp{
@@ -540,12 +506,25 @@ func (c *WechatPlatformController) ArticleDetail() {
 		br.IsSendEmail = false
 		return
 	}
-	item.Content = html.UnescapeString(item.Content)
+	resp := item.ToView()
+	resp.Content = html.UnescapeString(item.Content)
+
+	// 获取摘要信息
+	{
+		abstractObj := rag.WechatArticleAbstract{}
+		abstractItem, err := abstractObj.GetByWechatArticleId(wechatArticleId)
+		if err != nil && !utils.IsErrNoRow(err) {
+			br.Msg = "获取失败"
+			br.ErrMsg = "获取失败,Err:" + err.Error()
+			return
+		}
+		resp.Abstract = abstractItem.Content
+	}
 
 	br.Ret = 200
 	br.Success = true
 	br.Msg = "获取成功"
-	br.Data = item
+	br.Data = resp
 }
 
 // ArticleList

+ 105 - 0
models/rag/wechat_article.go

@@ -194,3 +194,108 @@ func (m *WechatArticle) GetPageListByCondition(condition string, pars []interfac
 
 	return
 }
+
+type WechatArticleAndPlatform struct {
+	WechatArticleId   int       `gorm:"column:wechat_article_id;type:int(10) UNSIGNED;primaryKey;not null;" description:""`
+	WechatPlatformId  int       `gorm:"column:wechat_platform_id;type:int(11);comment:归属公众号id;default:0;" description:"归属公众号id"`
+	FakeId            string    `gorm:"column:fake_id;type:varchar(255);comment:公众号唯一id;" description:"公众号唯一id"`
+	Title             string    `gorm:"column:title;type:varchar(255);comment:标题;" description:"标题"`
+	Link              string    `gorm:"column:link;type:varchar(255);comment:链接;" description:"链接"`
+	CoverUrl          string    `gorm:"column:cover_url;type:varchar(255);comment:公众号封面;" description:"公众号封面"`
+	Description       string    `gorm:"column:description;type:varchar(255);comment:描述;" description:"描述"`
+	Content           string    `gorm:"column:content;type:longtext;comment:报告详情;" description:"报告详情"`
+	TextContent       string    `gorm:"column:text_content;type:text;comment:文本内容;" description:"文本内容"`
+	Abstract          string    `gorm:"column:abstract;type:text;comment:摘要;" description:"摘要"`
+	Country           string    `gorm:"column:country;type:varchar(255);comment:国家;" description:"国家"`
+	Province          string    `gorm:"column:province;type:varchar(255);comment:省;" description:"省"`
+	City              string    `gorm:"column:city;type:varchar(255);comment:市;" description:"市"`
+	ArticleCreateTime time.Time `gorm:"column:article_create_time;type:datetime;comment:报告创建时间;default:NULL;" description:"报告创建时间"`
+	IsDeleted         int       `gorm:"column:is_deleted;type:tinyint(4);comment:是否删除,0:未删除,1: 已删除;default:0;" description:"是否删除,0:未删除,1: 已删除"`
+	ModifyTime        time.Time `gorm:"column:modify_time;type:datetime;comment:修改时间;default:NULL;" description:"修改时间"`
+	CreateTime        time.Time `gorm:"column:create_time;type:datetime;comment:入库时间;default:NULL;" description:"入库时间"`
+	Nickname          string    `gorm:"column:nickname;type:varchar(255);comment:公众号名称;" description:"nickname"`          // 公众号名称
+	Alias             string    `gorm:"column:alias;type:varchar(255);comment:别名;" description:"alias"`                   // 别名
+	RoundHeadImg      string    `gorm:"column:round_head_img;type:varchar(255);comment:头像;" description:"round_head_img"` // 头像
+}
+
+func (m *WechatArticleAndPlatform) ToView() WechatArticleView {
+	var articleCreateTime, modifyTime, createTime string
+
+	if !m.ArticleCreateTime.IsZero() {
+		articleCreateTime = m.ArticleCreateTime.Format(utils.FormatDateTime)
+	}
+	if !m.CreateTime.IsZero() {
+		createTime = m.CreateTime.Format(utils.FormatDateTime)
+	}
+	if !m.ModifyTime.IsZero() {
+		modifyTime = m.ModifyTime.Format(utils.FormatDateTime)
+	}
+	return WechatArticleView{
+		WechatArticleId:            m.WechatArticleId,
+		WechatPlatformId:           m.WechatPlatformId,
+		FakeId:                     m.FakeId,
+		Title:                      m.Title,
+		Link:                       m.Link,
+		CoverUrl:                   m.CoverUrl,
+		Description:                m.Description,
+		Content:                    m.Content,
+		TextContent:                m.TextContent,
+		Abstract:                   m.Abstract,
+		Country:                    m.Country,
+		Province:                   m.Province,
+		City:                       m.City,
+		ArticleCreateTime:          articleCreateTime,
+		ModifyTime:                 modifyTime,
+		CreateTime:                 createTime,
+		WechatPlatformName:         m.Nickname,
+		WechatPlatformRoundHeadImg: m.RoundHeadImg,
+	}
+}
+
+func (m *WechatArticle) ArticleAndPlatformListToViewList(list []*WechatArticleAndPlatform) (wechatArticleViewList []WechatArticleView) {
+	wechatArticleViewList = make([]WechatArticleView, 0)
+
+	for _, v := range list {
+		wechatArticleViewList = append(wechatArticleViewList, v.ToView())
+	}
+	return
+}
+
+func (m *WechatArticle) GetListByPlatformCondition(field, condition string, pars []interface{}, startSize, pageSize int) (items []*WechatArticleAndPlatform, err error) {
+	if field == "" {
+		field = "*"
+	}
+	sqlStr := fmt.Sprintf(`SELECT %s FROM %s AS a 
+          JOIN wechat_platform AS b ON a.wechat_platform_id=b.wechat_platform_id
+          WHERE 1=1 AND a.is_deleted=0 %s  order by a.article_create_time DESC,a.wechat_article_id DESC LIMIT ?,?`, field, m.TableName(), condition)
+	pars = append(pars, startSize, pageSize)
+	err = global.DbMap[utils.DbNameAI].Raw(sqlStr, pars...).Find(&items).Error
+
+	return
+}
+
+func (m *WechatArticle) GetCountByPlatformCondition(condition string, pars []interface{}) (total int, err error) {
+	var intNull sql.NullInt64
+	sqlStr := fmt.Sprintf(`SELECT COUNT(1) total FROM %s AS a 
+          JOIN wechat_platform AS b ON a.wechat_platform_id=b.wechat_platform_id 
+          WHERE 1=1 AND a.is_deleted=0 %s`, m.TableName(), condition)
+	err = global.DbMap[utils.DbNameAI].Raw(sqlStr, pars...).Scan(&intNull).Error
+	if err == nil && intNull.Valid {
+		total = int(intNull.Int64)
+	}
+
+	return
+}
+
+func (m *WechatArticle) GetPageListByPlatformCondition(condition string, pars []interface{}, startSize, pageSize int) (total int, items []*WechatArticleAndPlatform, err error) {
+
+	total, err = m.GetCountByPlatformCondition(condition, pars)
+	if err != nil {
+		return
+	}
+	if total > 0 {
+		items, err = m.GetListByPlatformCondition(`a.wechat_article_id,a.wechat_platform_id,a.fake_id,a.title,a.link,a.cover_url,a.description,a.country,a.province,a.city,a.article_create_time,a.modify_time,a.create_time,b.nickname,b.round_head_img`, condition, pars, startSize, pageSize)
+	}
+
+	return
+}

+ 113 - 0
models/rag/wechat_article_abstract.go

@@ -0,0 +1,113 @@
+package rag
+
+import (
+	"database/sql"
+	"eta/eta_api/global"
+	"eta/eta_api/utils"
+	"fmt"
+	"time"
+)
+
+type WechatArticleAbstract struct {
+	WechatArticleAbstractId int       `gorm:"column:wechat_article_abstract_id;type:int(9) UNSIGNED;primaryKey;not null;" description:"wechat_article_abstract_id"`
+	WechatArticleId         int       `gorm:"column:wechat_article_id;type:int(9) UNSIGNED;comment:关联的微信报告id;default:0;" description:"关联的微信报告id"`
+	Content                 string    `gorm:"column:content;type:longtext;comment:摘要内容;" description:"content"` // 摘要内容
+	Version                 int       `gorm:"column:version;type:int(10) UNSIGNED;comment:版本号;default:1;" description:"版本号"`
+	VectorKey               string    `gorm:"column:vector_key;type:varchar(255);comment:向量key标识;" json:"vector_key"` // 向量key标识
+	ModifyTime              time.Time `gorm:"column:modify_time;type:datetime;default:NULL;" description:"modify_time"`
+	CreateTime              time.Time `gorm:"column:create_time;type:datetime;default:NULL;" description:"create_time"`
+}
+
+// TableName get sql table name.获取数据库表名
+func (m *WechatArticleAbstract) TableName() string {
+	return "wechat_article_abstract"
+}
+
+// WechatArticleAbstractColumns get sql column name.获取数据库列名
+var WechatArticleAbstractColumns = struct {
+	WechatArticleAbstractID string
+	WechatArticleID         string
+	Content                 string
+	Version                 string
+	ModifyTime              string
+	CreateTime              string
+}{
+	WechatArticleAbstractID: "wechat_article_abstract_id",
+	WechatArticleID:         "wechat_article_id",
+	Content:                 "content",
+	Version:                 "version",
+	ModifyTime:              "modify_time",
+	CreateTime:              "create_time",
+}
+
+func (m *WechatArticleAbstract) Create() (err error) {
+	err = global.DbMap[utils.DbNameAI].Create(&m).Error
+
+	return
+}
+
+func (m *WechatArticleAbstract) Update(updateCols []string) (err error) {
+	err = global.DbMap[utils.DbNameAI].Select(updateCols).Updates(&m).Error
+
+	return
+}
+
+func (m *WechatArticleAbstract) GetById(id int) (item *WechatArticleAbstract, err error) {
+	err = global.DbMap[utils.DbNameAI].Where(fmt.Sprintf("%s = ?", WechatArticleAbstractColumns.WechatArticleAbstractID), id).First(&item).Error
+
+	return
+}
+
+// GetByWechatArticleId
+// @Description: 根据报告id获取摘要
+// @author: Roc
+// @receiver m
+// @datetime 2025-03-07 10:00:59
+// @param id int
+// @return item *WechatArticleAbstract
+// @return err error
+func (m *WechatArticleAbstract) GetByWechatArticleId(id int) (item *WechatArticleAbstract, err error) {
+	err = global.DbMap[utils.DbNameAI].Where(fmt.Sprintf("%s = ?", WechatArticleAbstractColumns.WechatArticleID), id).Order(fmt.Sprintf(`%s DESC`, WechatArticleAbstractColumns.WechatArticleAbstractID)).First(&item).Error
+
+	return
+}
+
+func (m *WechatArticleAbstract) GetListByPlatformCondition(field, condition string, pars []interface{}, startSize, pageSize int) (items []*WechatArticleAndPlatform, err error) {
+	if field == "" {
+		field = "*"
+	}
+	sqlStr := fmt.Sprintf(`SELECT %s FROM %s AS a 
+          JOIN wechat_platform AS b ON a.wechat_platform_id=b.wechat_platform_id
+          JOIN wechat_platform AS b ON a.wechat_platform_id=b.wechat_platform_id
+          WHERE 1=1 AND a.is_deleted=0 %s  order by a.article_create_time DESC,a.wechat_article_id DESC LIMIT ?,?`, field, m.TableName(), condition)
+	pars = append(pars, startSize, pageSize)
+	err = global.DbMap[utils.DbNameAI].Raw(sqlStr, pars...).Find(&items).Error
+
+	return
+}
+
+func (m *WechatArticleAbstract) GetCountByPlatformCondition(condition string, pars []interface{}) (total int, err error) {
+	var intNull sql.NullInt64
+	sqlStr := fmt.Sprintf(`SELECT COUNT(1) total FROM %s AS a 
+          JOIN wechat_platform AS b ON a.wechat_platform_id=b.wechat_platform_id 
+          WHERE 1=1 AND a.is_deleted=0 %s`, m.TableName(), condition)
+	err = global.DbMap[utils.DbNameAI].Raw(sqlStr, pars...).Scan(&intNull).Error
+	if err == nil && intNull.Valid {
+		total = int(intNull.Int64)
+	}
+
+	return
+}
+
+func (m *WechatArticleAbstract) GetPageListByPlatformCondition(condition string, pars []interface{}, startSize, pageSize int) (total int, items []*WechatArticleAndPlatform, err error) {
+
+	total, err = m.GetCountByPlatformCondition(condition, pars)
+	if err != nil {
+		return
+	}
+	if total > 0 {
+		items, err = m.GetListByPlatformCondition(`a.wechat_article_id,a.wechat_platform_id,a.fake_id,a.title,a.link,a.cover_url,a.description,a.country,a.province,a.city,a.article_create_time,a.modify_time,a.create_time,b.nickname,b.round_head_img`, condition, pars, startSize, pageSize)
+	}
+
+	return
+}

+ 60 - 0
models/rag/wechat_article_chat_record.go

@@ -0,0 +1,60 @@
+package rag
+
+import (
+	"eta/eta_api/global"
+	"eta/eta_api/utils"
+	"fmt"
+	"time"
+)
+
+type WechatArticleChatRecord struct {
+	WechatArticleChatRecordId int32     `gorm:"column:wechat_article_chat_record_id;type:int(11);comment:主键;primaryKey;not null;" json:"wechat_article_chat_record_id"` // 主键
+	WechatArticleId           int32     `gorm:"column:wechat_article_id;type:int(11);comment:文章id;default:NULL;" json:"wechat_article_id"`                              // 文章id
+	ChatUserType              string    `gorm:"column:chat_user_type;type:enum('user', 'assistant');comment:用户方;default:NULL;" json:"chat_user_type"`                   // 用户方
+	Content                   string    `gorm:"column:content;type:longtext;comment:对话内容;" json:"content"`                                                              // 对话内容
+	SendTime                  time.Time `gorm:"column:send_time;type:datetime;comment:发送时间;default:NULL;" json:"send_time"`                                             // 发送时间
+	CreatedTime               time.Time `gorm:"column:created_time;type:datetime;comment:创建时间;default:NULL;" json:"created_time"`                                       // 创建时间
+	UpdateTime                time.Time `gorm:"column:update_time;type:datetime;comment:更新时间;default:NULL;" json:"update_time"`                                         // 更新时间
+}
+
+// TableName get sql table name.获取数据库表名
+func (m *WechatArticleChatRecord) TableName() string {
+	return "wechat_article_chat_record"
+}
+
+// WechatArticleChatRecordColumns get sql column name.获取数据库列名
+var WechatArticleChatRecordColumns = struct {
+	WechatArticleChatRecordID string
+	WechatArticleID           string
+	ChatUserType              string
+	Content                   string
+	SendTime                  string
+	CreatedTime               string
+	UpdateTime                string
+}{
+	WechatArticleChatRecordID: "wechat_article_chat_record_id",
+	WechatArticleID:           "wechat_article_id",
+	ChatUserType:              "chat_user_type",
+	Content:                   "content",
+	SendTime:                  "send_time",
+	CreatedTime:               "created_time",
+	UpdateTime:                "update_time",
+}
+
+func (m *WechatArticleChatRecord) Create() (err error) {
+	err = global.DbMap[utils.DbNameAI].Create(&m).Error
+
+	return
+}
+
+func (m *WechatArticleChatRecord) Update(updateCols []string) (err error) {
+	err = global.DbMap[utils.DbNameAI].Select(updateCols).Updates(&m).Error
+
+	return
+}
+
+func (m *WechatArticleChatRecord) GetById(id int) (item *WechatArticle, err error) {
+	err = global.DbMap[utils.DbNameAI].Where(fmt.Sprintf("%s = ?", WechatArticleChatRecordColumns.WechatArticleChatRecordID), id).First(&item).Error
+
+	return
+}

+ 102 - 0
services/llm/chat.go

@@ -0,0 +1,102 @@
+package llm
+
+import (
+	"bytes"
+	"eta/eta_api/utils"
+	"fmt"
+	"io"
+	"mime/multipart"
+	"net/http"
+	"os"
+)
+
+type UpdateTempDocsResp struct {
+	Code int    `json:"code"`
+	Msg  string `json:"msg"`
+	Data struct {
+		Id          string        `json:"id"`
+		FailedFiles []interface{} `json:"failed_files"`
+	} `json:"data"`
+}
+
+func UpdateTempDocs(filePath string) {
+	postUrl := utils.LLM_SERVER + "/knowledge_base/upload_temp_docs"
+
+	params := make(map[string]string)
+	//params[`prev_id`] = ``
+	params[`chunk_size`] = `750`
+	params[`chunk_overlap`] = `150`
+	params[`zh_title_enhance`] = `true`
+
+	files := make(map[string]string)
+	files[`files`] = filePath
+
+	result, err := PostFormData(postUrl, params, files)
+	if err != nil {
+		return
+	}
+
+	str := string(result)
+	fmt.Println(str)
+
+	return
+}
+
+func ChatByFile(question string) {
+	// 没有问题那就直接返回
+	if question == `` {
+		return
+	}
+}
+
+// PostFormData sends a POST request with form-data
+func PostFormData(url string, params map[string]string, files map[string]string) ([]byte, error) {
+
+	body := &bytes.Buffer{}
+	writer := multipart.NewWriter(body)
+
+	for key, val := range params {
+		if err := writer.WriteField(key, val); err != nil {
+			return nil, err
+		}
+	}
+
+	for fieldName, filePath := range files {
+		file, err := os.Open(filePath)
+		if err != nil {
+			return nil, err
+		}
+		defer file.Close()
+
+		part, err := writer.CreateFormFile(fieldName, filePath)
+		if err != nil {
+			return nil, err
+		}
+		_, err = io.Copy(part, file)
+		if err != nil {
+			return nil, err
+		}
+	}
+
+	err := writer.Close()
+	if err != nil {
+		return nil, err
+	}
+
+	req, err := http.NewRequest("POST", url, body)
+	if err != nil {
+		return nil, err
+	}
+	req.Header.Set("Content-Type", writer.FormDataContentType())
+
+	client := &http.Client{}
+	resp, err := client.Do(req)
+	if err != nil {
+		return nil, err
+	}
+	defer resp.Body.Close()
+
+	result, err := io.ReadAll(resp.Body)
+
+	return result, nil
+}