|
@@ -5,7 +5,6 @@ import (
|
|
|
"bytes"
|
|
|
"encoding/json"
|
|
|
"errors"
|
|
|
- "eta/eta_api/services/llm/facade/bus_response"
|
|
|
"eta/eta_api/utils"
|
|
|
"eta/eta_api/utils/llm"
|
|
|
"eta/eta_api/utils/llm/eta_llm/eta_llm_http"
|
|
@@ -46,7 +45,7 @@ func GetInstance() llm.LLMService {
|
|
|
return etaLlmClient
|
|
|
}
|
|
|
|
|
|
-func (ds *ETALLMClient) KnowledgeBaseChat(query string, KnowledgeBaseName string, history []interface{}) (content interface{}, err error) {
|
|
|
+func (ds *ETALLMClient) KnowledgeBaseChat(query string, KnowledgeBaseName string, history []interface{}) (llmRes *http.Response, err error) {
|
|
|
ChatHistory := make([]eta_llm_http.HistoryContent, 0)
|
|
|
ChatHistory = append(ChatHistory, eta_llm_http.HistoryContent{
|
|
|
Content: query,
|
|
@@ -78,23 +77,7 @@ func (ds *ETALLMClient) KnowledgeBaseChat(query string, KnowledgeBaseName string
|
|
|
if err != nil {
|
|
|
return
|
|
|
}
|
|
|
- resp, err := ds.DoStreamPost(KNOWLEDGE_BASE_CHAT_API, body)
|
|
|
- if !resp.Success {
|
|
|
- err = errors.New(resp.Msg)
|
|
|
- return
|
|
|
- }
|
|
|
- if resp.Data != nil {
|
|
|
- var kbChatRes bus_response.KnowledgeBaseChatResponse
|
|
|
- err = json.Unmarshal(resp.Data, &kbChatRes)
|
|
|
- if err != nil {
|
|
|
- err = errors.New("搜索知识库失败")
|
|
|
- return
|
|
|
- }
|
|
|
- content = kbChatRes
|
|
|
- return
|
|
|
- }
|
|
|
- err = errors.New("搜索知识库失败")
|
|
|
- return
|
|
|
+ return ds.DoStreamPost(KNOWLEDGE_BASE_CHAT_API, body)
|
|
|
}
|
|
|
|
|
|
func (ds *ETALLMClient) SearchKbDocs(query string, KnowledgeBaseName string) (content interface{}, err error) {
|
|
@@ -116,7 +99,7 @@ func (ds *ETALLMClient) SearchKbDocs(query string, KnowledgeBaseName string) (co
|
|
|
return
|
|
|
}
|
|
|
if resp.Data != nil {
|
|
|
- var kbSearchRes []bus_response.SearchDocsResponse
|
|
|
+ var kbSearchRes []eta_llm_http.SearchDocsResponse
|
|
|
err = json.Unmarshal(resp.Data, &kbSearchRes)
|
|
|
if err != nil {
|
|
|
err = errors.New("搜索知识库失败")
|
|
@@ -143,13 +126,9 @@ func (ds *ETALLMClient) DoPost(apiUrl string, body []byte) (baseResp eta_llm_htt
|
|
|
}
|
|
|
return parseResponse(response)
|
|
|
}
|
|
|
-func (ds *ETALLMClient) DoStreamPost(apiUrl string, body []byte) (baseResp eta_llm_http.BaseResponse, err error) {
|
|
|
+func (ds *ETALLMClient) DoStreamPost(apiUrl string, body []byte) (baseResp *http.Response, err error) {
|
|
|
requestReader := bytes.NewReader(body)
|
|
|
- response, err := ds.HttpClient.Post(ds.BaseURL+apiUrl, CONTENT_TYPE_JSON, requestReader)
|
|
|
- if err != nil {
|
|
|
- return
|
|
|
- }
|
|
|
- return parseResponse(response)
|
|
|
+ return ds.HttpClient.Post(ds.BaseURL+apiUrl, CONTENT_TYPE_JSON, requestReader)
|
|
|
}
|
|
|
func parseResponse(response *http.Response) (baseResp eta_llm_http.BaseResponse, err error) {
|
|
|
defer func() {
|
|
@@ -169,30 +148,31 @@ func parseResponse(response *http.Response) (baseResp eta_llm_http.BaseResponse,
|
|
|
baseResp.Data = bodyBytes
|
|
|
return
|
|
|
}
|
|
|
-func parseStreamResponse(response *http.Response) (contentChan chan string, errChan chan error) {
|
|
|
+func ParseStreamResponse(response *http.Response) (contentChan chan string, errChan chan error, closeChan chan struct{}) {
|
|
|
defer func() {
|
|
|
_ = response.Body.Close()
|
|
|
+ defer close(contentChan)
|
|
|
+ defer close(errChan)
|
|
|
+ defer close(closeChan)
|
|
|
}()
|
|
|
- contentChan = make(chan string)
|
|
|
- errChan = make(chan error)
|
|
|
+ contentChan = make(chan string, 10)
|
|
|
+ errChan = make(chan error, 10)
|
|
|
+ closeChan = make(chan struct{})
|
|
|
go func() {
|
|
|
defer close(contentChan)
|
|
|
defer close(errChan)
|
|
|
scanner := bufio.NewScanner(response.Body)
|
|
|
scanner.Split(bufio.ScanLines)
|
|
|
-
|
|
|
for scanner.Scan() {
|
|
|
line := scanner.Text()
|
|
|
if line == "" {
|
|
|
continue
|
|
|
}
|
|
|
-
|
|
|
var chunk eta_llm_http.ChunkResponse
|
|
|
if err := json.Unmarshal([]byte(line), &chunk); err != nil {
|
|
|
errChan <- fmt.Errorf("解析 JSON 块失败: %w", err)
|
|
|
return
|
|
|
}
|
|
|
-
|
|
|
// 处理每个 chunk
|
|
|
if chunk.Choices != nil && len(chunk.Choices) > 0 {
|
|
|
for _, choice := range chunk.Choices {
|
|
@@ -202,11 +182,10 @@ func parseStreamResponse(response *http.Response) (contentChan chan string, errC
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
-
|
|
|
if err := scanner.Err(); err != nil {
|
|
|
errChan <- fmt.Errorf("读取响应体失败: %w", err)
|
|
|
return
|
|
|
}
|
|
|
}()
|
|
|
- return contentChan, errChan
|
|
|
+ return
|
|
|
}
|