eta_llm_client.go 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. package eta_llm
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "errors"
  6. "eta/eta_api/services/llm/facade/bus_response"
  7. "eta/eta_api/utils"
  8. "eta/eta_api/utils/llm"
  9. "eta/eta_api/utils/llm/eta_llm/eta_llm_http"
  10. "fmt"
  11. "io"
  12. "net/http"
  13. "sync"
  14. )
  15. var (
  16. dsOnce sync.Once
  17. etaLlmClient *ETALLMClient
  18. )
  19. const (
  20. CONTENT_TYPE_JSON = "application/json"
  21. KNOWLEDGE_BASE_CHAT_API = "/chat/kb_chat"
  22. KNOWLEDGE_BASE_SEARCH_DOCS_API = "/knowledge_base/search_docs"
  23. )
  24. type ETALLMClient struct {
  25. *llm.LLMClient
  26. LlmModel string
  27. }
  28. func GetInstance() llm.LLMService {
  29. dsOnce.Do(func() {
  30. if etaLlmClient == nil {
  31. etaLlmClient = &ETALLMClient{
  32. LLMClient: llm.NewLLMClient(utils.LLM_SERVER, 10),
  33. LlmModel: utils.LLM_MODEL,
  34. }
  35. }
  36. })
  37. return etaLlmClient
  38. }
  39. func (ds *ETALLMClient) KnowledgeBaseChat() string {
  40. ds.HttpClient.Post(ds.BaseURL+KNOWLEDGE_BASE_CHAT_API, CONTENT_TYPE_JSON, nil)
  41. return ""
  42. }
  43. func (ds *ETALLMClient) SearchKbDocs(query string, KnowledgeBaseName string) (content interface{}, err error) {
  44. // 类型断言
  45. kbReq := eta_llm_http.KbSearchDocsRequest{
  46. Query: query,
  47. KnowledgeBaseName: KnowledgeBaseName,
  48. TopK: 10,
  49. ScoreThreshold: 0.5,
  50. Metadata: struct{}{},
  51. }
  52. body, err := json.Marshal(kbReq)
  53. if err != nil {
  54. return
  55. }
  56. resp, err := ds.DoPost(KNOWLEDGE_BASE_SEARCH_DOCS_API, body)
  57. if !resp.Success {
  58. err = errors.New(resp.Msg)
  59. return
  60. }
  61. if resp.Data != nil {
  62. var kbSearchRes []bus_response.SearchDocsResponse
  63. err = json.Unmarshal(resp.Data, &kbSearchRes)
  64. if err != nil {
  65. err = errors.New("搜索知识库失败")
  66. return
  67. }
  68. content = kbSearchRes
  69. return
  70. }
  71. err = errors.New("搜索知识库失败")
  72. return
  73. }
  74. func init() {
  75. err := llm.Register(llm.ETA_LLM_CLIENT, GetInstance())
  76. if err != nil {
  77. utils.FileLog.Error("注册eta_llm_server服务失败:", err)
  78. }
  79. }
  80. func (ds *ETALLMClient) DoPost(apiUrl string, body []byte) (baseResp eta_llm_http.BaseResponse, err error) {
  81. requestReader := bytes.NewReader(body)
  82. response, err := ds.HttpClient.Post(ds.BaseURL+apiUrl, CONTENT_TYPE_JSON, requestReader)
  83. if err != nil {
  84. return
  85. }
  86. return parseResponse(response)
  87. }
  88. func parseResponse(response *http.Response) (baseResp eta_llm_http.BaseResponse, err error) {
  89. defer func() {
  90. _ = response.Body.Close()
  91. }()
  92. baseResp.Ret = response.StatusCode
  93. if response.StatusCode != http.StatusOK {
  94. baseResp.Msg = fmt.Sprintf("请求失败,状态码:%d, 状态信息:%s", response.StatusCode, http.StatusText(response.StatusCode))
  95. return
  96. }
  97. bodyBytes, err := io.ReadAll(response.Body)
  98. if err != nil {
  99. err = fmt.Errorf("读取响应体失败: %w", err)
  100. return
  101. }
  102. baseResp.Success = true
  103. baseResp.Data = bodyBytes
  104. return
  105. }