浏览代码

Merge remote-tracking branch 'origin/rag/3.1' into debug

# Conflicts:
#	controllers/data_manage/excel/balance_table.go
#	controllers/data_manage/excel/excel_info.go
Roc 3 天之前
父节点
当前提交
5338b87494

+ 6 - 6
controllers/data_manage/excel/balance_table.go

@@ -1425,12 +1425,12 @@ func downloadBalanceTable(excelInfo *excel.ExcelInfo, lang string) (savePath, zi
 				err = fmt.Errorf("转换成table失败,Err:" + err.Error())
 				return
 			}
-			// tableData, err = excel2.HandleRuleToTableCell(childExcelInfo.ExcelInfoId, tableData)
-			// if err != nil {
-			// 	errMsg = "获取失败"
-			// 	err = fmt.Errorf("处理条件格式管理规则失败,Err:%w", err)
-			// 	return
-			// }
+			//tableData, err = excel2.HandleRuleToTableCell(childExcelInfo.ExcelInfoId, tableData)
+			//if err != nil {
+			//	errMsg = "获取失败"
+			//	err = fmt.Errorf("处理条件格式管理规则失败,Err:%w", err)
+			//	return
+			//}
 			// 将单个sheet的数据写入到excel
 			err = tableData.WriteExcelSheetData(xlsxFile, childExcelInfo.ExcelName)
 			if err != nil {

+ 6 - 6
controllers/data_manage/excel/excel_info.go

@@ -2708,12 +2708,12 @@ func (c *ExcelInfoController) Download() {
 			br.ErrMsg = "转换成table失败,Err:" + err.Error()
 			return
 		}
-		// tableData, err = excel.HandleRuleToTableCell(excelInfo.ExcelInfoId, tableData)
-		// if err != nil {
-		// 	br.Msg = "获取失败"
-		// 	br.ErrMsg = "处理条件格式管理规则失败,Err:" + err.Error()
-		// 	return
-		// }
+		//tableData, err = excel.HandleRuleToTableCell(excelInfo.ExcelInfoId, tableData)
+		//if err != nil {
+		//	br.Msg = "获取失败"
+		//	br.ErrMsg = "处理条件格式管理规则失败,Err:" + err.Error()
+		//	return
+		//}
 	case utils.BALANCE_TABLE: // 混合表格
 		savePath, fileName, uploadDir, err, errMsg := downloadBalanceTable(excelInfo, c.Lang)
 		if err != nil {

+ 4 - 2
controllers/llm/llm_http/request.go

@@ -22,12 +22,14 @@ type UserChatRecordReq struct {
 }
 
 type GenerateContentReq struct {
-	WechatArticleId int    `json:"WechatArticleId" description:"公众号Id"`
+	Source          int    `json:"Source" description:"来源,0:公众号文章,1:eta报告"`
+	WechatArticleId int    `json:"WechatArticleId" description:"公众号文章Id"`
 	Promote         string `json:"Promote" description:"提示词"`
 	LLMModel        string `json:"LLMModel"`
 }
 type SaveContentReq struct {
-	WechatArticleId int             `json:"WechatArticleId" description:"公众号Id"`
+	Source          int             `json:"Source" description:"来源,0:公众号文章,1:eta报告"`
+	WechatArticleId int             `json:"WechatArticleId" description:"公众号文章Id"`
 	Title           string          `json:"Title" description:"标题"`
 	Llm             string          `json:"LLM"`
 	Promote         json.RawMessage `json:"Promote" description:"提示词"`

+ 5 - 1
controllers/llm/promote_controller.go

@@ -75,6 +75,7 @@ func (pCtrl *PromoteController) GenerateContent() {
 	}
 	res, err := facade.AIGCBaseOnPromote(facade.AIGC{
 		Promote:   gcReq.Promote,
+		Source:    gcReq.Source,
 		ArticleId: gcReq.WechatArticleId,
 		LLMModel:  gcReq.LLMModel,
 	})
@@ -95,6 +96,7 @@ func (pCtrl *PromoteController) GenerateContent() {
 	}
 	llm := strings.ReplaceAll(gcReq.LLMModel, ":", "")
 	saveContentReq := rag.PromoteTrainRecord{
+		Source:          gcReq.Source,
 		WechatArticleId: gcReq.WechatArticleId,
 		Title:           userContent.Content,
 		Llm:             llm,
@@ -200,6 +202,7 @@ func (pCtrl *PromoteController) SavePromoteContent() {
 	}
 	llm := strings.ReplaceAll(gcReq.Llm, ":", "")
 	saveContentReq := rag.PromoteTrainRecord{
+		Source:          gcReq.Source,
 		WechatArticleId: gcReq.WechatArticleId,
 		Title:           titile,
 		Llm:             llm,
@@ -273,6 +276,7 @@ func (pCtrl *PromoteController) PromoteContentList() {
 		pCtrl.ServeJSON()
 	}()
 	wechatArticleId, _ := pCtrl.GetInt("WechatArticleId")
+	source, _ := pCtrl.GetInt("Source")
 	sysUser := pCtrl.SysUser
 	if sysUser == nil {
 		br.Msg = "请登录"
@@ -287,7 +291,7 @@ func (pCtrl *PromoteController) PromoteContentList() {
 		return
 	}
 
-	list, err := rag.GetRecordList(wechatArticleId)
+	list, err := rag.GetRecordList(wechatArticleId, source)
 	if err != nil {
 		br.Msg = "查询列表失败"
 		br.ErrMsg = "查询列表失败,err:" + err.Error()

+ 228 - 0
controllers/llm/report.go

@@ -0,0 +1,228 @@
+package llm
+
+import (
+	"eta/eta_api/controllers"
+	"eta/eta_api/models"
+	"eta/eta_api/models/rag"
+	"eta/eta_api/models/rag/response"
+	"eta/eta_api/services"
+	"eta/eta_api/utils"
+	"fmt"
+	"github.com/rdlucklib/rdluck_tools/paging"
+	"html"
+)
+
+// RagEtaReportController
+// @Description: eta报告的接口
+type RagEtaReportController struct {
+	controllers.BaseAuthController
+}
+
+// ArticleList
+// @Title 我关注的接口
+// @Description 我关注的接口
+// @Param   PageSize   query   int  true       "每页数据条数"
+// @Param   CurrentIndex   query   int  true       "当前页页码,从1开始"
+// @Param   KeyWord   query   string  true       "搜索关键词"
+// @Success 200 {object} *rag.RagEtaReportListListResp
+// @router /eta_report/article/list [get]
+func (c *RagEtaReportController) ArticleList() {
+	br := new(models.BaseResponse).Init()
+	defer func() {
+		c.Data["json"] = br
+		c.ServeJSON()
+	}()
+
+	sysUser := c.SysUser
+	if sysUser == nil {
+		br.Msg = "请登录"
+		br.ErrMsg = "请登录,SysUser Is Empty"
+		return
+	}
+	pageSize, _ := c.GetInt("PageSize")
+	currentIndex, _ := c.GetInt("CurrentIndex")
+	keyWord := c.GetString("KeyWord")
+
+	var startSize int
+	if pageSize <= 0 {
+		pageSize = utils.PageSize20
+	}
+	if currentIndex <= 0 {
+		currentIndex = 1
+	}
+	startSize = utils.StartIndex(currentIndex, pageSize)
+
+	var total int
+	viewList := make([]rag.RagEtaReportView, 0)
+
+	var condition string
+	var pars []interface{}
+
+	if keyWord != "" {
+		condition += fmt.Sprintf(` AND %s like ? `, rag.RagEtaReportColumns.Title)
+		pars = append(pars, `%`+keyWord+`%`)
+	}
+
+	obj := new(rag.RagEtaReport)
+	tmpTotal, list, err := obj.GetPageListByCondition(condition, pars, startSize, pageSize)
+	if err != nil {
+		br.Msg = "获取失败"
+		br.ErrMsg = "获取失败,Err:" + err.Error()
+		return
+	}
+	total = tmpTotal
+
+	if list != nil && len(list) > 0 {
+		viewList = obj.ListToViewList(list)
+	}
+
+	page := paging.GetPaging(currentIndex, pageSize, total)
+	resp := response.RagEtaReportListListResp{
+		List:   viewList,
+		Paging: page,
+	}
+
+	br.Ret = 200
+	br.Success = true
+	br.Msg = "获取成功"
+	br.Data = resp
+}
+
+// ArticleDetail
+// @Title 文章详情
+// @Description 我关注的接口
+// @Param   RagEtaReportId   query   int  true       "知识库与eta报告关联的id"
+// @Success 200 {object} []*rag.WechatArticle
+// @router /eta_report/article/detail [get]
+func (c *RagEtaReportController) ArticleDetail() {
+	br := new(models.BaseResponse).Init()
+	defer func() {
+		c.Data["json"] = br
+		c.ServeJSON()
+	}()
+
+	sysUser := c.SysUser
+	if sysUser == nil {
+		br.Msg = "请登录"
+		br.ErrMsg = "请登录,SysUser Is Empty"
+		return
+	}
+	ragEtaReportId, _ := c.GetInt("RagEtaReportId")
+	if ragEtaReportId <= 0 {
+		br.Msg = "请选择文章"
+		br.IsSendEmail = false
+		return
+	}
+	obj := new(rag.RagEtaReport)
+	item, err := obj.GetById(ragEtaReportId)
+	if err != nil {
+		br.Msg = "获取失败"
+		br.ErrMsg = "获取失败,Err:" + err.Error()
+		return
+	}
+
+	if item.IsDeleted == 1 {
+		br.Msg = "文章已删除"
+		br.IsSendEmail = false
+		return
+	}
+	resp := item.ToView()
+
+	content := ``
+
+	// 获取源报告信息
+	{
+		if item.ReportChapterId <= 0 {
+			// 普通报告
+			reportInfo, err := models.GetReportByReportId(item.ReportId)
+			if err != nil && !utils.IsErrNoRow(err) {
+				br.Msg = "获取报告详情失败"
+				br.ErrMsg = "获取源报告详情失败,Err:" + err.Error()
+				return
+			}
+			content = reportInfo.Content
+		} else {
+			// 章节报告
+			reportChapterInfo, err := models.GetReportChapterInfoById(item.ReportChapterId)
+			if err != nil && !utils.IsErrNoRow(err) {
+				br.Msg = "获取报告详情失败"
+				br.ErrMsg = "获取源报告章节详情失败,Err:" + err.Error()
+				return
+			}
+			content = reportChapterInfo.Content
+		}
+
+		if content != `` {
+			content = html.UnescapeString(content)
+			businessConf, err := models.GetBusinessConfByKey(models.BusinessConfIsOpenChartExpired)
+			if err != nil {
+				br.Msg = "获取失败"
+				br.ErrMsg = "获取配置失败,Err:" + err.Error()
+				return
+			}
+
+			if businessConf.ConfVal == `true` {
+				tokenMap := make(map[string]string)
+				content = services.HandleReportContent(content, "add", tokenMap)
+			}
+		}
+	}
+
+	resp.Content = content
+
+	br.Ret = 200
+	br.Success = true
+	br.Msg = "获取成功"
+	br.Data = resp
+}
+
+// ArticleDel
+// @Title 删除文章
+// @Description 我关注的接口
+// @Param   RagEtaReportId   query   int  true       "知识库与eta报告关联的id"
+// @Success 200 {object} []*rag.WechatPlatform
+// @router /eta_report/article/del [get]
+func (c *RagEtaReportController) ArticleDel() {
+	br := new(models.BaseResponse).Init()
+	defer func() {
+		c.Data["json"] = br
+		c.ServeJSON()
+	}()
+
+	sysUser := c.SysUser
+	if sysUser == nil {
+		br.Msg = "请登录"
+		br.ErrMsg = "请登录,SysUser Is Empty"
+		return
+	}
+	ragEtaReportId, _ := c.GetInt("RagEtaReportId")
+	if ragEtaReportId <= 0 {
+		br.Msg = "请选择文章"
+		br.IsSendEmail = false
+		return
+	}
+	obj := new(rag.RagEtaReport)
+	item, err := obj.GetById(ragEtaReportId)
+	if err != nil {
+		br.Msg = "获取失败"
+		br.ErrMsg = "获取失败,Err:" + err.Error()
+		return
+	}
+
+	if item.IsDeleted == 1 {
+		br.Msg = "文章已删除"
+		br.IsSendEmail = false
+		return
+	}
+	item.IsDeleted = 1
+	err = item.Update([]string{"is_deleted"})
+	if err != nil {
+		br.Msg = "删除失败"
+		br.ErrMsg = "删除失败,Err:" + err.Error()
+		return
+	}
+
+	br.Ret = 200
+	br.Success = true
+	br.Msg = "删除成功"
+}

+ 3 - 2
models/rag/article_kb_mapping.go

@@ -10,6 +10,7 @@ import (
 type ArticleKbMapping struct {
 	Id              int `gorm:"id;primaryKey"`
 	WechatArticleId int
+	Source          int
 	KbId            string
 	CreatedTime     time.Time
 	UpdateTime      time.Time
@@ -19,8 +20,8 @@ func (a *ArticleKbMapping) TableName() string {
 	return "article_kb_mapping"
 }
 
-func GetArticleKbMapping(articleId int) (articleKbMapping *ArticleKbMapping, err error) {
-	err = global.DbMap[utils.DbNameAI].Where("wechat_article_id = ?", articleId).First(&articleKbMapping).Error
+func GetArticleKbMapping(articleId, source int) (articleKbMapping *ArticleKbMapping, err error) {
+	err = global.DbMap[utils.DbNameAI].Where("wechat_article_id = ? AND source = ? ", articleId, source).First(&articleKbMapping).Error
 	return
 }
 

+ 5 - 2
models/rag/promote_train_record.go

@@ -10,6 +10,7 @@ type PromoteTrainRecord struct {
 	Id              int       `gorm:"id;primaryKey"`
 	Title           string    `gorm:"title"`
 	Llm             string    `gorm:"llm"`
+	Source          int       `gorm:"source" description:"来源,0:公众号文章,1:eta报告"`
 	WechatArticleId int       `gorm:"wechat_article_id"`
 	TemplatePromote string    `gorm:"template_promote"`
 	PromoteSendTime time.Time `gorm:"promote_send_time"`
@@ -25,6 +26,7 @@ func (p *PromoteTrainRecord) ToView() *PromoteTrainRecordView {
 		Id:              p.Id,
 		Title:           p.Title,
 		Llm:             p.Llm,
+		Source:          p.Source,
 		WechatArticleId: p.WechatArticleId,
 		TemplatePromote: p.TemplatePromote,
 		PromoteSendTime: p.PromoteSendTime.Format(utils.FormatDateTime),
@@ -37,6 +39,7 @@ type PromoteTrainRecordView struct {
 	Id              int
 	Title           string
 	Llm             string
+	Source          int
 	WechatArticleId int
 	TemplatePromote string
 	PromoteSendTime string
@@ -55,9 +58,9 @@ func DeleteContent(id int) error {
 	return global.DbMap[utils.DbNameAI].Model(&PromoteTrainRecord{}).Where("id = ?", id).Update("is_deleted", true).Error
 }
 
-func GetRecordList(wechatArticleId int) (list []*PromoteTrainRecordView, err error) {
+func GetRecordList(wechatArticleId, source int) (list []*PromoteTrainRecordView, err error) {
 	var ormList []PromoteTrainRecord
-	err = global.DbMap[utils.DbNameAI].Model(&PromoteTrainRecord{}).Where("wechat_article_id = ? and is_deleted=?", wechatArticleId, false).Order(`created_time DESC`).Find(&ormList).Error
+	err = global.DbMap[utils.DbNameAI].Model(&PromoteTrainRecord{}).Where("wechat_article_id = ? AND source = ?  and is_deleted=?", wechatArticleId, source, false).Order(`created_time DESC`).Find(&ormList).Error
 	if err != nil {
 		return
 	}

+ 159 - 0
models/rag/rag_eta_report.go

@@ -0,0 +1,159 @@
+package rag
+
+import (
+	"database/sql"
+	"eta/eta_api/global"
+	"eta/eta_api/utils"
+	"fmt"
+	"time"
+)
+
+// RagEtaReport eta报告
+type RagEtaReport struct {
+	RagEtaReportId  int       `gorm:"primaryKey;column:rag_eta_report_id" `
+	ReportId        int       `gorm:"column:report_id" description:"报告id"`
+	ReportChapterId int       `gorm:"column:report_chapter_id" description:"报告章节id"`
+	Title           string    `gorm:"column:title" description:"报告标题(完整标题,含期数)"`
+	Author          string    `gorm:"column:author" description:"作者"`
+	TextContent     string    `gorm:"column:text_content" description:"报告内容(去除html)"`
+	VectorKey       string    `gorm:"column:vector_key" description:"向量库的key"`
+	IsDeleted       int       `gorm:"column:is_deleted;type:tinyint(4);comment:是否删除,0:未删除,1:已删除;default:0;" description:"否删除,0:未删除,1:已删除"`
+	PublishTime     time.Time `gorm:"column:publish_time" description:"发布时间"`
+	ModifyTime      time.Time `gorm:"column:modify_time" description:"修改时间"`
+	CreateTime      time.Time `gorm:"column:create_time" description:"新增时间"`
+}
+
+// TableName get sql table name.获取数据库表名
+func (m *RagEtaReport) TableName() string {
+	return "rag_eta_report"
+}
+
+// RagEtaReportColumns get sql column name.获取数据库列名
+var RagEtaReportColumns = struct {
+	RagEtaReportID  string
+	ReportID        string
+	ReportChapterID string
+	Title           string
+	Author          string
+	TextContent     string
+	VectorKey       string
+	IsDeleted       string
+	PublishTime     string
+	ModifyTime      string
+	CreateTime      string
+}{
+	RagEtaReportID:  "rag_eta_report_id",
+	ReportID:        "report_id",
+	ReportChapterID: "report_chapter_id",
+	Title:           "title",
+	Author:          "author",
+	TextContent:     "text_content",
+	VectorKey:       "vector_key",
+	IsDeleted:       "is_deleted",
+	PublishTime:     "publish_time",
+	ModifyTime:      "modify_time",
+	CreateTime:      "create_time",
+}
+
+type RagEtaReportView struct {
+	RagEtaReportId  int    `gorm:"primaryKey;column:rag_eta_report_id" `
+	ReportId        int    `gorm:"column:report_id" description:"报告id"`
+	ReportChapterId int    `gorm:"column:report_chapter_id" description:"报告章节id"`
+	Title           string `gorm:"column:title" description:"报告标题(完整标题,含期数)"`
+	Author          string `gorm:"column:author" description:"作者"`
+	Content         string `gorm:"column:content" description:"报告内容(包含html)"`
+	TextContent     string `gorm:"column:text_content" description:"报告内容(去除html)"`
+	VectorKey       string `gorm:"column:vector_key" description:"向量库的key"`
+	IsDeleted       int    `gorm:"column:is_deleted;type:tinyint(4);comment:是否删除,0:未删除,1:已删除;default:0;" description:"否删除,0:未删除,1:已删除"`
+	PublishTime     string `gorm:"column:publish_time" description:"发布时间"`
+	ModifyTime      string `gorm:"column:modify_time" description:"修改时间"`
+	CreateTime      string `gorm:"column:create_time" description:"新增时间"`
+}
+
+func (m *RagEtaReport) ToView() RagEtaReportView {
+	var publishTime, modifyTime, createTime string
+
+	if !m.PublishTime.IsZero() {
+		publishTime = m.PublishTime.Format(utils.FormatDateTime)
+	}
+	if !m.CreateTime.IsZero() {
+		createTime = m.CreateTime.Format(utils.FormatDateTime)
+	}
+	if !m.ModifyTime.IsZero() {
+		modifyTime = m.ModifyTime.Format(utils.FormatDateTime)
+	}
+	return RagEtaReportView{
+		RagEtaReportId:  m.RagEtaReportId,
+		ReportId:        m.ReportId,
+		ReportChapterId: m.ReportChapterId,
+		Title:           m.Title,
+		Author:          m.Author,
+		TextContent:     m.TextContent,
+		VectorKey:       m.VectorKey,
+		PublishTime:     publishTime,
+		ModifyTime:      modifyTime,
+		CreateTime:      createTime,
+	}
+}
+
+func (m *RagEtaReport) ListToViewList(list []*RagEtaReport) (RagEtaReportViewList []RagEtaReportView) {
+	RagEtaReportViewList = make([]RagEtaReportView, 0)
+
+	for _, v := range list {
+		RagEtaReportViewList = append(RagEtaReportViewList, v.ToView())
+	}
+	return
+}
+
+func (m *RagEtaReport) Create() (err error) {
+	err = global.DbMap[utils.DbNameAI].Create(&m).Error
+
+	return
+}
+
+func (m *RagEtaReport) Update(updateCols []string) (err error) {
+	err = global.DbMap[utils.DbNameAI].Select(updateCols).Updates(&m).Error
+
+	return
+}
+
+func (m *RagEtaReport) GetById(id int) (item *RagEtaReport, err error) {
+	err = global.DbMap[utils.DbNameAI].Where(fmt.Sprintf("%s = ?", RagEtaReportColumns.RagEtaReportID), id).First(&item).Error
+
+	return
+}
+
+func (m *RagEtaReport) GetListByCondition(field, condition string, pars []interface{}, startSize, pageSize int) (items []*RagEtaReport, err error) {
+	if field == "" {
+		field = "*"
+	}
+	sqlStr := fmt.Sprintf(`SELECT %s FROM %s WHERE is_deleted=0 %s  order by publish_time desc,report_id desc,report_chapter_id desc LIMIT ?,?`, field, m.TableName(), condition)
+	pars = append(pars, startSize, pageSize)
+	err = global.DbMap[utils.DbNameAI].Raw(sqlStr, pars...).Find(&items).Error
+
+	return
+}
+
+func (m *RagEtaReport) GetCountByCondition(condition string, pars []interface{}) (total int, err error) {
+	var intNull sql.NullInt64
+	sqlStr := fmt.Sprintf(`SELECT COUNT(1) total FROM %s WHERE 1=1 AND is_deleted=0 %s`, m.TableName(), condition)
+	err = global.DbMap[utils.DbNameAI].Raw(sqlStr, pars...).Scan(&intNull).Error
+	if err == nil && intNull.Valid {
+		total = int(intNull.Int64)
+	}
+
+	return
+}
+
+func (m *RagEtaReport) GetPageListByCondition(condition string, pars []interface{}, startSize, pageSize int) (total int, items []*RagEtaReport, err error) {
+
+	total, err = m.GetCountByCondition(condition, pars)
+	if err != nil {
+		return
+	}
+	if total > 0 {
+		items, err = m.GetListByCondition(``, condition, pars, startSize, pageSize)
+	}
+
+	return
+}

+ 11 - 0
models/rag/response/eta_report.go

@@ -0,0 +1,11 @@
+package response
+
+import (
+	"eta/eta_api/models/rag"
+	"github.com/rdlucklib/rdluck_tools/paging"
+)
+
+type RagEtaReportListListResp struct {
+	List   []rag.RagEtaReportView
+	Paging *paging.PagingItem `description:"分页数据"`
+}

+ 27 - 0
routers/commentsRouter.go

@@ -8791,6 +8791,33 @@ func init() {
             Filters: nil,
             Params: nil})
 
+    beego.GlobalControllerRouter["eta/eta_api/controllers/llm:RagEtaReportController"] = append(beego.GlobalControllerRouter["eta/eta_api/controllers/llm:RagEtaReportController"],
+        beego.ControllerComments{
+            Method: "ArticleDel",
+            Router: `/eta_report/article/del`,
+            AllowHTTPMethods: []string{"get"},
+            MethodParams: param.Make(),
+            Filters: nil,
+            Params: nil})
+
+    beego.GlobalControllerRouter["eta/eta_api/controllers/llm:RagEtaReportController"] = append(beego.GlobalControllerRouter["eta/eta_api/controllers/llm:RagEtaReportController"],
+        beego.ControllerComments{
+            Method: "ArticleDetail",
+            Router: `/eta_report/article/detail`,
+            AllowHTTPMethods: []string{"get"},
+            MethodParams: param.Make(),
+            Filters: nil,
+            Params: nil})
+
+    beego.GlobalControllerRouter["eta/eta_api/controllers/llm:RagEtaReportController"] = append(beego.GlobalControllerRouter["eta/eta_api/controllers/llm:RagEtaReportController"],
+        beego.ControllerComments{
+            Method: "ArticleList",
+            Router: `/eta_report/article/list`,
+            AllowHTTPMethods: []string{"get"},
+            MethodParams: param.Make(),
+            Filters: nil,
+            Params: nil})
+
     beego.GlobalControllerRouter["eta/eta_api/controllers/llm:UserChatController"] = append(beego.GlobalControllerRouter["eta/eta_api/controllers/llm:UserChatController"],
         beego.ControllerComments{
             Method: "ChatRecordList",

+ 1 - 0
routers/router.go

@@ -79,6 +79,7 @@ func init() {
 				&llm.QuestionController{},
 				&llm.AbstractController{},
 				&llm.PromoteController{},
+				&llm.RagEtaReportController{},
 			),
 		),
 		web.NSNamespace("/banner",

+ 53 - 17
services/llm/facade/llm_service.go

@@ -53,7 +53,7 @@ func AIGCBaseOnPromote(aigc AIGC) (resp bus_response.AIGCEtaResponse, err error)
 	if aigc.LLMModel != "" {
 		param["LLM"] = aigc.LLMModel
 	}
-	mapping, queryErr := rag.GetArticleKbMapping(aigc.ArticleId)
+	mapping, queryErr := rag.GetArticleKbMapping(aigc.ArticleId, aigc.Source)
 	if queryErr != nil && !errors.Is(queryErr, gorm.ErrRecordNotFound) {
 		utils.FileLog.Error("获取文章知识库信息失败,err: %v", queryErr)
 		err = fmt.Errorf("获取文章知识库信息失败,err: %v", queryErr)
@@ -62,20 +62,37 @@ func AIGCBaseOnPromote(aigc AIGC) (resp bus_response.AIGCEtaResponse, err error)
 		var kbId string
 		var file *os.File
 		if mapping.Id == 0 || mapping.KbId == "" {
-			article, fileErr := rag.GetArticleById(aigc.ArticleId)
-			if fileErr != nil {
-				// 找不到就处理失败
-				utils.FileLog.Error("公众号文章不存在")
-				err = fmt.Errorf("公众号文章不存在")
-				return
+			var title, textContent string
+			switch aigc.Source {
+			case 0:
+				article, fileErr := rag.GetArticleById(aigc.ArticleId)
+				if fileErr != nil {
+					// 找不到就处理失败
+					utils.FileLog.Error("公众号文章不存在")
+					err = fmt.Errorf("公众号文章不存在")
+					return
+				}
+				textContent = article.TextContent
+				title = article.Title
+			case 1:
+				ragEtaReportObj := rag.RagEtaReport{}
+				article, fileErr := ragEtaReportObj.GetById(aigc.ArticleId)
+				if fileErr != nil {
+					// 找不到就处理失败
+					utils.FileLog.Error("ETA文章不存在")
+					err = fmt.Errorf("ETA文章不存在")
+					return
+				}
+				textContent = article.TextContent
+				title = article.Title
 			}
-			if article.TextContent == "" {
+			if textContent == "" {
 				utils.FileLog.Error("暂不支持纯文本以外的内容生成")
 				err = fmt.Errorf("暂不支持纯文本以外的内容生成")
 				return
 			}
 			// 文章加入到知识库
-			path, fileErr := localService.CreateArticleFile(article)
+			path, fileErr := localService.CreateArticleFile(title, textContent)
 			if fileErr != nil {
 				utils.FileLog.Error("创建文章文件失败,err: %v", fileErr)
 				err = fmt.Errorf("创建文章文件失败,err: %v", fileErr)
@@ -103,6 +120,7 @@ func AIGCBaseOnPromote(aigc AIGC) (resp bus_response.AIGCEtaResponse, err error)
 				return
 			}
 			err = rag.CreateArticleKbMapping(rag.ArticleKbMapping{
+				Source:          aigc.Source,
 				WechatArticleId: aigc.ArticleId,
 				KbId:            data.Id,
 				CreatedTime:     time.Now(),
@@ -134,20 +152,37 @@ func AIGCBaseOnPromote(aigc AIGC) (resp bus_response.AIGCEtaResponse, err error)
 		}
 		if gcResp.Code == 404 {
 			param["PrevId"] = kbId
-			article, fileErr := rag.GetArticleById(aigc.ArticleId)
-			if fileErr != nil {
-				// 找不到就处理失败
-				utils.FileLog.Error("公众号文章不存在")
-				err = fmt.Errorf("公众号文章不存在")
-				return
+			var title, textContent string
+			switch aigc.Source {
+			case 0:
+				article, fileErr := rag.GetArticleById(aigc.ArticleId)
+				if fileErr != nil {
+					// 找不到就处理失败
+					utils.FileLog.Error("公众号文章不存在")
+					err = fmt.Errorf("公众号文章不存在")
+					return
+				}
+				textContent = article.TextContent
+				title = article.Title
+			case 1:
+				ragEtaReportObj := rag.RagEtaReport{}
+				article, fileErr := ragEtaReportObj.GetById(aigc.ArticleId)
+				if fileErr != nil {
+					// 找不到就处理失败
+					utils.FileLog.Error("ETA文章不存在")
+					err = fmt.Errorf("ETA文章不存在")
+					return
+				}
+				textContent = article.TextContent
+				title = article.Title
 			}
-			if article.TextContent == "" {
+			if textContent == "" {
 				utils.FileLog.Error("暂不支持纯文本以外的内容生成")
 				err = fmt.Errorf("暂不支持纯文本以外的内容生成")
 				return
 			}
 			// 文章加入到知识库
-			path, fileErr := localService.CreateArticleFile(article)
+			path, fileErr := localService.CreateArticleFile(title, textContent)
 			if fileErr != nil {
 				utils.FileLog.Error("创建文章文件失败,err: %v", fileErr)
 				err = fmt.Errorf("创建文章文件失败,err: %v", fileErr)
@@ -202,6 +237,7 @@ type LLMKnowledgeSearch struct {
 
 type AIGC struct {
 	Promote   string
+	Source    int
 	ArticleId int
 	LLMModel  string
 }

+ 4 - 5
services/llm/promote_service.go

@@ -1,14 +1,13 @@
 package llm
 
 import (
-	"eta/eta_api/models/rag"
 	"eta/eta_api/utils"
 	"fmt"
 	"os"
 )
 
-func CreateArticleFile(item *rag.WechatArticle) (tmpFilePath string, err error) {
-	if item.TextContent == `` {
+func CreateArticleFile(title, textContent string) (tmpFilePath string, err error) {
+	if textContent == `` {
 		err = fmt.Errorf("生成文章原文文本失败,文章内容为空")
 		return
 	}
@@ -19,9 +18,9 @@ func CreateArticleFile(item *rag.WechatArticle) (tmpFilePath string, err error)
 		err = fmt.Errorf("存储目录创建失败,Err:" + err.Error())
 		return
 	}
-	fileName := utils.RemoveSpecialChars(item.Title) + `.md`
+	fileName := utils.RemoveSpecialChars(title) + `.md`
 	tmpFilePath = uploadDir + "/" + fileName
-	err = utils.SaveToFile(item.TextContent, tmpFilePath)
+	err = utils.SaveToFile(textContent, tmpFilePath)
 	if err != nil {
 		err = fmt.Errorf("生成临时文件失败,Err:" + err.Error())
 		return

+ 8 - 7
services/wechat_platform.go

@@ -271,7 +271,7 @@ func GenerateArticleAbstract(item *rag.WechatArticle) {
 
 	// 生成临时文件
 	dateDir := time.Now().Format("20060102")
-	uploadDir := utils.STATIC_DIR + "ai/" + dateDir
+	uploadDir := "./static/ai/" + dateDir
 	err = os.MkdirAll(uploadDir, utils.DIR_MOD)
 	if err != nil {
 		err = fmt.Errorf("存储目录创建失败,Err:" + err.Error())
@@ -382,7 +382,7 @@ func ReGenerateArticleAbstract(item *rag.WechatArticle) {
 
 	// 生成临时文件
 	dateDir := time.Now().Format("20060102")
-	uploadDir := utils.STATIC_DIR + "ai/" + dateDir
+	uploadDir := "./static/ai/" + dateDir
 	err = os.MkdirAll(uploadDir, utils.DIR_MOD)
 	if err != nil {
 		err = fmt.Errorf("存储目录创建失败,Err:" + err.Error())
@@ -610,14 +610,15 @@ func ArticleToKnowledge(item *rag.WechatArticle) {
 
 	// 生成临时文件
 	//dateDir := time.Now().Format("20060102")
-	//uploadDir := utils.STATIC_DIR + "ai/article/" + dateDir
-	uploadDir := utils.STATIC_DIR + "ai/article"
+	//uploadDir :=   "./static/ai/article/" + dateDir
+	uploadDir := "./static/ai/article"
 	err = os.MkdirAll(uploadDir, utils.DIR_MOD)
 	if err != nil {
 		err = fmt.Errorf("存储目录创建失败,Err:" + err.Error())
 		return
 	}
-	fileName := utils.RemoveSpecialChars(item.Title) + `.md`
+	//fileName := utils.RemoveSpecialChars(item.Title) + `.md`
+	fileName := utils.MD5(item.Title) + `.md`
 	tmpFilePath := uploadDir + "/" + fileName
 	err = utils.SaveToFile(item.TextContent, tmpFilePath)
 	if err != nil {
@@ -675,8 +676,8 @@ func AbstractToKnowledge(wechatArticleItem *rag.WechatArticle, abstractItem *rag
 
 	// 生成临时文件
 	//dateDir := time.Now().Format("20060102")
-	//uploadDir := utils.STATIC_DIR + "ai/article/" + dateDir
-	uploadDir := utils.STATIC_DIR + "ai/abstract"
+	//uploadDir :=  + "./static/ai/article/" + dateDir
+	uploadDir := "./static/ai/abstract"
 	err = os.MkdirAll(uploadDir, utils.DIR_MOD)
 	if err != nil {
 		err = fmt.Errorf("存储目录创建失败,Err:" + err.Error())