浏览代码

提示词内容生成

kobe6258 2 周之前
父节点
当前提交
e9537063f5

+ 6 - 1
controllers/llm/llm_http/request.go

@@ -17,4 +17,9 @@ type UserChatRecordReq struct {
 	Content      string `json:"Content" description:"会话名称"`
 	ChatUserType string `json:"ChatUserType" description:"用户类型"`
 	SendTime     string `json:"SendTime" description:"发送时间"`
-}
+}
+
+type GenerateContentReq struct {
+	WechatArticleId int    `json:"WechatArticleId" description:"公众号Id"`
+	Promote         string `json:"Promote" description:"提示词"`
+}

+ 106 - 0
controllers/llm/promote_controller.go

@@ -0,0 +1,106 @@
+package llm
+
+import (
+	"encoding/json"
+	"eta/eta_api/controllers"
+	"eta/eta_api/controllers/llm/llm_http"
+	"eta/eta_api/models"
+	"eta/eta_api/services/llm/facade"
+	"eta/eta_api/utils"
+)
+
+type PromoteController struct {
+	controllers.BaseAuthController
+}
+
+// PromoteTrainRecordList @Title 获取聊天记录
+// @Description 获取聊天记录
+// @Success 101 {object} response.ListResp
+// @router /promote/train_list [get]
+func (pCtrl *PromoteController) PromoteTrainRecordList() {
+	br := new(models.BaseResponse).Init()
+	defer func() {
+		pCtrl.Data["json"] = br
+		pCtrl.ServeJSON()
+	}()
+	sysUser := pCtrl.SysUser
+	if sysUser == nil {
+		br.Msg = "请登录"
+		br.ErrMsg = "请登录,SysUser Is Empty"
+		br.Ret = 408
+		return
+	}
+
+	pageSize, _ := pCtrl.GetInt("PageSize")
+	currentIndex, _ := pCtrl.GetInt("CurrentIndex")
+
+	//var total, startSize int
+	if pageSize <= 0 {
+		pageSize = utils.PageSize5
+	}
+	if currentIndex <= 0 {
+		currentIndex = 1
+	}
+	//startSize = paging.StartIndex(currentIndex, pageSize)
+	//page := paging.GetPaging(currentIndex, pageSize, total)
+	//total, err := rag.CountQuestionList()
+	//if err != nil {
+	//	br.Msg = "获取失败"
+	//	br.ErrMsg = "获取失败,Err:" + err.Error()
+	//	return
+	//}
+	//list, err := models.GetPptV2List(condition, pars, startSize, pageSize)
+	//if err != nil {
+	//	br.Msg = "获取失败"
+	//	br.ErrMsg = "获取失败,Err:" + err.Error()
+	//	return
+	//}
+	br.Data = nil
+	br.Ret = 200
+	br.Success = true
+	br.Msg = "获取聊天记录成功"
+}
+
+// GenerateContent @Title 生成问答内容
+// @Description 生成问答内容
+// @Success 101 {object} response.ListResp
+// @router /promote/generate_content [get]
+func (pCtrl *PromoteController) GenerateContent() {
+	br := new(models.BaseResponse).Init()
+	defer func() {
+		pCtrl.Data["json"] = br
+		pCtrl.ServeJSON()
+	}()
+	var gcReq llm_http.GenerateContentReq
+	err := json.Unmarshal(pCtrl.Ctx.Input.RequestBody, &gcReq)
+	if err != nil {
+		br.Msg = "参数解析异常!"
+		br.ErrMsg = "参数解析失败,Err:" + err.Error()
+		return
+	}
+	sysUser := pCtrl.SysUser
+	if sysUser == nil {
+		br.Msg = "请登录"
+		br.ErrMsg = "请登录,SysUser Is Empty"
+		br.Ret = 408
+		return
+	}
+	if gcReq.Promote == "" {
+		br.Msg = "提示词不能为空"
+		br.ErrMsg = "提示词不能为空"
+		return
+	}
+	if gcReq.WechatArticleId <= 0 {
+		br.Msg = "公众号文章编号非法"
+		br.ErrMsg = "公众号文章编号非法"
+		return
+	}
+	facade.AIGCBaseOnPromote(facade.AIGC{
+		Promote:     gcReq.Promote,
+		ArticleId: gcReq.WechatArticleId,
+	})
+	br.Data = nil
+	br.Ret = 200
+	br.Success = true
+	br.Msg = "获取聊天记录成功"
+}

+ 2 - 0
controllers/llm/user_chat_controller.go

@@ -330,3 +330,5 @@ func (ucCtrl *UserChatController) ChatRecordList() {
 	br.Success = true
 	br.Msg = "获取聊天记录成功"
 }
+
+

+ 37 - 0
models/rag/article_kb_mapping.go

@@ -0,0 +1,37 @@
+package rag
+
+import (
+	"eta/eta_api/global"
+	"eta/eta_api/utils"
+	"gorm.io/gorm/clause"
+	"time"
+)
+
+type ArticleKbMapping struct {
+	Id              int `gorm:"id;primaryKey"`
+	WechatArticleId int
+	KbId            string
+	CreatedTime     time.Time
+	UpdateTime      time.Time
+}
+
+func (a *ArticleKbMapping) TableName() string {
+	return "article_kb_mapping"
+}
+
+func GetArticleKbMapping(articleId int) (articleKbMapping *ArticleKbMapping, err error) {
+	err = global.DbMap[utils.DbNameAI].Where("wechat_article_id = ?", articleId).First(&articleKbMapping).Error
+	return
+}
+
+func CreateArticleKbMapping(articleKbMapping ArticleKbMapping) (err error) {
+	db := global.DbMap[utils.DbNameAI]
+	db.Clauses(
+		clause.OnConflict{
+			DoNothing: true,
+			Columns:   []clause.Column{{Name: "wechat_article_id"}},
+		},
+	)
+	err = global.DbMap[utils.DbNameAI].Create(&articleKbMapping).Error
+	return
+}

+ 4 - 1
models/rag/wechat_article.go

@@ -155,7 +155,10 @@ func (m *WechatArticle) GetById(id int) (item *WechatArticle, err error) {
 
 	return
 }
-
+func GetArticleById(id int) (item *WechatArticle, err error) {
+	err = global.DbMap[utils.DbNameAI].Where(fmt.Sprintf("%s = ?", WechatArticleColumns.WechatArticleID), id).First(&item).Error
+	return
+}
 func (m *WechatArticle) GetByLink(link string) (item *WechatArticle, err error) {
 	err = global.DbMap[utils.DbNameAI].Where(fmt.Sprintf("%s = ?", WechatArticleColumns.Link), link).First(&item).Error
 

+ 18 - 0
routers/commentsRouter.go

@@ -8404,6 +8404,24 @@ func init() {
             Filters: nil,
             Params: nil})
 
+    beego.GlobalControllerRouter["eta/eta_api/controllers/llm:PromoteController"] = append(beego.GlobalControllerRouter["eta/eta_api/controllers/llm:PromoteController"],
+        beego.ControllerComments{
+            Method: "GenerateContent",
+            Router: `/promote/generate_content`,
+            AllowHTTPMethods: []string{"get"},
+            MethodParams: param.Make(),
+            Filters: nil,
+            Params: nil})
+
+    beego.GlobalControllerRouter["eta/eta_api/controllers/llm:PromoteController"] = append(beego.GlobalControllerRouter["eta/eta_api/controllers/llm:PromoteController"],
+        beego.ControllerComments{
+            Method: "PromoteTrainRecordList",
+            Router: `/promote/train_list`,
+            AllowHTTPMethods: []string{"get"},
+            MethodParams: param.Make(),
+            Filters: nil,
+            Params: nil})
+
     beego.GlobalControllerRouter["eta/eta_api/controllers/llm:QuestionController"] = append(beego.GlobalControllerRouter["eta/eta_api/controllers/llm:QuestionController"],
         beego.ControllerComments{
             Method: "Add",

+ 1 - 0
routers/router.go

@@ -75,6 +75,7 @@ func init() {
 				&llm.WechatPlatformController{},
 				&llm.QuestionController{},
 				&llm.AbstractController{},
+				&llm.PromoteController{},
 			),
 		),
 		web.NSNamespace("/banner",

+ 4 - 0
services/llm/facade/bus_response/eta_response.go

@@ -6,3 +6,7 @@ type SearchDocsEtaResponse struct {
 	Content string
 	Docs    []eta_llm_http.SearchDocsResponse
 }
+type AIGCEtaResponse struct {
+	Answer string   `json:"answer"`
+	Docs   []string `json:"docs"`
+}

+ 132 - 0
services/llm/facade/llm_service.go

@@ -1,13 +1,20 @@
 package facade
 
 import (
+	"encoding/json"
+	"errors"
+	"eta/eta_api/models/rag"
+	localService "eta/eta_api/services/llm"
 	"eta/eta_api/services/llm/facade/bus_response"
+	"eta/eta_api/utils"
 	"eta/eta_api/utils/llm"
 	"eta/eta_api/utils/llm/eta_llm/eta_llm_http"
 	"eta/eta_api/utils/ws"
 	"fmt"
 	"github.com/gorilla/websocket"
 	"github.com/rdlucklib/rdluck_tools/uuid"
+	"gorm.io/gorm"
+	"os"
 )
 
 var (
@@ -38,7 +45,132 @@ func LLMKnowledgeBaseSearchDocs(search LLMKnowledgeSearch) (resp bus_response.Se
 	return
 }
 
+// AIGCBaseOnPromote aigc 生成内容
+func AIGCBaseOnPromote(aigc AIGC) (resp bus_response.AIGCEtaResponse, err error) {
+	mapping, queryErr := rag.GetArticleKbMapping(aigc.ArticleId)
+	if queryErr != nil && !errors.Is(queryErr, gorm.ErrRecordNotFound) {
+		utils.FileLog.Error("获取文章知识库信息失败,err: %v", queryErr)
+		err = fmt.Errorf("获取文章知识库信息失败,err: %v", queryErr)
+		return
+	} else {
+		var kbId string
+		var file *os.File
+		var params map[string]interface{}
+		if mapping.Id == 0 || mapping.KbId == "" {
+			article, fileErr := rag.GetArticleById(aigc.ArticleId)
+			if fileErr != nil {
+				// 找不到就处理失败
+				utils.FileLog.Error("公众号文章不存在")
+				err = fmt.Errorf("公众号文章不存在")
+				return
+			}
+			// 文章加入到知识库
+			path, fileErr := localService.CreateArticleFile(article)
+			if fileErr != nil {
+				utils.FileLog.Error("创建文章文件失败,err: %v", fileErr)
+				err = fmt.Errorf("创建文章文件失败,err: %v", fileErr)
+				return
+			}
+			defer func() {
+				_ = os.Remove(path)
+			}()
+			file, err = os.Open(path)
+			if err != nil {
+				utils.FileLog.Error("打开文件失败,err:", err)
+				return
+			}
+			uploadResp, httpErr := llmService.UploadFileToTemplate([]*os.File{file}, params)
+			if httpErr != nil {
+				utils.FileLog.Error("上传文件失败,err:", err.Error())
+				err = fmt.Errorf("上传文件失败,err:%v", httpErr)
+				return
+			}
+			data := uploadResp.(eta_llm_http.UploadDocsResponse)
+			//保存映射关系到数据库
+			if data.Id == "" {
+				utils.FileLog.Error("上传文件失败,向量库Id获取失败")
+				err = fmt.Errorf("上传文件失败,向量库Id获取失败")
+				return
+			}
+			err = rag.CreateArticleKbMapping(rag.ArticleKbMapping{
+				WechatArticleId: aigc.ArticleId,
+				KbId:            data.Id,
+			})
+			if err != nil {
+				utils.FileLog.Warn("创建文章知识库映射关系失败,err:", err.Error())
+			}
+			kbId = data.Id
+		} else {
+			kbId = mapping.KbId
+		}
+		//知识库对话
+		response, httpErr := llmService.FileChat(aigc.Promote, kbId, nil)
+		if httpErr != nil {
+			utils.FileLog.Error("内容生成失败,err:", err.Error())
+			err = fmt.Errorf("内容生成失败,err:%v", httpErr)
+			return
+		}
+		if !response.Success {
+			utils.FileLog.Error("内容生成失败,code:%v,msg:%v", response.Ret, response.Msg)
+			err = fmt.Errorf("内容生成失败,code:%v,msg:%v", response.Ret, response.Msg)
+			return
+		} else {
+			var baseResp eta_llm_http.RagBaseResponse
+			parseErr := json.Unmarshal(response.Data, &baseResp)
+			if parseErr != nil {
+				utils.FileLog.Error("内容生成失败,code:%v,msg:%v", parseErr)
+				err = fmt.Errorf("内容生成失败,err:%v", parseErr)
+				return
+			}
+			if baseResp.Code != 200 {
+				if baseResp.Code == 404 {
+					params["PrevId"] = kbId
+					_, putErr := llmService.UploadFileToTemplate([]*os.File{file}, params)
+					if putErr != nil {
+						utils.FileLog.Error("内容生成失败,err:", err.Error())
+						err = fmt.Errorf("内容生成失败,err:%v", httpErr)
+						return
+					}
+				} else {
+					utils.FileLog.Error("内容生成失败,code:%v,msg:%v", baseResp.Code, baseResp.Msg)
+					err = fmt.Errorf("内容生成失败,code:%v,msg:%v", baseResp.Code, baseResp.Msg)
+					return
+				}
+			}
+			gcResp, gcErr := llmService.FileChat(aigc.Promote, kbId, nil)
+			if gcErr != nil {
+				utils.FileLog.Error("内容生成失败,err:%v", gcErr.Error())
+				err = fmt.Errorf("内容生成失败,err:%v", gcErr)
+				return
+			}
+			if !gcResp.Success {
+				utils.FileLog.Error("内容生成失败,code:%v,msg:%v", gcResp.Ret, gcResp.Msg)
+				err = fmt.Errorf("内容生成失败,err:%v", gcResp.Msg)
+			}
+			var steamResp eta_llm_http.ContentResponse
+			parseErr = json.Unmarshal(gcResp.Data, &steamResp)
+			if parseErr != nil {
+				utils.FileLog.Error("内容生成失败,code:%v,msg:%v", parseErr)
+				err = fmt.Errorf("内容生成失败,err:%v", parseErr)
+				return
+			}
+			parseErr = json.Unmarshal(steamResp.Data, &resp)
+			if parseErr != nil {
+				utils.FileLog.Error("内容生成失败,code:%v,msg:%v", parseErr)
+				err = fmt.Errorf("内容生成失败,err:%v", parseErr)
+				return
+			}
+		}
+	}
+	return
+}
+
 type LLMKnowledgeSearch struct {
 	Query             string `json:"Query"`
 	KnowledgeBaseName string `json:"KnowledgeBaseName"`
 }
+
+type AIGC struct {
+	Promote   string
+	ArticleId int
+}

+ 40 - 0
services/llm/promote_service.go

@@ -0,0 +1,40 @@
+package llm
+
+import (
+	"eta/eta_api/models/rag"
+	"eta/eta_api/utils"
+	"fmt"
+	"os"
+)
+
+func CreateArticleFile(item *rag.WechatArticle) (filePath string, err error) {
+	if item.TextContent == `` {
+		err = fmt.Errorf("生成文章原文文本失败,文章内容为空")
+		return
+	}
+	defer func() {
+		if err != nil {
+			utils.FileLog.Error("上传文章原文到知识库失败,err:%v", err)
+			fmt.Println("上传文章原文到知识库失败,err:", err)
+		}
+	}()
+
+	// 生成临时文件
+	uploadDir := utils.STATIC_DIR + "ai/article"
+	err = os.MkdirAll(uploadDir, utils.DIR_MOD)
+	if err != nil {
+		err = fmt.Errorf("存储目录创建失败,Err:" + err.Error())
+		return
+	}
+	fileName := utils.RemoveSpecialChars(item.Title) + `.md`
+	tmpFilePath := uploadDir + "/" + fileName
+	err = utils.SaveToFile(item.TextContent, tmpFilePath)
+	if err != nil {
+		err = fmt.Errorf("生成临时文件失败,Err:" + err.Error())
+		return
+	}
+	defer func() {
+		_ = os.Remove(tmpFilePath)
+	}()
+	return
+}

+ 76 - 0
utils/llm/eta_llm/eta_llm_client.go

@@ -12,6 +12,7 @@ import (
 	"fmt"
 	"io"
 	"net/http"
+	"os"
 	"strings"
 	"sync"
 )
@@ -29,6 +30,7 @@ const (
 	KNOWLEDGE_BASE_CHAT_API        = "/chat/kb_chat"
 	DOCUMENT_CHAT_API              = "/chat/file_chat"
 	KNOWLEDGE_BASE_SEARCH_DOCS_API = "/knowledge_base/search_docs"
+	UPLOAD_TEMP_DOCS_API           = "/knowledge_base/upload_temp_docs"
 )
 
 type ETALLMClient struct {
@@ -137,6 +139,80 @@ func (ds *ETALLMClient) KnowledgeBaseChat(query string, KnowledgeBaseName string
 	return ds.DoStreamPost(KNOWLEDGE_BASE_CHAT_API, body)
 }
 
+func (ds *ETALLMClient) FileChat(query string, KnowledgeId string, history []json.RawMessage) (resp eta_llm_http.BaseResponse, err error) {
+	ChatHistory := make([]eta_llm_http.HistoryContent, 0)
+	for _, historyItemStr := range history {
+		var historyItem eta_llm_http.HistoryContentWeb
+		parseErr := json.Unmarshal(historyItemStr, &historyItem)
+		if parseErr != nil {
+			continue
+		}
+		ChatHistory = append(ChatHistory, eta_llm_http.HistoryContent{
+			Content: historyItem.Content,
+			Role:    historyItem.Role,
+		})
+	}
+	kbReq := eta_llm_http.DocumentChatRequest{
+		ModelName:      ds.LlmModel,
+		Query:          query,
+		KnowledgeId:    KnowledgeId,
+		History:        ChatHistory,
+		TopK:           3,
+		ScoreThreshold: 0.5,
+		Stream:         false,
+		Temperature:    0.7,
+		MaxTokens:      0,
+		PromptName:     DEFALUT_PROMPT_NAME,
+	}
+	body, err := json.Marshal(kbReq)
+	if err != nil {
+		utils.FileLog.Error("内容生成失败,序列化请求参数失败,err", err.Error())
+		err = fmt.Errorf("内容生成失败,序列化请求参数失败,err:%v", err)
+		return
+	}
+	return ds.DoPost(DOCUMENT_CHAT_API, body)
+}
+
+func (ds *ETALLMClient) UploadFileToTemplate(files []*os.File, param map[string]interface{}) (data interface{}, err error) {
+	pervId := param["PrevId"].(string)
+	docReq := eta_llm_http.UploadTempDocsRequest{
+		ChunkOverlap:   750,
+		ChunkSize:      150,
+		Files:          files,
+		PrevId:         pervId,
+		ZhTitleEnhance: false,
+	}
+	body, err := json.Marshal(docReq)
+	if err != nil {
+		return
+	}
+	resp, err := ds.DoPost(UPLOAD_TEMP_DOCS_API, body)
+	if !resp.Success {
+		err = errors.New(resp.Msg)
+		return
+	}
+	if resp.Data != nil {
+		var uploadDocsRes eta_llm_http.RagBaseResponse
+		err = json.Unmarshal(resp.Data, &uploadDocsRes)
+		if err != nil {
+			err = errors.New("上传临时文件失败,err:" + err.Error())
+			return
+		}
+		if uploadDocsRes.Code != 200 {
+			err = errors.New("上传临时文件失败,err:" + uploadDocsRes.Msg)
+			return
+		}
+		var uploadResult eta_llm_http.UploadDocsResponse
+		err = json.Unmarshal(uploadDocsRes.Data, &uploadResult)
+		if len(uploadResult.FiledFiles) > 0 {
+			utils.FileLog.Warn("上传临时文件失败:", uploadResult.FiledFiles)
+		}
+		data = uploadResult
+		return
+	}
+	return
+}
+
 func (ds *ETALLMClient) SearchKbDocs(query string, KnowledgeBaseName string) (content interface{}, err error) {
 	kbReq := eta_llm_http.KbSearchDocsRequest{
 		Query:             query,

+ 10 - 0
utils/llm/eta_llm/eta_llm_http/request.go

@@ -1,5 +1,7 @@
 package eta_llm_http
 
+import "os"
+
 type KbChatRequest struct {
 	Query          string           `json:"query"`
 	Mode           string           `json:"mode"`
@@ -42,3 +44,11 @@ type KbSearchDocsRequest struct {
 	FileName          string      `json:"file_name"`
 	Metadata          interface{} `json:"metadata"`
 }
+
+type UploadTempDocsRequest struct {
+	Files          []*os.File `json:"files"`
+	PrevId         string     `json:"prev_id"`
+	ChunkSize      int        `json:"chunk_size"`
+	ChunkOverlap   int        `json:"chunk_overlap"`
+	ZhTitleEnhance bool       `json:"zh_title_enhance"`
+}

+ 15 - 1
utils/llm/eta_llm/eta_llm_http/response.go

@@ -9,8 +9,13 @@ type BaseResponse struct {
 	Data    json.RawMessage `json:"data"`
 }
 type SteamResponse struct {
-	Data    ChunkResponse `json:"data"`
+	Data ChunkResponse `json:"data"`
 }
+
+type ContentResponse struct {
+	Data json.RawMessage `json:"data"`
+}
+
 // ChunkResponse 定义流式响应的结构体
 type ChunkResponse struct {
 	ID          string   `json:"id"`
@@ -24,6 +29,11 @@ type ChunkResponse struct {
 	Docs        []string `json:"docs"`
 	Choices     []Choice `json:"choices"`
 }
+type RagBaseResponse struct {
+	Data json.RawMessage `json:"data"`
+	Msg  string          `json:"msg"`
+	Code int             `json:"code"`
+}
 
 // Choice 定义选择的结构体
 type Choice struct {
@@ -61,3 +71,7 @@ type Metadata struct {
 	Source string `json:"source"`
 	Id     string `json:"id"`
 }
+type UploadDocsResponse struct {
+	Id         string   `json:"id"`
+	FiledFiles []string `json:"filed_files"`
+}

+ 4 - 0
utils/llm/llm_client.go

@@ -2,7 +2,9 @@ package llm
 
 import (
 	"encoding/json"
+	"eta/eta_api/utils/llm/eta_llm/eta_llm_http"
 	"net/http"
+	"os"
 	"time"
 )
 
@@ -24,4 +26,6 @@ type LLMService interface {
 	KnowledgeBaseChat(query string, KnowledgeBaseName string, history []json.RawMessage) (llmRes *http.Response, err error)
 	DocumentChat(query string, KnowledgeId string, history []json.RawMessage, stream bool) (llmRes *http.Response, err error)
 	SearchKbDocs(query string, KnowledgeBaseName string) (data interface{}, err error)
+	UploadFileToTemplate(files []*os.File, param map[string]interface{}) (data interface{}, err error)
+	FileChat(query string, KnowledgeId string, history []json.RawMessage) (resp eta_llm_http.BaseResponse, err error)
 }