eta_llm_client.go 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. package eta_llm
  2. import (
  3. "bufio"
  4. "bytes"
  5. "encoding/json"
  6. "errors"
  7. "eta/eta_api/utils"
  8. "eta/eta_api/utils/llm"
  9. "eta/eta_api/utils/llm/eta_llm/eta_llm_http"
  10. "fmt"
  11. "io"
  12. "net/http"
  13. "strings"
  14. "sync"
  15. )
  16. var (
  17. dsOnce sync.Once
  18. etaLlmClient *ETALLMClient
  19. )
  20. const (
  21. KNOWLEDEG_CHAT_MODE = "local_kb"
  22. DEFALUT_PROMPT_NAME = "default"
  23. CONTENT_TYPE_JSON = "application/json"
  24. KNOWLEDGE_BASE_CHAT_API = "/chat/kb_chat"
  25. KNOWLEDGE_BASE_SEARCH_DOCS_API = "/knowledge_base/search_docs"
  26. )
  27. type ETALLMClient struct {
  28. *llm.LLMClient
  29. LlmModel string
  30. }
  31. func GetInstance() llm.LLMService {
  32. dsOnce.Do(func() {
  33. if etaLlmClient == nil {
  34. etaLlmClient = &ETALLMClient{
  35. LLMClient: llm.NewLLMClient(utils.LLM_SERVER, 120),
  36. LlmModel: utils.LLM_MODEL,
  37. }
  38. }
  39. })
  40. return etaLlmClient
  41. }
  42. func (ds *ETALLMClient) KnowledgeBaseChat(query string, KnowledgeBaseName string, history []string) (llmRes *http.Response, err error) {
  43. ChatHistory := make([]eta_llm_http.HistoryContent, 0)
  44. for _, historyItemStr := range history {
  45. str := strings.Split(historyItemStr, "-")
  46. historyItem := eta_llm_http.HistoryContent{
  47. Role: str[0],
  48. Content: str[1],
  49. }
  50. ChatHistory = append(ChatHistory, historyItem)
  51. }
  52. kbReq := eta_llm_http.KbChatRequest{
  53. Query: query,
  54. Mode: KNOWLEDEG_CHAT_MODE,
  55. KbName: KnowledgeBaseName,
  56. History: ChatHistory,
  57. TopK: 3,
  58. ScoreThreshold: 0.5,
  59. Stream: true,
  60. Model: ds.LlmModel,
  61. Temperature: 0.7,
  62. MaxTokens: 0,
  63. PromptName: DEFALUT_PROMPT_NAME,
  64. ReturnDirect: false,
  65. }
  66. fmt.Printf("%v", kbReq.History)
  67. body, err := json.Marshal(kbReq)
  68. if err != nil {
  69. return
  70. }
  71. return ds.DoStreamPost(KNOWLEDGE_BASE_CHAT_API, body)
  72. }
  73. func (ds *ETALLMClient) SearchKbDocs(query string, KnowledgeBaseName string) (content interface{}, err error) {
  74. kbReq := eta_llm_http.KbSearchDocsRequest{
  75. Query: query,
  76. KnowledgeBaseName: KnowledgeBaseName,
  77. TopK: 10,
  78. ScoreThreshold: 0.5,
  79. Metadata: struct{}{},
  80. }
  81. body, err := json.Marshal(kbReq)
  82. if err != nil {
  83. return
  84. }
  85. resp, err := ds.DoPost(KNOWLEDGE_BASE_SEARCH_DOCS_API, body)
  86. if !resp.Success {
  87. err = errors.New(resp.Msg)
  88. return
  89. }
  90. if resp.Data != nil {
  91. var kbSearchRes []eta_llm_http.SearchDocsResponse
  92. err = json.Unmarshal(resp.Data, &kbSearchRes)
  93. if err != nil {
  94. err = errors.New("搜索知识库失败")
  95. return
  96. }
  97. content = kbSearchRes
  98. return
  99. }
  100. err = errors.New("搜索知识库失败")
  101. return
  102. }
  103. func init() {
  104. err := llm.Register(llm.ETA_LLM_CLIENT, GetInstance())
  105. if err != nil {
  106. utils.FileLog.Error("注册eta_llm_server服务失败:", err)
  107. }
  108. }
  109. func (ds *ETALLMClient) DoPost(apiUrl string, body []byte) (baseResp eta_llm_http.BaseResponse, err error) {
  110. requestReader := bytes.NewReader(body)
  111. response, err := ds.HttpClient.Post(ds.BaseURL+apiUrl, CONTENT_TYPE_JSON, requestReader)
  112. if err != nil {
  113. return
  114. }
  115. return parseResponse(response)
  116. }
  117. func (ds *ETALLMClient) DoStreamPost(apiUrl string, body []byte) (baseResp *http.Response, err error) {
  118. requestReader := bytes.NewReader(body)
  119. return ds.HttpClient.Post(ds.BaseURL+apiUrl, CONTENT_TYPE_JSON, requestReader)
  120. }
  121. func parseResponse(response *http.Response) (baseResp eta_llm_http.BaseResponse, err error) {
  122. defer func() {
  123. _ = response.Body.Close()
  124. }()
  125. baseResp.Ret = response.StatusCode
  126. if response.StatusCode != http.StatusOK {
  127. baseResp.Msg = fmt.Sprintf("请求失败,状态码:%d, 状态信息:%s", response.StatusCode, http.StatusText(response.StatusCode))
  128. return
  129. }
  130. bodyBytes, err := io.ReadAll(response.Body)
  131. if err != nil {
  132. err = fmt.Errorf("读取响应体失败: %w", err)
  133. return
  134. }
  135. baseResp.Success = true
  136. baseResp.Data = bodyBytes
  137. return
  138. }
  139. func ParseStreamResponse(response *http.Response) (contentChan chan string, errChan chan error, closeChan chan struct{}) {
  140. contentChan = make(chan string, 10)
  141. errChan = make(chan error, 10)
  142. closeChan = make(chan struct{})
  143. go func() {
  144. defer close(contentChan)
  145. defer close(errChan)
  146. defer close(closeChan)
  147. scanner := bufio.NewScanner(response.Body)
  148. scanner.Split(bufio.ScanLines)
  149. for scanner.Scan() {
  150. line := scanner.Text()
  151. if line == "" {
  152. continue
  153. }
  154. // 忽略 "ping" 行
  155. if strings.HasPrefix(line, ": ping") {
  156. continue
  157. }
  158. // 去除 "data: " 前缀
  159. if strings.HasPrefix(line, "data: ") {
  160. line = strings.TrimPrefix(line, "data: ")
  161. }
  162. var chunk eta_llm_http.ChunkResponse
  163. if err := json.Unmarshal([]byte(line), &chunk); err != nil {
  164. fmt.Println("解析错误的line:" + line)
  165. errChan <- fmt.Errorf("解析 JSON 块失败: %w", err)
  166. return
  167. }
  168. // 处理每个 chunk
  169. if chunk.Choices != nil && len(chunk.Choices) > 0 {
  170. for _, choice := range chunk.Choices {
  171. if choice.Delta.Content != "" {
  172. contentChan <- choice.Delta.Content
  173. }
  174. }
  175. }
  176. }
  177. if err := scanner.Err(); err != nil {
  178. errChan <- fmt.Errorf("读取响应体失败: %w", err)
  179. return
  180. }
  181. }()
  182. return
  183. }