eta_llm_client.go 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. package eta_llm
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "errors"
  6. "eta/eta_api/utils"
  7. "eta/eta_api/utils/llm"
  8. "eta/eta_api/utils/llm/eta_llm/eta_llm_http"
  9. "fmt"
  10. "io"
  11. "net/http"
  12. "sync"
  13. )
  14. var (
  15. dsOnce sync.Once
  16. etaLlmClient *ETALLMClient
  17. )
  18. const (
  19. CONTENT_TYPE_JSON = "application/json"
  20. KNOWLEDGE_BASE_CHAT_API = "/chat/kb_chat"
  21. KNOWLEDGE_BASE_SEARCH_DOCS_API = "/knowledge_base/search_docs"
  22. )
  23. type ETALLMClient struct {
  24. *llm.LLMClient
  25. LlmModel string
  26. }
  27. func GetInstance() llm.LLMService {
  28. dsOnce.Do(func() {
  29. if etaLlmClient == nil {
  30. etaLlmClient = &ETALLMClient{
  31. LLMClient: llm.NewLLMClient(utils.LLM_SERVER, 10),
  32. LlmModel: utils.LLM_MODEL,
  33. }
  34. }
  35. })
  36. return etaLlmClient
  37. }
  38. func (ds *ETALLMClient) KnowledgeBaseChat() string {
  39. ds.HttpClient.Post(ds.BaseURL+KNOWLEDGE_BASE_CHAT_API, CONTENT_TYPE_JSON, nil)
  40. return ""
  41. }
  42. func (ds *ETALLMClient) SearchKbDocs(query string, KnowledgeBaseName string) (content string, err error) {
  43. // 类型断言
  44. kbReq := eta_llm_http.KbSearchDocsRequest{
  45. Query: query,
  46. KnowledgeBaseName: KnowledgeBaseName,
  47. Model: ds.LlmModel,
  48. TopK: 3,
  49. ScoreThreshold: 2,
  50. }
  51. body, err := json.Marshal(kbReq)
  52. if err != nil {
  53. return
  54. }
  55. resp, err := ds.DoPost(KNOWLEDGE_BASE_SEARCH_DOCS_API, body)
  56. if !resp.Success {
  57. err = errors.New(resp.Msg)
  58. return
  59. }
  60. return "", nil
  61. }
  62. func init() {
  63. err := llm.Register(llm.ETA_LLM_CLIENT, GetInstance())
  64. if err != nil {
  65. utils.FileLog.Error("注册eta_llm_server服务失败:", err)
  66. }
  67. }
  68. func (ds *ETALLMClient) DoPost(apiUrl string, body []byte) (baseResp eta_llm_http.BaseResponse, err error) {
  69. requestReader := bytes.NewReader(body)
  70. response, err := ds.HttpClient.Post(ds.BaseURL+apiUrl, CONTENT_TYPE_JSON, requestReader)
  71. if err != nil {
  72. return
  73. }
  74. return parseResponse(response)
  75. }
  76. func parseResponse(response *http.Response) (baseResp eta_llm_http.BaseResponse, err error) {
  77. defer func() {
  78. _ = response.Body.Close()
  79. }()
  80. baseResp.Ret = response.StatusCode
  81. if response.StatusCode != http.StatusOK {
  82. baseResp.Msg = fmt.Sprintf("请求失败,状态码:%d, 状态信息:%s", response.StatusCode, http.StatusText(response.StatusCode))
  83. return
  84. }
  85. bodyBytes, err := io.ReadAll(response.Body)
  86. if err != nil {
  87. err = fmt.Errorf("读取响应体失败: %w", err)
  88. return
  89. }
  90. baseResp.Data = bodyBytes
  91. return
  92. }