|
@@ -0,0 +1,193 @@
|
|
|
+package eta_llm
|
|
|
+
|
|
|
+import (
|
|
|
+ "bufio"
|
|
|
+ "bytes"
|
|
|
+ "encoding/json"
|
|
|
+ "errors"
|
|
|
+ "eta/eta_api/utils"
|
|
|
+ "eta/eta_api/utils/llm"
|
|
|
+ "eta/eta_api/utils/llm/eta_llm/eta_llm_http"
|
|
|
+ "fmt"
|
|
|
+ "io"
|
|
|
+ "net/http"
|
|
|
+ "strings"
|
|
|
+ "sync"
|
|
|
+)
|
|
|
+
|
|
|
+var (
|
|
|
+ dsOnce sync.Once
|
|
|
+
|
|
|
+ etaLlmClient *ETALLMClient
|
|
|
+)
|
|
|
+
|
|
|
+const (
|
|
|
+ KNOWLEDEG_CHAT_MODE = "local_kb"
|
|
|
+ DEFALUT_PROMPT_NAME = "default"
|
|
|
+ CONTENT_TYPE_JSON = "application/json"
|
|
|
+ KNOWLEDGE_BASE_CHAT_API = "/chat/kb_chat"
|
|
|
+ KNOWLEDGE_BASE_SEARCH_DOCS_API = "/knowledge_base/search_docs"
|
|
|
+)
|
|
|
+
|
|
|
+type ETALLMClient struct {
|
|
|
+ *llm.LLMClient
|
|
|
+ LlmModel string
|
|
|
+}
|
|
|
+
|
|
|
+func GetInstance() llm.LLMService {
|
|
|
+ dsOnce.Do(func() {
|
|
|
+ if etaLlmClient == nil {
|
|
|
+ etaLlmClient = &ETALLMClient{
|
|
|
+ LLMClient: llm.NewLLMClient(utils.LLM_SERVER, 120),
|
|
|
+ LlmModel: utils.LLM_MODEL,
|
|
|
+ }
|
|
|
+ }
|
|
|
+ })
|
|
|
+ return etaLlmClient
|
|
|
+}
|
|
|
+
|
|
|
+func (ds *ETALLMClient) KnowledgeBaseChat(query string, KnowledgeBaseName string, history []string) (llmRes *http.Response, err error) {
|
|
|
+ ChatHistory := make([]eta_llm_http.HistoryContent, 0)
|
|
|
+ for _, historyItemStr := range history {
|
|
|
+ str := strings.Split(historyItemStr, "-")
|
|
|
+ historyItem := eta_llm_http.HistoryContent{
|
|
|
+ Role: str[0],
|
|
|
+ Content: str[1],
|
|
|
+ }
|
|
|
+ ChatHistory = append(ChatHistory, historyItem)
|
|
|
+ }
|
|
|
+ kbReq := eta_llm_http.KbChatRequest{
|
|
|
+ Query: query,
|
|
|
+ Mode: KNOWLEDEG_CHAT_MODE,
|
|
|
+ KbName: KnowledgeBaseName,
|
|
|
+ History: ChatHistory,
|
|
|
+ TopK: 3,
|
|
|
+ ScoreThreshold: 0.5,
|
|
|
+ Stream: true,
|
|
|
+ Model: ds.LlmModel,
|
|
|
+ Temperature: 0.7,
|
|
|
+ MaxTokens: 0,
|
|
|
+ PromptName: DEFALUT_PROMPT_NAME,
|
|
|
+ ReturnDirect: false,
|
|
|
+ }
|
|
|
+ fmt.Printf("%v", kbReq.History)
|
|
|
+ body, err := json.Marshal(kbReq)
|
|
|
+ if err != nil {
|
|
|
+ return
|
|
|
+ }
|
|
|
+ return ds.DoStreamPost(KNOWLEDGE_BASE_CHAT_API, body)
|
|
|
+}
|
|
|
+
|
|
|
+func (ds *ETALLMClient) SearchKbDocs(query string, KnowledgeBaseName string) (content interface{}, err error) {
|
|
|
+ kbReq := eta_llm_http.KbSearchDocsRequest{
|
|
|
+ Query: query,
|
|
|
+ KnowledgeBaseName: KnowledgeBaseName,
|
|
|
+ TopK: 10,
|
|
|
+ ScoreThreshold: 0.5,
|
|
|
+ Metadata: struct{}{},
|
|
|
+ }
|
|
|
+
|
|
|
+ body, err := json.Marshal(kbReq)
|
|
|
+ if err != nil {
|
|
|
+ return
|
|
|
+ }
|
|
|
+ resp, err := ds.DoPost(KNOWLEDGE_BASE_SEARCH_DOCS_API, body)
|
|
|
+ if !resp.Success {
|
|
|
+ err = errors.New(resp.Msg)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ if resp.Data != nil {
|
|
|
+ var kbSearchRes []eta_llm_http.SearchDocsResponse
|
|
|
+ err = json.Unmarshal(resp.Data, &kbSearchRes)
|
|
|
+ if err != nil {
|
|
|
+ err = errors.New("搜索知识库失败")
|
|
|
+ return
|
|
|
+ }
|
|
|
+ content = kbSearchRes
|
|
|
+ return
|
|
|
+ }
|
|
|
+ err = errors.New("搜索知识库失败")
|
|
|
+ return
|
|
|
+}
|
|
|
+func init() {
|
|
|
+ err := llm.Register(llm.ETA_LLM_CLIENT, GetInstance())
|
|
|
+ if err != nil {
|
|
|
+ utils.FileLog.Error("注册eta_llm_server服务失败:", err)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (ds *ETALLMClient) DoPost(apiUrl string, body []byte) (baseResp eta_llm_http.BaseResponse, err error) {
|
|
|
+ requestReader := bytes.NewReader(body)
|
|
|
+ response, err := ds.HttpClient.Post(ds.BaseURL+apiUrl, CONTENT_TYPE_JSON, requestReader)
|
|
|
+ if err != nil {
|
|
|
+ return
|
|
|
+ }
|
|
|
+ return parseResponse(response)
|
|
|
+}
|
|
|
+func (ds *ETALLMClient) DoStreamPost(apiUrl string, body []byte) (baseResp *http.Response, err error) {
|
|
|
+ requestReader := bytes.NewReader(body)
|
|
|
+ return ds.HttpClient.Post(ds.BaseURL+apiUrl, CONTENT_TYPE_JSON, requestReader)
|
|
|
+}
|
|
|
+func parseResponse(response *http.Response) (baseResp eta_llm_http.BaseResponse, err error) {
|
|
|
+ defer func() {
|
|
|
+ _ = response.Body.Close()
|
|
|
+ }()
|
|
|
+ baseResp.Ret = response.StatusCode
|
|
|
+ if response.StatusCode != http.StatusOK {
|
|
|
+ baseResp.Msg = fmt.Sprintf("请求失败,状态码:%d, 状态信息:%s", response.StatusCode, http.StatusText(response.StatusCode))
|
|
|
+ return
|
|
|
+ }
|
|
|
+ bodyBytes, err := io.ReadAll(response.Body)
|
|
|
+ if err != nil {
|
|
|
+ err = fmt.Errorf("读取响应体失败: %w", err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ baseResp.Success = true
|
|
|
+ baseResp.Data = bodyBytes
|
|
|
+ return
|
|
|
+}
|
|
|
+func ParseStreamResponse(response *http.Response) (contentChan chan string, errChan chan error, closeChan chan struct{}) {
|
|
|
+ contentChan = make(chan string, 10)
|
|
|
+ errChan = make(chan error, 10)
|
|
|
+ closeChan = make(chan struct{})
|
|
|
+ go func() {
|
|
|
+ defer close(contentChan)
|
|
|
+ defer close(errChan)
|
|
|
+ defer close(closeChan)
|
|
|
+ scanner := bufio.NewScanner(response.Body)
|
|
|
+ scanner.Split(bufio.ScanLines)
|
|
|
+ for scanner.Scan() {
|
|
|
+ line := scanner.Text()
|
|
|
+ if line == "" {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ // 忽略 "ping" 行
|
|
|
+ if strings.HasPrefix(line, ": ping") {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ // 去除 "data: " 前缀
|
|
|
+ if strings.HasPrefix(line, "data: ") {
|
|
|
+ line = strings.TrimPrefix(line, "data: ")
|
|
|
+ }
|
|
|
+ var chunk eta_llm_http.ChunkResponse
|
|
|
+ if err := json.Unmarshal([]byte(line), &chunk); err != nil {
|
|
|
+ fmt.Println("解析错误的line:" + line)
|
|
|
+ errChan <- fmt.Errorf("解析 JSON 块失败: %w", err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ // 处理每个 chunk
|
|
|
+ if chunk.Choices != nil && len(chunk.Choices) > 0 {
|
|
|
+ for _, choice := range chunk.Choices {
|
|
|
+ if choice.Delta.Content != "" {
|
|
|
+ contentChan <- choice.Delta.Content
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if err := scanner.Err(); err != nil {
|
|
|
+ errChan <- fmt.Errorf("读取响应体失败: %w", err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ }()
|
|
|
+ return
|
|
|
+}
|