|
@@ -0,0 +1,102 @@
|
|
|
+package eta_llm
|
|
|
+
|
|
|
+import (
|
|
|
+ "bytes"
|
|
|
+ "encoding/json"
|
|
|
+ "errors"
|
|
|
+ "eta/eta_api/utils"
|
|
|
+ "eta/eta_api/utils/llm"
|
|
|
+ "eta/eta_api/utils/llm/eta_llm/eta_llm_http"
|
|
|
+ "fmt"
|
|
|
+ "io"
|
|
|
+ "net/http"
|
|
|
+ "sync"
|
|
|
+)
|
|
|
+
|
|
|
+var (
|
|
|
+ dsOnce sync.Once
|
|
|
+
|
|
|
+ etaLlmClient *ETALLMClient
|
|
|
+)
|
|
|
+
|
|
|
+const (
|
|
|
+ CONTENT_TYPE_JSON = "application/json"
|
|
|
+ KNOWLEDGE_BASE_CHAT_API = "/chat/kb_chat"
|
|
|
+ KNOWLEDGE_BASE_SEARCH_DOCS_API = "/knowledge_base/search_docs"
|
|
|
+)
|
|
|
+
|
|
|
+type ETALLMClient struct {
|
|
|
+ *llm.LLMClient
|
|
|
+ LlmModel string
|
|
|
+}
|
|
|
+
|
|
|
+func GetInstance() llm.LLMService {
|
|
|
+ dsOnce.Do(func() {
|
|
|
+ if etaLlmClient == nil {
|
|
|
+ etaLlmClient = &ETALLMClient{
|
|
|
+ LLMClient: llm.NewLLMClient(utils.LLM_SERVER, 10),
|
|
|
+ LlmModel: utils.LLM_MODEL,
|
|
|
+ }
|
|
|
+ }
|
|
|
+ })
|
|
|
+ return etaLlmClient
|
|
|
+}
|
|
|
+
|
|
|
+func (ds *ETALLMClient) KnowledgeBaseChat() string {
|
|
|
+ ds.HttpClient.Post(ds.BaseURL+KNOWLEDGE_BASE_CHAT_API, CONTENT_TYPE_JSON, nil)
|
|
|
+ return ""
|
|
|
+}
|
|
|
+
|
|
|
+func (ds *ETALLMClient) SearchKbDocs(query string, KnowledgeBaseName string) (content string, err error) {
|
|
|
+ // 类型断言
|
|
|
+ kbReq := eta_llm_http.KbSearchDocsRequest{
|
|
|
+ Query: query,
|
|
|
+ KnowledgeBaseName: KnowledgeBaseName,
|
|
|
+ Model: ds.LlmModel,
|
|
|
+ TopK: 3,
|
|
|
+ ScoreThreshold: 2,
|
|
|
+ }
|
|
|
+ body, err := json.Marshal(kbReq)
|
|
|
+ if err != nil {
|
|
|
+ return
|
|
|
+ }
|
|
|
+ resp, err := ds.DoPost(KNOWLEDGE_BASE_SEARCH_DOCS_API, body)
|
|
|
+ if !resp.Success {
|
|
|
+ err = errors.New(resp.Msg)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ return "", nil
|
|
|
+}
|
|
|
+func init() {
|
|
|
+ err := llm.Register(llm.ETA_LLM_CLIENT, GetInstance())
|
|
|
+ if err != nil {
|
|
|
+ utils.FileLog.Error("注册eta_llm_server服务失败:", err)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (ds *ETALLMClient) DoPost(apiUrl string, body []byte) (baseResp eta_llm_http.BaseResponse, err error) {
|
|
|
+ requestReader := bytes.NewReader(body)
|
|
|
+ response, err := ds.HttpClient.Post(ds.BaseURL+apiUrl, CONTENT_TYPE_JSON, requestReader)
|
|
|
+ if err != nil {
|
|
|
+ return
|
|
|
+ }
|
|
|
+ return parseResponse(response)
|
|
|
+}
|
|
|
+
|
|
|
+func parseResponse(response *http.Response) (baseResp eta_llm_http.BaseResponse, err error) {
|
|
|
+ defer func() {
|
|
|
+ _ = response.Body.Close()
|
|
|
+ }()
|
|
|
+ baseResp.Ret = response.StatusCode
|
|
|
+ if response.StatusCode != http.StatusOK {
|
|
|
+ baseResp.Msg = fmt.Sprintf("请求失败,状态码:%d, 状态信息:%s", response.StatusCode, http.StatusText(response.StatusCode))
|
|
|
+ return
|
|
|
+ }
|
|
|
+ bodyBytes, err := io.ReadAll(response.Body)
|
|
|
+ if err != nil {
|
|
|
+ err = fmt.Errorf("读取响应体失败: %w", err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ baseResp.Data = bodyBytes
|
|
|
+ return
|
|
|
+}
|