eta_llm_client.go 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. package eta_llm
  2. import (
  3. "bufio"
  4. "bytes"
  5. "encoding/json"
  6. "errors"
  7. "eta/eta_api/services/llm/facade/bus_response"
  8. "eta/eta_api/utils"
  9. "eta/eta_api/utils/llm"
  10. "eta/eta_api/utils/llm/eta_llm/eta_llm_http"
  11. "fmt"
  12. "io"
  13. "net/http"
  14. "sync"
  15. )
  16. var (
  17. dsOnce sync.Once
  18. etaLlmClient *ETALLMClient
  19. )
  20. const (
  21. KNOWLEDEG_CHAT_MODE = "local_kb"
  22. DEFALUT_PROMPT_NAME = "default"
  23. CONTENT_TYPE_JSON = "application/json"
  24. KNOWLEDGE_BASE_CHAT_API = "/chat/kb_chat"
  25. KNOWLEDGE_BASE_SEARCH_DOCS_API = "/knowledge_base/search_docs"
  26. )
  27. type ETALLMClient struct {
  28. *llm.LLMClient
  29. LlmModel string
  30. }
  31. func GetInstance() llm.LLMService {
  32. dsOnce.Do(func() {
  33. if etaLlmClient == nil {
  34. etaLlmClient = &ETALLMClient{
  35. LLMClient: llm.NewLLMClient(utils.LLM_SERVER, 10),
  36. LlmModel: utils.LLM_MODEL,
  37. }
  38. }
  39. })
  40. return etaLlmClient
  41. }
  42. func (ds *ETALLMClient) KnowledgeBaseChat(query string, KnowledgeBaseName string, history []interface{}) (content interface{}, err error) {
  43. ChatHistory := make([]eta_llm_http.HistoryContent, 0)
  44. ChatHistory = append(ChatHistory, eta_llm_http.HistoryContent{
  45. Content: query,
  46. Role: "user",
  47. })
  48. for _, historyItem := range history {
  49. historyItemMap := historyItem.(map[string]interface{})
  50. ChatHistory = append(ChatHistory, eta_llm_http.HistoryContent{
  51. Content: historyItemMap["content"].(string),
  52. Role: historyItemMap["role"].(string),
  53. })
  54. }
  55. kbReq := eta_llm_http.KbChatRequest{
  56. Query: query,
  57. Mode: KNOWLEDEG_CHAT_MODE,
  58. KbName: KnowledgeBaseName,
  59. History: ChatHistory,
  60. TopK: 3,
  61. ScoreThreshold: 0.5,
  62. Stream: true,
  63. Model: ds.LlmModel,
  64. Temperature: 0.7,
  65. MaxTokens: 0,
  66. PromptName: DEFALUT_PROMPT_NAME,
  67. ReturnDirect: false,
  68. }
  69. body, err := json.Marshal(kbReq)
  70. if err != nil {
  71. return
  72. }
  73. resp, err := ds.DoStreamPost(KNOWLEDGE_BASE_CHAT_API, body)
  74. if !resp.Success {
  75. err = errors.New(resp.Msg)
  76. return
  77. }
  78. if resp.Data != nil {
  79. var kbChatRes bus_response.KnowledgeBaseChatResponse
  80. err = json.Unmarshal(resp.Data, &kbChatRes)
  81. if err != nil {
  82. err = errors.New("搜索知识库失败")
  83. return
  84. }
  85. content = kbChatRes
  86. return
  87. }
  88. err = errors.New("搜索知识库失败")
  89. return
  90. }
  91. func (ds *ETALLMClient) SearchKbDocs(query string, KnowledgeBaseName string) (content interface{}, err error) {
  92. kbReq := eta_llm_http.KbSearchDocsRequest{
  93. Query: query,
  94. KnowledgeBaseName: KnowledgeBaseName,
  95. TopK: 10,
  96. ScoreThreshold: 0.5,
  97. Metadata: struct{}{},
  98. }
  99. body, err := json.Marshal(kbReq)
  100. if err != nil {
  101. return
  102. }
  103. resp, err := ds.DoPost(KNOWLEDGE_BASE_SEARCH_DOCS_API, body)
  104. if !resp.Success {
  105. err = errors.New(resp.Msg)
  106. return
  107. }
  108. if resp.Data != nil {
  109. var kbSearchRes []bus_response.SearchDocsResponse
  110. err = json.Unmarshal(resp.Data, &kbSearchRes)
  111. if err != nil {
  112. err = errors.New("搜索知识库失败")
  113. return
  114. }
  115. content = kbSearchRes
  116. return
  117. }
  118. err = errors.New("搜索知识库失败")
  119. return
  120. }
  121. func init() {
  122. err := llm.Register(llm.ETA_LLM_CLIENT, GetInstance())
  123. if err != nil {
  124. utils.FileLog.Error("注册eta_llm_server服务失败:", err)
  125. }
  126. }
  127. func (ds *ETALLMClient) DoPost(apiUrl string, body []byte) (baseResp eta_llm_http.BaseResponse, err error) {
  128. requestReader := bytes.NewReader(body)
  129. response, err := ds.HttpClient.Post(ds.BaseURL+apiUrl, CONTENT_TYPE_JSON, requestReader)
  130. if err != nil {
  131. return
  132. }
  133. return parseResponse(response)
  134. }
  135. func (ds *ETALLMClient) DoStreamPost(apiUrl string, body []byte) (baseResp eta_llm_http.BaseResponse, err error) {
  136. requestReader := bytes.NewReader(body)
  137. response, err := ds.HttpClient.Post(ds.BaseURL+apiUrl, CONTENT_TYPE_JSON, requestReader)
  138. if err != nil {
  139. return
  140. }
  141. return parseResponse(response)
  142. }
  143. func parseResponse(response *http.Response) (baseResp eta_llm_http.BaseResponse, err error) {
  144. defer func() {
  145. _ = response.Body.Close()
  146. }()
  147. baseResp.Ret = response.StatusCode
  148. if response.StatusCode != http.StatusOK {
  149. baseResp.Msg = fmt.Sprintf("请求失败,状态码:%d, 状态信息:%s", response.StatusCode, http.StatusText(response.StatusCode))
  150. return
  151. }
  152. bodyBytes, err := io.ReadAll(response.Body)
  153. if err != nil {
  154. err = fmt.Errorf("读取响应体失败: %w", err)
  155. return
  156. }
  157. baseResp.Success = true
  158. baseResp.Data = bodyBytes
  159. return
  160. }
  161. func parseStreamResponse(response *http.Response) (contentChan chan string, errChan chan error) {
  162. defer func() {
  163. _ = response.Body.Close()
  164. }()
  165. contentChan = make(chan string)
  166. errChan = make(chan error)
  167. go func() {
  168. defer close(contentChan)
  169. defer close(errChan)
  170. scanner := bufio.NewScanner(response.Body)
  171. scanner.Split(bufio.ScanLines)
  172. for scanner.Scan() {
  173. line := scanner.Text()
  174. if line == "" {
  175. continue
  176. }
  177. var chunk eta_llm_http.ChunkResponse
  178. if err := json.Unmarshal([]byte(line), &chunk); err != nil {
  179. errChan <- fmt.Errorf("解析 JSON 块失败: %w", err)
  180. return
  181. }
  182. // 处理每个 chunk
  183. if chunk.Choices != nil && len(chunk.Choices) > 0 {
  184. for _, choice := range chunk.Choices {
  185. if choice.Delta.Content != "" {
  186. contentChan <- choice.Delta.Content
  187. }
  188. }
  189. }
  190. }
  191. if err := scanner.Err(); err != nil {
  192. errChan <- fmt.Errorf("读取响应体失败: %w", err)
  193. return
  194. }
  195. }()
  196. return contentChan, errChan
  197. }