package eta_llm import ( "bufio" "bytes" "encoding/json" "errors" "eta/eta_api/models" "eta/eta_api/utils" "eta/eta_api/utils/llm" "eta/eta_api/utils/llm/eta_llm/eta_llm_http" "fmt" "io" "net/http" "strings" "sync" ) var ( dsOnce sync.Once etaLlmClient *ETALLMClient ) const ( KNOWLEDEG_CHAT_MODE = "local_kb" DEFALUT_PROMPT_NAME = "default" CONTENT_TYPE_JSON = "application/json" KNOWLEDGE_BASE_CHAT_API = "/chat/kb_chat" DOCUMENT_CHAT_API = "/chat/file_chat" KNOWLEDGE_BASE_SEARCH_DOCS_API = "/knowledge_base/search_docs" ) type ETALLMClient struct { *llm.LLMClient LlmModel string } type LLMConfig struct { LlmAddress string `json:"llm_server"` LlmModel string `json:"llm_model"` } func GetInstance() llm.LLMService { dsOnce.Do(func() { confStr := models.BusinessConfMap[models.LLMInitConfig] if confStr == "" { utils.FileLog.Error("LLM配置为空") return } var config LLMConfig err := json.Unmarshal([]byte(confStr), &config) if err != nil { utils.FileLog.Error("LLM配置错误") } if etaLlmClient == nil { etaLlmClient = &ETALLMClient{ LLMClient: llm.NewLLMClient(config.LlmAddress, 120), LlmModel: config.LlmModel, } } }) return etaLlmClient } func (ds *ETALLMClient) DocumentChat(query string, KnowledgeId string, history []string, stream bool) (llmRes *http.Response, err error) { ChatHistory := make([]eta_llm_http.HistoryContent, 0) for _, historyItemStr := range history { str := strings.Split(historyItemStr, "-") historyItem := eta_llm_http.HistoryContent{ Role: str[0], Content: str[1], } ChatHistory = append(ChatHistory, historyItem) } kbReq := eta_llm_http.DocumentChatRequest{ Query: query, KnowledgeId: KnowledgeId, History: ChatHistory, TopK: 3, ScoreThreshold: 0.5, Stream: stream, ModelName: ds.LlmModel, Temperature: 0.7, MaxTokens: 0, PromptName: DEFALUT_PROMPT_NAME, } fmt.Printf("%v", kbReq.History) body, err := json.Marshal(kbReq) if err != nil { return } return ds.DoStreamPost(DOCUMENT_CHAT_API, body) } func (ds *ETALLMClient) KnowledgeBaseChat(query string, KnowledgeBaseName string, history []string) (llmRes *http.Response, err error) { ChatHistory := make([]eta_llm_http.HistoryContent, 0) for _, historyItemStr := range history { str := strings.Split(historyItemStr, "-") historyItem := eta_llm_http.HistoryContent{ Role: str[0], Content: str[1], } ChatHistory = append(ChatHistory, historyItem) } kbReq := eta_llm_http.KbChatRequest{ Query: query, Mode: KNOWLEDEG_CHAT_MODE, KbName: KnowledgeBaseName, History: ChatHistory, TopK: 3, ScoreThreshold: 0.5, Stream: true, Model: ds.LlmModel, Temperature: 0.7, MaxTokens: 0, PromptName: DEFALUT_PROMPT_NAME, ReturnDirect: false, } fmt.Printf("%v", kbReq.History) body, err := json.Marshal(kbReq) if err != nil { return } return ds.DoStreamPost(KNOWLEDGE_BASE_CHAT_API, body) } func (ds *ETALLMClient) SearchKbDocs(query string, KnowledgeBaseName string) (content interface{}, err error) { kbReq := eta_llm_http.KbSearchDocsRequest{ Query: query, KnowledgeBaseName: KnowledgeBaseName, TopK: 10, ScoreThreshold: 0.5, Metadata: struct{}{}, } body, err := json.Marshal(kbReq) if err != nil { return } resp, err := ds.DoPost(KNOWLEDGE_BASE_SEARCH_DOCS_API, body) if !resp.Success { err = errors.New(resp.Msg) return } if resp.Data != nil { var kbSearchRes []eta_llm_http.SearchDocsResponse err = json.Unmarshal(resp.Data, &kbSearchRes) if err != nil { err = errors.New("搜索知识库失败") return } content = kbSearchRes return } err = errors.New("搜索知识库失败") return } func init() { err := llm.Register(llm.ETA_LLM_CLIENT, GetInstance()) if err != nil { utils.FileLog.Error("注册eta_llm_server服务失败:", err) } } func (ds *ETALLMClient) DoPost(apiUrl string, body []byte) (baseResp eta_llm_http.BaseResponse, err error) { requestReader := bytes.NewReader(body) response, err := ds.HttpClient.Post(ds.BaseURL+apiUrl, CONTENT_TYPE_JSON, requestReader) if err != nil { return } return parseResponse(response) } func (ds *ETALLMClient) DoStreamPost(apiUrl string, body []byte) (baseResp *http.Response, err error) { requestReader := bytes.NewReader(body) return ds.HttpClient.Post(ds.BaseURL+apiUrl, CONTENT_TYPE_JSON, requestReader) } func parseResponse(response *http.Response) (baseResp eta_llm_http.BaseResponse, err error) { defer func() { _ = response.Body.Close() }() baseResp.Ret = response.StatusCode if response.StatusCode != http.StatusOK { baseResp.Msg = fmt.Sprintf("请求失败,状态码:%d, 状态信息:%s", response.StatusCode, http.StatusText(response.StatusCode)) return } bodyBytes, err := io.ReadAll(response.Body) if err != nil { err = fmt.Errorf("读取响应体失败: %w", err) return } baseResp.Success = true baseResp.Data = bodyBytes return } func ParseStreamResponse(response *http.Response) (contentChan chan string, errChan chan error, closeChan chan struct{}) { contentChan = make(chan string, 10) errChan = make(chan error, 10) closeChan = make(chan struct{}) go func() { defer close(contentChan) defer close(errChan) defer close(closeChan) scanner := bufio.NewScanner(response.Body) scanner.Split(bufio.ScanLines) for scanner.Scan() { line := scanner.Text() if line == "" { continue } // 忽略 "ping" 行 if strings.HasPrefix(line, ": ping") { continue } // 去除 "data: " 前缀 if strings.HasPrefix(line, "data: ") { line = strings.TrimPrefix(line, "data: ") } var chunk eta_llm_http.ChunkResponse if err := json.Unmarshal([]byte(line), &chunk); err != nil { fmt.Println("解析错误的line:" + line) errChan <- fmt.Errorf("解析 JSON 块失败: %w", err) return } // 处理每个 chunk if chunk.Choices != nil && len(chunk.Choices) > 0 { for _, choice := range chunk.Choices { if choice.Delta.Content != "" { contentChan <- choice.Delta.Content } } } } if err := scanner.Err(); err != nil { errChan <- fmt.Errorf("读取响应体失败: %w", err) return } }() return }