eta_llm_client.go 8.9 KB


  1. package eta_llm
  2. import (
  3. "bufio"
  4. "bytes"
  5. "encoding/json"
  6. "errors"
  7. "eta/eta_api/models"
  8. "eta/eta_api/utils"
  9. "eta/eta_api/utils/llm"
  10. "eta/eta_api/utils/llm/eta_llm/eta_llm_http"
  11. "fmt"
  12. "io"
  13. "net/http"
  14. "os"
  15. "strings"
  16. "sync"
  17. )
  18. var (
  19. dsOnce sync.Once
  20. etaLlmClient *ETALLMClient
  21. )
  22. const (
  23. KNOWLEDEG_CHAT_MODE = "local_kb"
  24. DEFALUT_PROMPT_NAME = "default"
  25. CONTENT_TYPE_JSON = "application/json"
  26. KNOWLEDGE_BASE_CHAT_API = "/chat/kb_chat"
  27. DOCUMENT_CHAT_API = "/chat/file_chat"
  28. KNOWLEDGE_BASE_SEARCH_DOCS_API = "/knowledge_base/search_docs"
  29. UPLOAD_TEMP_DOCS_API = "/knowledge_base/upload_temp_docs"
  30. )
  31. type ETALLMClient struct {
  32. *llm.LLMClient
  33. LlmModel string
  34. }
  35. type LLMConfig struct {
  36. LlmAddress string `json:"llm_server"`
  37. LlmModel string `json:"llm_model"`
  38. }
  39. func GetInstance() llm.LLMService {
  40. dsOnce.Do(func() {
  41. confStr := models.BusinessConfMap[models.LLMInitConfig]
  42. if confStr == "" {
  43. utils.FileLog.Error("LLM配置为空")
  44. return
  45. }
  46. var config LLMConfig
  47. err := json.Unmarshal([]byte(confStr), &config)
  48. if err != nil {
  49. utils.FileLog.Error("LLM配置错误")
  50. }
  51. if etaLlmClient == nil {
  52. etaLlmClient = &ETALLMClient{
  53. LLMClient: llm.NewLLMClient(config.LlmAddress, 120),
  54. LlmModel: config.LlmModel,
  55. }
  56. }
  57. })
  58. return etaLlmClient
  59. }
  60. func (ds *ETALLMClient) DocumentChat(query string, KnowledgeId string, history []json.RawMessage, stream bool) (llmRes *http.Response, err error) {
  61. ChatHistory := make([]eta_llm_http.HistoryContent, 0)
  62. for _, historyItemStr := range history {
  63. var historyItem eta_llm_http.HistoryContent
  64. parseErr := json.Unmarshal(historyItemStr, &historyItem)
  65. if parseErr != nil {
  66. continue
  67. }
  68. //str := strings.Split(historyItemStr, "-")
  69. //historyItem := eta_llm_http.HistoryContent{
  70. // Role: str[0],
  71. // Content: str[1],
  72. //}
  73. ChatHistory = append(ChatHistory, historyItem)
  74. }
  75. kbReq := eta_llm_http.DocumentChatRequest{
  76. Query: query,
  77. KnowledgeId: KnowledgeId,
  78. History: ChatHistory,
  79. TopK: 3,
  80. //ScoreThreshold: 0.5,
  81. ScoreThreshold: 2,
  82. Stream: stream,
  83. ModelName: ds.LlmModel,
  84. //Temperature: 0.7,
  85. Temperature: 0.01,
  86. MaxTokens: 0,
  87. //PromptName: DEFALUT_PROMPT_NAME,
  88. }
  89. //fmt.Printf("%v", kbReq.History)
  90. body, err := json.Marshal(kbReq)
  91. fmt.Println(string(body))
  92. if err != nil {
  93. return
  94. }
  95. return ds.DoStreamPost(DOCUMENT_CHAT_API, body)
  96. }
  97. func (ds *ETALLMClient) KnowledgeBaseChat(query string, KnowledgeBaseName string, history []json.RawMessage) (llmRes *http.Response, err error) {
  98. ChatHistory := make([]eta_llm_http.HistoryContent, 0)
  99. for _, historyItemStr := range history {
  100. var historyItem eta_llm_http.HistoryContentWeb
  101. parseErr := json.Unmarshal(historyItemStr, &historyItem)
  102. if parseErr != nil {
  103. continue
  104. }
  105. ChatHistory = append(ChatHistory, eta_llm_http.HistoryContent{
  106. Content: historyItem.Content,
  107. Role: historyItem.Role,
  108. })
  109. }
  110. kbReq := eta_llm_http.KbChatRequest{
  111. Query: query,
  112. Mode: KNOWLEDEG_CHAT_MODE,
  113. KbName: KnowledgeBaseName,
  114. History: ChatHistory,
  115. TopK: 3,
  116. ScoreThreshold: 0.5,
  117. Stream: true,
  118. Model: ds.LlmModel,
  119. Temperature: 0.7,
  120. MaxTokens: 0,
  121. PromptName: DEFALUT_PROMPT_NAME,
  122. ReturnDirect: false,
  123. }
  124. fmt.Printf("%v", kbReq.History)
  125. body, err := json.Marshal(kbReq)
  126. if err != nil {
  127. return
  128. }
  129. return ds.DoStreamPost(KNOWLEDGE_BASE_CHAT_API, body)
  130. }
  131. func (ds *ETALLMClient) FileChat(query string, KnowledgeId string, history []json.RawMessage) (resp eta_llm_http.BaseResponse, err error) {
  132. ChatHistory := make([]eta_llm_http.HistoryContent, 0)
  133. for _, historyItemStr := range history {
  134. var historyItem eta_llm_http.HistoryContentWeb
  135. parseErr := json.Unmarshal(historyItemStr, &historyItem)
  136. if parseErr != nil {
  137. continue
  138. }
  139. ChatHistory = append(ChatHistory, eta_llm_http.HistoryContent{
  140. Content: historyItem.Content,
  141. Role: historyItem.Role,
  142. })
  143. }
  144. kbReq := eta_llm_http.DocumentChatRequest{
  145. ModelName: ds.LlmModel,
  146. Query: query,
  147. KnowledgeId: KnowledgeId,
  148. History: ChatHistory,
  149. TopK: 3,
  150. ScoreThreshold: 0.5,
  151. Stream: false,
  152. Temperature: 0.7,
  153. MaxTokens: 0,
  154. PromptName: DEFALUT_PROMPT_NAME,
  155. }
  156. body, err := json.Marshal(kbReq)
  157. if err != nil {
  158. utils.FileLog.Error("内容生成失败,序列化请求参数失败,err", err.Error())
  159. err = fmt.Errorf("内容生成失败,序列化请求参数失败,err:%v", err)
  160. return
  161. }
  162. return ds.DoPost(DOCUMENT_CHAT_API, body)
  163. }
  164. func (ds *ETALLMClient) UploadFileToTemplate(files []*os.File, param map[string]interface{}) (data interface{}, err error) {
  165. pervId := param["PrevId"].(string)
  166. docReq := eta_llm_http.UploadTempDocsRequest{
  167. ChunkOverlap: 750,
  168. ChunkSize: 150,
  169. Files: files,
  170. PrevId: pervId,
  171. ZhTitleEnhance: false,
  172. }
  173. body, err := json.Marshal(docReq)
  174. if err != nil {
  175. return
  176. }
  177. resp, err := ds.DoPost(UPLOAD_TEMP_DOCS_API, body)
  178. if !resp.Success {
  179. err = errors.New(resp.Msg)
  180. return
  181. }
  182. if resp.Data != nil {
  183. var uploadDocsRes eta_llm_http.RagBaseResponse
  184. err = json.Unmarshal(resp.Data, &uploadDocsRes)
  185. if err != nil {
  186. err = errors.New("上传临时文件失败,err:" + err.Error())
  187. return
  188. }
  189. if uploadDocsRes.Code != 200 {
  190. err = errors.New("上传临时文件失败,err:" + uploadDocsRes.Msg)
  191. return
  192. }
  193. var uploadResult eta_llm_http.UploadDocsResponse
  194. err = json.Unmarshal(uploadDocsRes.Data, &uploadResult)
  195. if len(uploadResult.FiledFiles) > 0 {
  196. utils.FileLog.Warn("上传临时文件失败:", uploadResult.FiledFiles)
  197. }
  198. data = uploadResult
  199. return
  200. }
  201. return
  202. }
  203. func (ds *ETALLMClient) SearchKbDocs(query string, KnowledgeBaseName string) (content interface{}, err error) {
  204. kbReq := eta_llm_http.KbSearchDocsRequest{
  205. Query: query,
  206. KnowledgeBaseName: KnowledgeBaseName,
  207. TopK: 10,
  208. ScoreThreshold: 0.5,
  209. Metadata: struct{}{},
  210. }
  211. body, err := json.Marshal(kbReq)
  212. if err != nil {
  213. return
  214. }
  215. resp, err := ds.DoPost(KNOWLEDGE_BASE_SEARCH_DOCS_API, body)
  216. if !resp.Success {
  217. err = errors.New(resp.Msg)
  218. return
  219. }
  220. if resp.Data != nil {
  221. var kbSearchRes []eta_llm_http.SearchDocsResponse
  222. err = json.Unmarshal(resp.Data, &kbSearchRes)
  223. if err != nil {
  224. err = errors.New("搜索知识库失败")
  225. return
  226. }
  227. content = kbSearchRes
  228. return
  229. }
  230. err = errors.New("搜索知识库失败")
  231. return
  232. }
  233. func init() {
  234. err := llm.Register(llm.ETA_LLM_CLIENT, GetInstance())
  235. if err != nil {
  236. utils.FileLog.Error("注册eta_llm_server服务失败:", err)
  237. }
  238. }
  239. func (ds *ETALLMClient) DoPost(apiUrl string, body []byte) (baseResp eta_llm_http.BaseResponse, err error) {
  240. requestReader := bytes.NewReader(body)
  241. response, err := ds.HttpClient.Post(ds.BaseURL+apiUrl, CONTENT_TYPE_JSON, requestReader)
  242. if err != nil {
  243. return
  244. }
  245. return parseResponse(response)
  246. }
  247. func (ds *ETALLMClient) DoStreamPost(apiUrl string, body []byte) (baseResp *http.Response, err error) {
  248. requestReader := bytes.NewReader(body)
  249. return ds.HttpClient.Post(ds.BaseURL+apiUrl, CONTENT_TYPE_JSON, requestReader)
  250. }
  251. func parseResponse(response *http.Response) (baseResp eta_llm_http.BaseResponse, err error) {
  252. defer func() {
  253. _ = response.Body.Close()
  254. }()
  255. baseResp.Ret = response.StatusCode
  256. if response.StatusCode != http.StatusOK {
  257. baseResp.Msg = fmt.Sprintf("请求失败,状态码:%d, 状态信息:%s", response.StatusCode, http.StatusText(response.StatusCode))
  258. return
  259. }
  260. bodyBytes, err := io.ReadAll(response.Body)
  261. if err != nil {
  262. err = fmt.Errorf("读取响应体失败: %w", err)
  263. return
  264. }
  265. baseResp.Success = true
  266. baseResp.Data = bodyBytes
  267. return
  268. }
  269. func ParseStreamResponse(response *http.Response) (contentChan chan string, errChan chan error, closeChan chan struct{}) {
  270. contentChan = make(chan string, 10)
  271. errChan = make(chan error, 10)
  272. closeChan = make(chan struct{})
  273. go func() {
  274. defer close(contentChan)
  275. defer close(errChan)
  276. defer close(closeChan)
  277. scanner := bufio.NewScanner(response.Body)
  278. scanner.Split(bufio.ScanLines)
  279. for scanner.Scan() {
  280. line := scanner.Text()
  281. if line == "" {
  282. continue
  283. }
  284. // 忽略 "ping" 行
  285. if strings.HasPrefix(line, ": ping") {
  286. continue
  287. }
  288. // 去除 "data: " 前缀
  289. if strings.HasPrefix(line, "data: ") {
  290. line = strings.TrimPrefix(line, "data: ")
  291. }
  292. var chunk eta_llm_http.ChunkResponse
  293. if err := json.Unmarshal([]byte(line), &chunk); err != nil {
  294. fmt.Println("解析错误的line:" + line)
  295. errChan <- fmt.Errorf("解析 JSON 块失败: %w", err)
  296. return
  297. }
  298. // 处理每个 chunk
  299. if chunk.Choices != nil && len(chunk.Choices) > 0 {
  300. for _, choice := range chunk.Choices {
  301. if choice.Delta.Content != "" {
  302. contentChan <- choice.Delta.Content
  303. }
  304. }
  305. }
  306. }
  307. if err := scanner.Err(); err != nil {
  308. errChan <- fmt.Errorf("读取响应体失败: %w", err)
  309. return
  310. }
  311. }()
  312. return
  313. }