|
@@ -1,6 +1,7 @@
|
|
package eta_llm
|
|
package eta_llm
|
|
|
|
|
|
import (
|
|
import (
|
|
|
|
+ "bufio"
|
|
"bytes"
|
|
"bytes"
|
|
"encoding/json"
|
|
"encoding/json"
|
|
"errors"
|
|
"errors"
|
|
@@ -21,6 +22,8 @@ var (
|
|
)
|
|
)
|
|
|
|
|
|
const (
|
|
const (
|
|
|
|
+ KNOWLEDEG_CHAT_MODE = "local_kb"
|
|
|
|
+ DEFALUT_PROMPT_NAME = "default"
|
|
CONTENT_TYPE_JSON = "application/json"
|
|
CONTENT_TYPE_JSON = "application/json"
|
|
KNOWLEDGE_BASE_CHAT_API = "/chat/kb_chat"
|
|
KNOWLEDGE_BASE_CHAT_API = "/chat/kb_chat"
|
|
KNOWLEDGE_BASE_SEARCH_DOCS_API = "/knowledge_base/search_docs"
|
|
KNOWLEDGE_BASE_SEARCH_DOCS_API = "/knowledge_base/search_docs"
|
|
@@ -43,14 +46,58 @@ func GetInstance() llm.LLMService {
|
|
return etaLlmClient
|
|
return etaLlmClient
|
|
}
|
|
}
|
|
|
|
|
|
-func (ds *ETALLMClient) KnowledgeBaseChat() string {
|
|
|
|
- ds.HttpClient.Post(ds.BaseURL+KNOWLEDGE_BASE_CHAT_API, CONTENT_TYPE_JSON, nil)
|
|
|
|
- return ""
|
|
|
|
|
|
+func (ds *ETALLMClient) KnowledgeBaseChat(query string, KnowledgeBaseName string, history []interface{}) (content interface{}, err error) {
|
|
|
|
+ ChatHistory := make([]eta_llm_http.HistoryContent, 0)
|
|
|
|
+ ChatHistory = append(ChatHistory, eta_llm_http.HistoryContent{
|
|
|
|
+ Content: query,
|
|
|
|
+ Role: "user",
|
|
|
|
+ })
|
|
|
|
+ for _, historyItem := range history {
|
|
|
|
+ historyItemMap := historyItem.(map[string]interface{})
|
|
|
|
+ ChatHistory = append(ChatHistory, eta_llm_http.HistoryContent{
|
|
|
|
+ Content: historyItemMap["content"].(string),
|
|
|
|
+ Role: historyItemMap["role"].(string),
|
|
|
|
+ })
|
|
|
|
+ }
|
|
|
|
+ kbReq := eta_llm_http.KbChatRequest{
|
|
|
|
+ Query: query,
|
|
|
|
+ Mode: KNOWLEDEG_CHAT_MODE,
|
|
|
|
+ KbName: KnowledgeBaseName,
|
|
|
|
+ History: ChatHistory,
|
|
|
|
+ TopK: 3,
|
|
|
|
+ ScoreThreshold: 0.5,
|
|
|
|
+ Stream: true,
|
|
|
|
+ Model: ds.LlmModel,
|
|
|
|
+ Temperature: 0.7,
|
|
|
|
+ MaxTokens: 0,
|
|
|
|
+ PromptName: DEFALUT_PROMPT_NAME,
|
|
|
|
+ ReturnDirect: false,
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ body, err := json.Marshal(kbReq)
|
|
|
|
+ if err != nil {
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+ resp, err := ds.DoStreamPost(KNOWLEDGE_BASE_CHAT_API, body)
|
|
|
|
+ if !resp.Success {
|
|
|
|
+ err = errors.New(resp.Msg)
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+ if resp.Data != nil {
|
|
|
|
+ var kbChatRes bus_response.KnowledgeBaseChatResponse
|
|
|
|
+ err = json.Unmarshal(resp.Data, &kbChatRes)
|
|
|
|
+ if err != nil {
|
|
|
|
+ err = errors.New("搜索知识库失败")
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+ content = kbChatRes
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+ err = errors.New("搜索知识库失败")
|
|
|
|
+ return
|
|
}
|
|
}
|
|
|
|
|
|
func (ds *ETALLMClient) SearchKbDocs(query string, KnowledgeBaseName string) (content interface{}, err error) {
|
|
func (ds *ETALLMClient) SearchKbDocs(query string, KnowledgeBaseName string) (content interface{}, err error) {
|
|
- // 类型断言
|
|
|
|
-
|
|
|
|
kbReq := eta_llm_http.KbSearchDocsRequest{
|
|
kbReq := eta_llm_http.KbSearchDocsRequest{
|
|
Query: query,
|
|
Query: query,
|
|
KnowledgeBaseName: KnowledgeBaseName,
|
|
KnowledgeBaseName: KnowledgeBaseName,
|
|
@@ -96,7 +143,14 @@ func (ds *ETALLMClient) DoPost(apiUrl string, body []byte) (baseResp eta_llm_htt
|
|
}
|
|
}
|
|
return parseResponse(response)
|
|
return parseResponse(response)
|
|
}
|
|
}
|
|
-
|
|
|
|
|
|
+func (ds *ETALLMClient) DoStreamPost(apiUrl string, body []byte) (baseResp eta_llm_http.BaseResponse, err error) {
|
|
|
|
+ requestReader := bytes.NewReader(body)
|
|
|
|
+ response, err := ds.HttpClient.Post(ds.BaseURL+apiUrl, CONTENT_TYPE_JSON, requestReader)
|
|
|
|
+ if err != nil {
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+ return parseResponse(response)
|
|
|
|
+}
|
|
func parseResponse(response *http.Response) (baseResp eta_llm_http.BaseResponse, err error) {
|
|
func parseResponse(response *http.Response) (baseResp eta_llm_http.BaseResponse, err error) {
|
|
defer func() {
|
|
defer func() {
|
|
_ = response.Body.Close()
|
|
_ = response.Body.Close()
|
|
@@ -115,3 +169,44 @@ func parseResponse(response *http.Response) (baseResp eta_llm_http.BaseResponse,
|
|
baseResp.Data = bodyBytes
|
|
baseResp.Data = bodyBytes
|
|
return
|
|
return
|
|
}
|
|
}
|
|
|
|
+func parseStreamResponse(response *http.Response) (contentChan chan string, errChan chan error) {
|
|
|
|
+ defer func() {
|
|
|
|
+ _ = response.Body.Close()
|
|
|
|
+ }()
|
|
|
|
+ contentChan = make(chan string)
|
|
|
|
+ errChan = make(chan error)
|
|
|
|
+ go func() {
|
|
|
|
+ defer close(contentChan)
|
|
|
|
+ defer close(errChan)
|
|
|
|
+ scanner := bufio.NewScanner(response.Body)
|
|
|
|
+ scanner.Split(bufio.ScanLines)
|
|
|
|
+
|
|
|
|
+ for scanner.Scan() {
|
|
|
|
+ line := scanner.Text()
|
|
|
|
+ if line == "" {
|
|
|
|
+ continue
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ var chunk eta_llm_http.ChunkResponse
|
|
|
|
+ if err := json.Unmarshal([]byte(line), &chunk); err != nil {
|
|
|
|
+ errChan <- fmt.Errorf("解析 JSON 块失败: %w", err)
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // 处理每个 chunk
|
|
|
|
+ if chunk.Choices != nil && len(chunk.Choices) > 0 {
|
|
|
|
+ for _, choice := range chunk.Choices {
|
|
|
|
+ if choice.Delta.Content != "" {
|
|
|
|
+ contentChan <- choice.Delta.Content
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ if err := scanner.Err(); err != nil {
|
|
|
|
+ errChan <- fmt.Errorf("读取响应体失败: %w", err)
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+ }()
|
|
|
|
+ return contentChan, errChan
|
|
|
|
+}
|