Roc 4 天之前
父节点
当前提交
55eeb4b48e
共有 4 个文件被更改,包括 282 次插入90 次删除
  1. 10 0
      controllers/llm/abstract.go
  2. 23 1
      models/rag/tag.go
  3. 53 11
      services/llm_report.go
  4. 196 78
      services/wechat_platform.go

+ 10 - 0
controllers/llm/abstract.go

@@ -418,3 +418,13 @@ func (c *AbstractController) AddVector() {
 	br.Success = true
 	br.Msg = `添加向量库中,请稍后查看`
 }
+
+func init() {
+	obj := rag.WechatArticle{}
+	item, tmpErr := obj.GetById(1722)
+	if tmpErr != nil {
+		// 找不到就处理失败
+		return
+	}
+	services.GenerateWechatArticleAbstract(item, false)
+}

+ 23 - 1
models/rag/tag.go

@@ -5,6 +5,7 @@ import (
 	"eta/eta_api/global"
 	"eta/eta_api/utils"
 	"fmt"
+	"strings"
 	"time"
 )
 
@@ -37,6 +38,12 @@ var TagColumns = struct {
 	CreateTime: "create_time",
 }
 
+func (m *Tag) Create() (err error) {
+	err = global.DbMap[utils.DbNameAI].Create(&m).Error
+
+	return
+}
+
 func (m *Tag) GetByID(TagId int) (item *Tag, err error) {
 	err = global.DbMap[utils.DbNameAI].Where(fmt.Sprintf("%s = ?", TagColumns.TagID), TagId).First(&item).Error
 
@@ -70,7 +77,6 @@ func (m *Tag) GetCountByCondition(condition string, pars []interface{}) (total i
 }
 
 func (m *Tag) GetPageListByCondition(condition string, pars []interface{}, startSize, pageSize int) (total int, items []*Tag, err error) {
-
 	total, err = m.GetCountByCondition(condition, pars)
 	if err != nil {
 		return
@@ -81,3 +87,19 @@ func (m *Tag) GetPageListByCondition(condition string, pars []interface{}, start
 
 	return
 }
+
+var aiAbstractTagMap map[string]int
+
+func (m *Tag) GetTagIdByName(tagName string) (tagId int, err error) {
+	tagName = strings.TrimSpace(tagName)
+	tagId, ok := aiAbstractTagMap[tagName]
+	if ok {
+		return
+	}
+
+	var item *Tag
+	sqlStr := fmt.Sprintf(`SELECT * FROM %s WHERE %s = ? `, m.TableName(), TagColumns.TagName)
+	err = global.DbMap[utils.DbNameAI].Raw(sqlStr, tagName).First(&item).Error
+
+	return
+}

+ 53 - 11
services/llm_report.go

@@ -1,6 +1,7 @@
 package services
 
 import (
+	"encoding/json"
 	"eta/eta_api/models"
 	"eta/eta_api/models/rag"
 	"eta/eta_api/services/elastic"
@@ -348,7 +349,7 @@ func GenerateRagEtaReportAbstractByQuestion(item *rag.RagEtaReport, question *ra
 	//你现在是一名资深的期货行业分析师,请基于以下的问题进行汇总总结,如果不能正常总结出来,那么就只需要回复我:sorry
 	questionStr := fmt.Sprintf(`%s\n%s`, `你现在是一名资深的期货行业分析师,请基于以下的问题进行汇总总结,如果不能正常总结出来,那么就只需要回复我:sorry。以下是问题:`, question.QuestionContent)
 	//开始对话
-	abstract, _, tmpErr := getAnswerByContent(item.RagEtaReportId, utils.AI_ARTICLE_SOURCE_ETA_REPORT, questionStr)
+	abstract, industryTags, _, tmpErr := getAnswerByContent(item.RagEtaReportId, utils.AI_ARTICLE_SOURCE_ETA_REPORT, questionStr)
 	if tmpErr != nil {
 		err = fmt.Errorf("LLM对话失败,Err:" + tmpErr.Error())
 		return
@@ -377,6 +378,51 @@ func GenerateRagEtaReportAbstractByQuestion(item *rag.RagEtaReport, question *ra
 	//item.ModifyTime = time.Now()
 	//err = item.Update([]string{"AbstractStatus", "ModifyTime"})
 
+	var tagIdJsonStr string
+	// 标签ID
+	{
+		tagIdList := make([]int, 0)
+		tagIdMap := make(map[int]bool)
+
+		if abstractItem != nil && abstractItem.Tags != `` {
+			tmpErr = json.Unmarshal([]byte(abstractItem.Tags), &tagIdList)
+			if tmpErr != nil {
+				utils.FileLog.Info(fmt.Sprintf("json.Unmarshal 失败,标签数据:%s,Err:%s", abstractItem.Tags, tmpErr.Error()))
+			} else {
+				for _, tagId := range tagIdList {
+					tagIdMap[tagId] = true
+				}
+			}
+		}
+		for _, tagName := range industryTags {
+			tagId, tmpErr := GetTagIdByName(tagName)
+			if tmpErr != nil {
+				utils.FileLog.Info(fmt.Sprintf("获取标签ID失败,标签名称:%s,Err:%s", tagName, tmpErr.Error()))
+			}
+			if _, ok := tagIdMap[tagId]; !ok {
+				tagIdList = append(tagIdList, tagId)
+				tagIdMap[tagId] = true
+			}
+		}
+		//for _, tagName := range varietyTags {
+		//	tagId, tmpErr := GetTagIdByName(tagName)
+		//	if tmpErr != nil {
+		//		utils.FileLog.Info(fmt.Sprintf("获取标签ID失败,标签名称:%s,Err:%s", tagName, tmpErr.Error()))
+		//	}
+		//	if _, ok := tagIdMap[tagId]; !ok {
+		//		tagIdList = append(tagIdList, tagId)
+		//		tagIdMap[tagId] = true
+		//	}
+		//}
+
+		tagIdJsonByte, err := json.Marshal(tagIdList)
+		if err != nil {
+			utils.FileLog.Info(fmt.Sprintf("标签ID序列化失败,Err:%s", tmpErr.Error()))
+		} else {
+			tagIdJsonStr = string(tagIdJsonByte)
+		}
+	}
+
 	if abstractItem == nil || abstractItem.RagEtaReportAbstractId <= 0 {
 		abstractItem = &rag.RagEtaReportAbstract{
 			RagEtaReportAbstractId: 0,
@@ -385,7 +431,7 @@ func GenerateRagEtaReportAbstractByQuestion(item *rag.RagEtaReport, question *ra
 			QuestionId:             question.QuestionId,
 			QuestionContent:        question.QuestionContent,
 			Version:                1,
-			Tags:                   "",
+			Tags:                   tagIdJsonStr,
 			VectorKey:              "",
 			ModifyTime:             time.Now(),
 			CreateTime:             time.Now(),
@@ -410,7 +456,7 @@ func GenerateRagEtaReportAbstractByQuestion(item *rag.RagEtaReport, question *ra
 	ReportAbstractToKnowledge(item, abstractItem, false)
 }
 
-// AddOrEditEsWechatArticleAbstract
+// AddOrEditEsRagEtaReportAbstract
 // @Description: 新增/编辑微信文章摘要入ES
 // @author: Roc
 // @datetime 2025-03-13 14:13:47
@@ -442,14 +488,10 @@ func AddOrEditEsRagEtaReportAbstract(ragEtaReportAbstractId int) {
 
 	tagIdList := make([]int, 0)
 	if abstractInfo.Tags != `` {
-		tagIdStrList := strings.Split(abstractInfo.Tags, ",")
-		for _, tagIdStr := range tagIdStrList {
-			tagId, tmpErr := strconv.Atoi(tagIdStr)
-			if tmpErr != nil {
-				err = fmt.Errorf("报告标签ID转int失败,Err:" + tmpErr.Error())
-				return
-			}
-			tagIdList = append(tagIdList, tagId)
+		err = json.Unmarshal([]byte(abstractInfo.Tags), &tagIdList)
+		if err != nil {
+			err = fmt.Errorf("报告标签ID转int失败,Err:" + err.Error())
+			utils.FileLog.Info(fmt.Sprintf("json.Unmarshal 报告标签ID转int失败,标签数据:%s,Err:%s", abstractInfo.Tags, err.Error()))
 		}
 	}
 

+ 196 - 78
services/wechat_platform.go

@@ -16,6 +16,7 @@ import (
 	"html"
 	"os"
 	"path"
+	"regexp"
 	"strconv"
 	"strings"
 	"time"
@@ -368,8 +369,8 @@ func GenerateWechatArticleAbstractByQuestion(item *rag.WechatArticle, question *
 	var err error
 	defer func() {
 		if err != nil {
-			utils.FileLog.Error("文章转临时文件失败,err:%v", err)
-			fmt.Println("文章转临时文件失败,err:", err)
+			utils.FileLog.Error("摘要生成失败,err:%v", err)
+			fmt.Println("摘要生成失败,err:", err)
 		}
 	}()
 
@@ -394,64 +395,102 @@ func GenerateWechatArticleAbstractByQuestion(item *rag.WechatArticle, question *
 	//你现在是一名资深的期货行业分析师,请基于以下的问题进行汇总总结,如果不能正常总结出来,那么就只需要回复我:sorry
 	questionStr := fmt.Sprintf(`%s\n%s`, `你现在是一名资深的期货行业分析师,请基于以下的问题进行汇总总结,如果不能正常总结出来,那么就只需要回复我:sorry。以下是问题:`, question.QuestionContent)
 	//开始对话
-	abstract, addArticleChatRecordList, tmpErr := getAnswerByContent(item.WechatArticleId, utils.AI_ARTICLE_SOURCE_ETA_REPORT, questionStr)
+	abstract, industryTags, _, tmpErr := getAnswerByContent(item.WechatArticleId, utils.AI_ARTICLE_SOURCE_WECHAT, questionStr)
 	if tmpErr != nil {
 		err = fmt.Errorf("LLM对话失败,Err:" + tmpErr.Error())
 		return
 	}
 
-	// 添加问答记录
-	if len(addArticleChatRecordList) > 0 {
-		recordObj := rag.WechatArticleChatRecord{}
-		err = recordObj.CreateInBatches(addArticleChatRecordList)
-		if err != nil {
-			return
-		}
+	if abstract == `` {
+		return
 	}
 
-	if abstract != `` {
-		if abstract == `sorry` || strings.Index(abstract, `根据已知信息无法回答该问题`) == 0 {
-			item.AbstractStatus = 2
-			item.ModifyTime = time.Now()
-			err = item.Update([]string{"AbstractStatus", "ModifyTime"})
-			return
-		}
-		item.AbstractStatus = 1
-		item.ModifyTime = time.Now()
-		err = item.Update([]string{"AbstractStatus", "ModifyTime"})
+	var tagIdJsonStr string
+	// 标签ID
+	{
+		tagIdList := make([]int, 0)
+		tagIdMap := make(map[int]bool)
 
-		if abstractItem == nil || abstractItem.WechatArticleAbstractId <= 0 {
-			abstractItem = &rag.WechatArticleAbstract{
-				WechatArticleAbstractId: 0,
-				WechatArticleId:         item.WechatArticleId,
-				Content:                 abstract,
-				Version:                 1,
-				VectorKey:               "",
-				ModifyTime:              time.Now(),
-				CreateTime:              time.Now(),
-				QuestionId:              question.QuestionId,
-				Tags:                    "",
-				QuestionContent:         question.QuestionContent,
+		if abstractItem != nil && abstractItem.Tags != `` {
+			tmpErr = json.Unmarshal([]byte(abstractItem.Tags), &tagIdList)
+			if tmpErr != nil {
+				utils.FileLog.Info(fmt.Sprintf("json.Unmarshal 失败,标签数据:%s,Err:%s", abstractItem.Tags, tmpErr.Error()))
+			} else {
+				for _, tagId := range tagIdList {
+					tagIdMap[tagId] = true
+				}
 			}
-			err = abstractItem.Create()
-		} else {
-			abstractItem.Content = abstract
-			abstractItem.Version++
-			abstractItem.ModifyTime = time.Now()
-			abstractItem.Tags = ""
-			abstractItem.QuestionContent = question.QuestionContent
-			err = abstractItem.Update([]string{"content", "version", "modify_time", "tags", "question_content"})
 		}
-
+		for _, tagName := range industryTags {
+			tagId, tmpErr := GetTagIdByName(tagName)
+			if tmpErr != nil {
+				utils.FileLog.Info(fmt.Sprintf("获取标签ID失败,标签名称:%s,Err:%s", tagName, tmpErr.Error()))
+			}
+			if _, ok := tagIdMap[tagId]; !ok {
+				tagIdList = append(tagIdList, tagId)
+				tagIdMap[tagId] = true
+			}
+		}
+		//for _, tagName := range varietyTags {
+		//	tagId, tmpErr := GetTagIdByName(tagName)
+		//	if tmpErr != nil {
+		//		utils.FileLog.Info(fmt.Sprintf("获取标签ID失败,标签名称:%s,Err:%s", tagName, tmpErr.Error()))
+		//	}
+		//	if _, ok := tagIdMap[tagId]; !ok {
+		//		tagIdList = append(tagIdList, tagId)
+		//		tagIdMap[tagId] = true
+		//	}
+		//}
+
+		tagIdJsonByte, err := json.Marshal(tagIdList)
 		if err != nil {
-			return
+			utils.FileLog.Info(fmt.Sprintf("标签ID序列化失败,Err:%s", tmpErr.Error()))
+		} else {
+			tagIdJsonStr = string(tagIdJsonByte)
 		}
+	}
 
-		// 数据入ES库
-		go AddOrEditEsWechatArticleAbstract(abstractItem.WechatArticleAbstractId)
+	if abstract == `sorry` || strings.Index(abstract, `根据已知信息无法回答该问题`) == 0 {
+		item.AbstractStatus = 2
+		item.ModifyTime = time.Now()
+		err = item.Update([]string{"AbstractStatus", "ModifyTime"})
+		return
+	}
+	item.AbstractStatus = 1
+	item.ModifyTime = time.Now()
+	err = item.Update([]string{"AbstractStatus", "ModifyTime"})
+
+	if abstractItem == nil || abstractItem.WechatArticleAbstractId <= 0 {
+		abstractItem = &rag.WechatArticleAbstract{
+			WechatArticleAbstractId: 0,
+			WechatArticleId:         item.WechatArticleId,
+			Content:                 abstract,
+			Version:                 1,
+			VectorKey:               "",
+			ModifyTime:              time.Now(),
+			CreateTime:              time.Now(),
+			QuestionId:              question.QuestionId,
+			Tags:                    tagIdJsonStr,
+			QuestionContent:         question.QuestionContent,
+		}
+		err = abstractItem.Create()
+	} else {
+		abstractItem.Content = abstract
+		abstractItem.Version++
+		abstractItem.ModifyTime = time.Now()
+		abstractItem.Tags = ""
+		abstractItem.QuestionContent = question.QuestionContent
+		err = abstractItem.Update([]string{"content", "version", "modify_time", "tags", "question_content"})
+	}
 
-		WechatArticleAbstractToKnowledge(item, abstractItem, false)
+	if err != nil {
+		return
 	}
+
+	// 数据入ES库
+	go AddOrEditEsWechatArticleAbstract(abstractItem.WechatArticleAbstractId)
+
+	WechatArticleAbstractToKnowledge(item, abstractItem, false)
 }
 
 // DelDoc
@@ -527,8 +566,8 @@ func DelLlmDoc(vectorKeyList []string, wechatArticleAbstractIdList []int) (err e
 	return
 }
 
-func getAnswerByContent(articleId int, source int, questionStr string) (answer string, addArticleChatRecordList []*rag.WechatArticleChatRecord, err error) {
-	addArticleChatRecordList = make([]*rag.WechatArticleChatRecord, 0)
+func getAnswerByContent(articleId int, source int, questionStr string) (answer string, industryTags, varietyTags []string, err error) {
+	//addArticleChatRecordList = make([]*rag.WechatArticleChatRecord, 0)
 
 	result, err := facade.AIGCBaseOnPromote(facade.AIGC{
 		Promote:   questionStr,
@@ -541,11 +580,11 @@ func getAnswerByContent(articleId int, source int, questionStr string) (answer s
 	}
 
 	// JSON字符串转字节
-	answerByte, err := json.Marshal(result)
-	if err != nil {
-		return
-	}
-	originalAnswer := string(answerByte)
+	//answerByte, err := json.Marshal(result)
+	//if err != nil {
+	//	return
+	//}
+	//originalAnswer := string(answerByte)
 
 	// 提取 </think> 后面的内容
 	thinkEndIndex := strings.Index(result.Answer, "</think>")
@@ -557,24 +596,27 @@ func getAnswerByContent(articleId int, source int, questionStr string) (answer s
 
 	answer = strings.TrimSpace(answer)
 
-	// 待入库的数据
-	addArticleChatRecordList = append(addArticleChatRecordList, &rag.WechatArticleChatRecord{
-		WechatArticleChatRecordId: 0,
-		WechatArticleId:           articleId,
-		ChatUserType:              "user",
-		Content:                   questionStr,
-		SendTime:                  time.Now(),
-		CreatedTime:               time.Now(),
-		UpdateTime:                time.Now(),
-	}, &rag.WechatArticleChatRecord{
-		WechatArticleChatRecordId: 0,
-		WechatArticleId:           articleId,
-		ChatUserType:              "assistant",
-		Content:                   originalAnswer,
-		SendTime:                  time.Now(),
-		CreatedTime:               time.Now(),
-		UpdateTime:                time.Now(),
-	})
+	// 提取标签
+	industryTags, varietyTags = extractLabels(answer)
+
+	//// 待入库的数据
+	//addArticleChatRecordList = append(addArticleChatRecordList, &rag.WechatArticleChatRecord{
+	//	WechatArticleChatRecordId: 0,
+	//	WechatArticleId:           articleId,
+	//	ChatUserType:              "user",
+	//	Content:                   questionStr,
+	//	SendTime:                  time.Now(),
+	//	CreatedTime:               time.Now(),
+	//	UpdateTime:                time.Now(),
+	//}, &rag.WechatArticleChatRecord{
+	//	WechatArticleChatRecordId: 0,
+	//	WechatArticleId:           articleId,
+	//	ChatUserType:              "assistant",
+	//	Content:                   originalAnswer,
+	//	SendTime:                  time.Now(),
+	//	CreatedTime:               time.Now(),
+	//	UpdateTime:                time.Now(),
+	//})
 
 	return
 }
@@ -1038,14 +1080,10 @@ func AddOrEditEsWechatArticleAbstract(articleAbstractId int) {
 	// 标签ID
 	tagIdList := make([]int, 0)
 	if abstractInfo.Tags != `` {
-		tagIdStrList := strings.Split(abstractInfo.Tags, ",")
-		for _, tagIdStr := range tagIdStrList {
-			tagId, tmpErr := strconv.Atoi(tagIdStr)
-			if tmpErr != nil {
-				err = fmt.Errorf("报告标签ID转int失败,Err:" + tmpErr.Error())
-				return
-			}
-			tagIdList = append(tagIdList, tagId)
+		err = json.Unmarshal([]byte(abstractInfo.Tags), &tagIdList)
+		if err != nil {
+			err = fmt.Errorf("报告标签ID转int失败,Err:" + err.Error())
+			utils.FileLog.Info(fmt.Sprintf("json.Unmarshal 报告标签ID转int失败,标签数据:%s,Err:%s", abstractInfo.Tags, err.Error()))
 		}
 	}
 
@@ -1146,3 +1184,83 @@ func DelEsRagQuestion(questionId int) {
 
 	err = elastic.RagQuestionEsDel(strconv.Itoa(questionId))
 }
+
+// extractLabels
+// @Description: 提取摘要中的标签
+// @author: Roc
+// @datetime 2025-04-18 17:16:05
+// @param text string
+// @return industryTags []string
+// @return varietyTags []string
+func extractLabels(text string) (industryTags []string, varietyTags []string) {
+	reIndustry := regexp.MustCompile(`行业标签((?:【[^】]*】)+)`)
+	industryMatch := reIndustry.FindStringSubmatch(text)
+	if len(industryMatch) > 1 {
+		industryContent := industryMatch[1]
+		reSplit := regexp.MustCompile(`【([^】]*)】`)
+		industryTags = make([]string, 0)
+		for _, m := range reSplit.FindAllStringSubmatch(industryContent, -1) {
+			if len(m) > 1 {
+				industryTags = append(industryTags, m[1])
+			}
+		}
+	}
+
+	reVariety := regexp.MustCompile(`品种标签((?:【[^】]*】)+)`)
+	varietyMatch := reVariety.FindStringSubmatch(text)
+	if len(varietyMatch) > 1 {
+		varietyContent := varietyMatch[1]
+		reSplit := regexp.MustCompile(`【([^】]*)】`)
+		varietyTags = make([]string, 0)
+		for _, m := range reSplit.FindAllStringSubmatch(varietyContent, -1) {
+			if len(m) > 1 {
+				varietyTags = append(varietyTags, m[1])
+			}
+		}
+	}
+	return
+}
+
+var aiAbstractTagMap = map[string]int{}
+
+// GetTagIdByName
+// @Description: 获取标签ID
+// @author: Roc
+// @datetime 2025-04-18 17:25:46
+// @param tagName string
+// @return tagId int
+// @return err error
+func GetTagIdByName(tagName string) (tagId int, err error) {
+	tagName = strings.TrimSpace(tagName)
+	tagId, ok := aiAbstractTagMap[tagName]
+	if ok {
+		return
+	}
+
+	obj := rag.Tag{}
+	item, err := obj.GetByCondition(fmt.Sprintf(` AND  %s = ? `, rag.TagColumns.TagName), []interface{}{tagName})
+	if err != nil {
+		if !utils.IsErrNoRow(err) {
+			err = fmt.Errorf("获取标签失败,Err:" + err.Error())
+			return
+		}
+
+		item = &rag.Tag{
+			TagId:      0,
+			TagName:    tagName,
+			Sort:       0,
+			ModifyTime: time.Now(),
+			CreateTime: time.Now(),
+		}
+		err = item.Create()
+		if err != nil {
+			err = fmt.Errorf("添加标签失败,Err:" + err.Error())
+			return
+		}
+	}
+
+	tagId = item.TagId
+	aiAbstractTagMap[tagName] = tagId
+
+	return
+}