kobe6258 2 månader sedan
förälder
incheckning
021a4c6b4f

+ 51 - 0
controllers/rag/kb_controller.go

@@ -0,0 +1,51 @@
+package rag
+
+import (
+	"encoding/json"
+	"eta/eta_api/controllers"
+	"eta/eta_api/models"
+	"eta/eta_api/services/llm/facade"
+)
+
+type KbController struct {
+	controllers.BaseAuthController
+}
+
+// SearchDocs  @Title 搜索知识库文档
+// @Description 搜索知识库文档
+// @Success 101 {object} response.ListResp
+// @router /knowledge_base/searchDocs [post]
+func (kbctrl *KbController) SearchDocs() {
+	br := new(models.BaseResponse).Init()
+	defer func() {
+		if br.ErrMsg == "" {
+			br.IsSendEmail = false
+		}
+		kbctrl.Data["json"] = br
+		kbctrl.ServeJSON()
+	}()
+	sysUser := kbctrl.SysUser
+	if sysUser == nil {
+		br.Msg = "请登录"
+		br.ErrMsg = "请登录,SysUser Is Empty"
+		br.Ret = 408
+		return
+	}
+	var req facade.LLMKnowledgeSearch
+	err := json.Unmarshal(kbctrl.Ctx.Input.RequestBody, &req)
+	if err != nil {
+		br.Msg = "参数解析异常!"
+		br.ErrMsg = "参数解析失败,Err:" + err.Error()
+		return
+	}
+	docs, err := facade.LLMKnowledgeBaseSearchDocs(req)
+	if err != nil {
+		br.Msg = "搜索知识库失败"
+		br.ErrMsg = "搜索知识库失败:" + err.Error()
+		return
+	}
+	br.Data = docs
+	br.Ret = 200
+	br.Success = true
+	br.Msg = "获取成功"
+}

+ 1 - 1
main.go

@@ -12,7 +12,7 @@ import (
 	_ "eta/eta_api/routers"
 	"eta/eta_api/services"
 	"eta/eta_api/utils"
-	_ "eta/eta_api/utils/llm/deepseek"
+	_ "eta/eta_api/utils/llm/eta_llm"
 	"github.com/beego/beego/v2/adapter/logs"
 	"github.com/beego/beego/v2/server/web"
 	"github.com/beego/beego/v2/server/web/context"

+ 9 - 0
routers/commentsRouter.go

@@ -8530,6 +8530,15 @@ func init() {
             Filters: nil,
             Params: nil})
 
+    beego.GlobalControllerRouter["eta/eta_api/controllers/rag:KbController"] = append(beego.GlobalControllerRouter["eta/eta_api/controllers/rag:KbController"],
+        beego.ControllerComments{
+            Method: "SearchDocs",
+            Router: `/knowledge_base/searchDocs`,
+            AllowHTTPMethods: []string{"post"},
+            MethodParams: param.Make(),
+            Filters: nil,
+            Params: nil})
+
     beego.GlobalControllerRouter["eta/eta_api/controllers/report_approve:ReportApproveController"] = append(beego.GlobalControllerRouter["eta/eta_api/controllers/report_approve:ReportApproveController"],
         beego.ControllerComments{
             Method: "Approve",

+ 4 - 1
routers/router.go

@@ -54,7 +54,7 @@ func init() {
 		ExposeHeaders:    []string{"Content-Length", "Access-Control-Allow-Origin", "Access-Control-Allow-Headers", "Content-Type"},
 		AllowCredentials: true,
 	}))
-	web.InsertFilter("/adminapi/llm/chat/connect", web.BeforeRouter,services.WsAuthenticate())
+	web.InsertFilter("/adminapi/llm/chat/connect", web.BeforeRouter, services.WsAuthenticate())
 	ns := web.NewNamespace("/adminapi",
 		web.NSNamespace("/sysuser",
 			web.NSInclude(
@@ -73,6 +73,9 @@ func init() {
 			web.NSInclude(
 				&rag.ChatController{},
 			),
+			web.NSInclude(
+				&rag.KbController{},
+			),
 		),
 		web.NSNamespace("/banner",
 			web.NSInclude(

+ 10 - 1
services/llm/facade/llm_service.go

@@ -9,7 +9,7 @@ import (
 )
 
 var (
-	deepseekService, _ = llm.GetInstance(llm.LLM_DEEPSEEK)
+	llmService, _ = llm.GetInstance(llm.ETA_LLM_CLIENT)
 )
 
 func generateSessionCode() (code string) {
@@ -21,3 +21,12 @@ func AddSession(userId int, conn *websocket.Conn) {
 	session := ws.NewSession(userId, sessionId, conn)
 	ws.Manager().AddSession(session)
 }
+
+func LLMKnowledgeBaseSearchDocs(search LLMKnowledgeSearch) (resp string, err error) {
+	return llmService.SearchKbDocs(search.Query, search.KnowledgeBaseName)
+}
+
+type LLMKnowledgeSearch struct {
+	Query             string `json:"Query"`
+	KnowledgeBaseName string `json:"KnowledgeBaseName"`
+}

+ 4 - 2
utils/config.go

@@ -13,7 +13,8 @@ import (
 
 // 大模型配置
 var (
-	DS_LLM_SERVER string //模型服务地址
+	LLM_SERVER string //模型服务地址
+	LLM_MODEL  string
 )
 var (
 	RunMode          string //运行模式
@@ -634,7 +635,8 @@ func init() {
 	ChromePath = config["chrome_path"]
 
 	//模型服务器地址
-	DS_LLM_SERVER = config["llm_server"]
+	LLM_SERVER = config["llm_server"]
+	LLM_MODEL = config["llm_model"]
 	// 初始化ES
 	initEs()
 

+ 0 - 38
utils/llm/deepseek/deekseek.go

@@ -1,38 +0,0 @@
-package deepseek
-
-import (
-	"eta/eta_api/utils"
-	"eta/eta_api/utils/llm"
-	"sync"
-)
-
-var (
-	dsOnce sync.Once
-
-	deepseekClient *DeepSeekClient
-)
-
-type DeepSeekClient struct {
-	*llm.LLMClient
-}
-
-func Getinstance() llm.LLMService {
-	dsOnce.Do(func() {
-		if deepseekClient == nil {
-			deepseekClient = &DeepSeekClient{
-				LLMClient: llm.NewLLMClient(utils.DS_LLM_SERVER, 10),
-			}
-		}
-	})
-	return deepseekClient
-}
-
-func (ds *DeepSeekClient) AskQuestion() string {
-	return ""
-}
-func init() {
-	err := llm.Register(llm.LLM_DEEPSEEK, Getinstance())
-	if err != nil {
-		utils.FileLog.Error("注册deepseek服务失败:", err)
-	}
-}

+ 102 - 0
utils/llm/eta_llm/eta_llm_client.go

@@ -0,0 +1,102 @@
+package eta_llm
+
+import (
+	"bytes"
+	"encoding/json"
+	"errors"
+	"eta/eta_api/utils"
+	"eta/eta_api/utils/llm"
+	"eta/eta_api/utils/llm/eta_llm/eta_llm_http"
+	"fmt"
+	"io"
+	"net/http"
+	"sync"
+)
+
+var (
+	dsOnce sync.Once
+
+	etaLlmClient *ETALLMClient
+)
+
+const (
+	CONTENT_TYPE_JSON              = "application/json"
+	KNOWLEDGE_BASE_CHAT_API        = "/chat/kb_chat"
+	KNOWLEDGE_BASE_SEARCH_DOCS_API = "/knowledge_base/search_docs"
+)
+
+type ETALLMClient struct {
+	*llm.LLMClient
+	LlmModel string
+}
+
+func GetInstance() llm.LLMService {
+	dsOnce.Do(func() {
+		if etaLlmClient == nil {
+			etaLlmClient = &ETALLMClient{
+				LLMClient: llm.NewLLMClient(utils.LLM_SERVER, 10),
+				LlmModel:  utils.LLM_MODEL,
+			}
+		}
+	})
+	return etaLlmClient
+}
+
+func (ds *ETALLMClient) KnowledgeBaseChat() string {
+	ds.HttpClient.Post(ds.BaseURL+KNOWLEDGE_BASE_CHAT_API, CONTENT_TYPE_JSON, nil)
+	return ""
+}
+
+func (ds *ETALLMClient) SearchKbDocs(query string, KnowledgeBaseName string) (content string, err error) {
+	// 类型断言
+	kbReq := eta_llm_http.KbSearchDocsRequest{
+		Query:             query,
+		KnowledgeBaseName: KnowledgeBaseName,
+		Model:             ds.LlmModel,
+		TopK:              3,
+		ScoreThreshold:    2,
+	}
+	body, err := json.Marshal(kbReq)
+	if err != nil {
+		return
+	}
+	resp, err := ds.DoPost(KNOWLEDGE_BASE_SEARCH_DOCS_API, body)
+	if !resp.Success {
+		err = errors.New(resp.Msg)
+		return
+	}
+	return "", nil
+}
+func init() {
+	err := llm.Register(llm.ETA_LLM_CLIENT, GetInstance())
+	if err != nil {
+		utils.FileLog.Error("注册eta_llm_server服务失败:", err)
+	}
+}
+
+func (ds *ETALLMClient) DoPost(apiUrl string, body []byte) (baseResp eta_llm_http.BaseResponse, err error) {
+	requestReader := bytes.NewReader(body)
+	response, err := ds.HttpClient.Post(ds.BaseURL+apiUrl, CONTENT_TYPE_JSON, requestReader)
+	if err != nil {
+		return
+	}
+	return parseResponse(response)
+}
+
+func parseResponse(response *http.Response) (baseResp eta_llm_http.BaseResponse, err error) {
+	defer func() {
+		_ = response.Body.Close()
+	}()
+	baseResp.Ret = response.StatusCode
+	if response.StatusCode != http.StatusOK {
+		baseResp.Msg = fmt.Sprintf("请求失败,状态码:%d, 状态信息:%s", response.StatusCode, http.StatusText(response.StatusCode))
+		return
+	}
+	bodyBytes, err := io.ReadAll(response.Body)
+	if err != nil {
+		err = fmt.Errorf("读取响应体失败: %w", err)
+		return
+	}
+	baseResp.Data = bodyBytes
+	return
+}

+ 33 - 0
utils/llm/eta_llm/eta_llm_http/request.go

@@ -0,0 +1,33 @@
+package eta_llm_http
+
+import "encoding/json"
+
+type KbChatRequest struct {
+	Query          string           `json:"query"`
+	Mode           string           `json:"mode"`
+	KbName         string           `json:"kb_name"`
+	TopK           int              `json:"top_k"`
+	ScoreThreshold int              `json:"score_threshold"`
+	History        []HistoryContent `json:"history"`
+	Stream         bool             `json:"stream"`
+	Model          string           `json:"model"`
+	Temperature    float32          `json:"temperature"`
+	MaxTokens      int              `json:"max_tokens"`
+	PromptName     string           `json:"prompt_name"`
+	ReturnDirect   bool             `json:"return_direct"`
+}
+
+type HistoryContent struct {
+	Content string `json:"content"`
+	Role    string `json:"role"`
+}
+
+type KbSearchDocsRequest struct {
+	Query             string          `json:"query"`
+	KnowledgeBaseName string          `json:"Knowledge_base_name"`
+	TopK              int             `json:"top_k"`
+	ScoreThreshold    int             `json:"score_threshold"`
+	FileName          string            `json:"file_name"`
+	Model             string          `json:"model"`
+	Metadata          json.RawMessage `json:"metadata"`
+}

+ 10 - 0
utils/llm/eta_llm/eta_llm_http/response.go

@@ -0,0 +1,10 @@
+package eta_llm_http
+
+import "encoding/json"
+
+type BaseResponse struct {
+	Ret     int             `json:"ret"`
+	Msg     string          `json:"msg"`
+	Success bool            `json:"success"`
+	Data    json.RawMessage `json:"data"`
+}

+ 2 - 1
utils/llm/llm_client.go

@@ -20,5 +20,6 @@ func NewLLMClient(baseURL string, timeout time.Duration) *LLMClient {
 }
 
 type LLMService interface {
-	AskQuestion() string
+	KnowledgeBaseChat() string
+	SearchKbDocs(query string, KnowledgeBaseName string) (string, error)
 }

+ 2 - 3
utils/llm/llm_factory.go

@@ -5,11 +5,11 @@ import (
 )
 
 var (
-	llmInstanceMap map[string]LLMService = make(map[string]LLMService)
+	llmInstanceMap = make(map[string]LLMService)
 )
 
 const (
-	LLM_DEEPSEEK = "deepseek"
+	ETA_LLM_CLIENT = "eta_llm"
 )
 
 func Register(name string, llmClient LLMService) (err error) {
@@ -36,4 +36,3 @@ func GetInstance(name string) (llmClient LLMService, err error) {
 	llmClient = llmInstanceMap[name]
 	return
 }
-