llm_service.go 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. package facade
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "eta/eta_api/models/rag"
  6. localService "eta/eta_api/services/llm"
  7. "eta/eta_api/services/llm/facade/bus_response"
  8. "eta/eta_api/utils"
  9. "eta/eta_api/utils/llm"
  10. "eta/eta_api/utils/llm/eta_llm/eta_llm_http"
  11. "eta/eta_api/utils/ws"
  12. "fmt"
  13. "github.com/gorilla/websocket"
  14. "github.com/rdlucklib/rdluck_tools/uuid"
  15. "gorm.io/gorm"
  16. "os"
  17. )
  18. var (
  19. llmService, _ = llm.GetInstance(llm.ETA_LLM_CLIENT)
  20. )
  21. func generateSessionCode() (code string) {
  22. return fmt.Sprintf("%s%s", "llm_session_", uuid.NewUUID().Hex32())
  23. }
  24. // AddSession 创建会话session
  25. func AddSession(userId int, conn *websocket.Conn) {
  26. sessionId := generateSessionCode()
  27. session := ws.NewSession(userId, sessionId, conn)
  28. ws.Manager().AddSession(session)
  29. }
  30. // LLMKnowledgeBaseSearchDocs 搜索知识库
  31. func LLMKnowledgeBaseSearchDocs(search LLMKnowledgeSearch) (resp bus_response.SearchDocsEtaResponse, err error) {
  32. docs, err := llmService.SearchKbDocs(search.Query, search.KnowledgeBaseName)
  33. if err != nil {
  34. return
  35. }
  36. for _, doc := range docs.([]eta_llm_http.SearchDocsResponse) {
  37. resp.Content = resp.Content + doc.PageContent
  38. }
  39. resp.Docs = docs.([]eta_llm_http.SearchDocsResponse)
  40. return
  41. }
  42. // AIGCBaseOnPromote aigc 生成内容
  43. func AIGCBaseOnPromote(aigc AIGC) (resp bus_response.AIGCEtaResponse, err error) {
  44. mapping, queryErr := rag.GetArticleKbMapping(aigc.ArticleId)
  45. if queryErr != nil && !errors.Is(queryErr, gorm.ErrRecordNotFound) {
  46. utils.FileLog.Error("获取文章知识库信息失败,err: %v", queryErr)
  47. err = fmt.Errorf("获取文章知识库信息失败,err: %v", queryErr)
  48. return
  49. } else {
  50. var kbId string
  51. var file *os.File
  52. var params map[string]interface{}
  53. if mapping.Id == 0 || mapping.KbId == "" {
  54. article, fileErr := rag.GetArticleById(aigc.ArticleId)
  55. if fileErr != nil {
  56. // 找不到就处理失败
  57. utils.FileLog.Error("公众号文章不存在")
  58. err = fmt.Errorf("公众号文章不存在")
  59. return
  60. }
  61. // 文章加入到知识库
  62. path, fileErr := localService.CreateArticleFile(article)
  63. if fileErr != nil {
  64. utils.FileLog.Error("创建文章文件失败,err: %v", fileErr)
  65. err = fmt.Errorf("创建文章文件失败,err: %v", fileErr)
  66. return
  67. }
  68. defer func() {
  69. _ = os.Remove(path)
  70. }()
  71. file, err = os.Open(path)
  72. if err != nil {
  73. utils.FileLog.Error("打开文件失败,err:", err)
  74. return
  75. }
  76. uploadResp, httpErr := llmService.UploadFileToTemplate([]*os.File{file}, params)
  77. if httpErr != nil {
  78. utils.FileLog.Error("上传文件失败,err:", err.Error())
  79. err = fmt.Errorf("上传文件失败,err:%v", httpErr)
  80. return
  81. }
  82. data := uploadResp.(eta_llm_http.UploadDocsResponse)
  83. //保存映射关系到数据库
  84. if data.Id == "" {
  85. utils.FileLog.Error("上传文件失败,向量库Id获取失败")
  86. err = fmt.Errorf("上传文件失败,向量库Id获取失败")
  87. return
  88. }
  89. err = rag.CreateArticleKbMapping(rag.ArticleKbMapping{
  90. WechatArticleId: aigc.ArticleId,
  91. KbId: data.Id,
  92. })
  93. if err != nil {
  94. utils.FileLog.Warn("创建文章知识库映射关系失败,err:", err.Error())
  95. }
  96. kbId = data.Id
  97. } else {
  98. kbId = mapping.KbId
  99. }
  100. //知识库对话
  101. response, httpErr := llmService.FileChat(aigc.Promote, kbId, nil)
  102. if httpErr != nil {
  103. utils.FileLog.Error("内容生成失败,err:", err.Error())
  104. err = fmt.Errorf("内容生成失败,err:%v", httpErr)
  105. return
  106. }
  107. if !response.Success {
  108. utils.FileLog.Error("内容生成失败,code:%v,msg:%v", response.Ret, response.Msg)
  109. err = fmt.Errorf("内容生成失败,code:%v,msg:%v", response.Ret, response.Msg)
  110. return
  111. } else {
  112. var baseResp eta_llm_http.RagBaseResponse
  113. parseErr := json.Unmarshal(response.Data, &baseResp)
  114. if parseErr != nil {
  115. utils.FileLog.Error("内容生成失败,code:%v,msg:%v", parseErr)
  116. err = fmt.Errorf("内容生成失败,err:%v", parseErr)
  117. return
  118. }
  119. if baseResp.Code != 200 {
  120. if baseResp.Code == 404 {
  121. params["PrevId"] = kbId
  122. _, putErr := llmService.UploadFileToTemplate([]*os.File{file}, params)
  123. if putErr != nil {
  124. utils.FileLog.Error("内容生成失败,err:", err.Error())
  125. err = fmt.Errorf("内容生成失败,err:%v", httpErr)
  126. return
  127. }
  128. } else {
  129. utils.FileLog.Error("内容生成失败,code:%v,msg:%v", baseResp.Code, baseResp.Msg)
  130. err = fmt.Errorf("内容生成失败,code:%v,msg:%v", baseResp.Code, baseResp.Msg)
  131. return
  132. }
  133. }
  134. gcResp, gcErr := llmService.FileChat(aigc.Promote, kbId, nil)
  135. if gcErr != nil {
  136. utils.FileLog.Error("内容生成失败,err:%v", gcErr.Error())
  137. err = fmt.Errorf("内容生成失败,err:%v", gcErr)
  138. return
  139. }
  140. if !gcResp.Success {
  141. utils.FileLog.Error("内容生成失败,code:%v,msg:%v", gcResp.Ret, gcResp.Msg)
  142. err = fmt.Errorf("内容生成失败,err:%v", gcResp.Msg)
  143. }
  144. var steamResp eta_llm_http.ContentResponse
  145. parseErr = json.Unmarshal(gcResp.Data, &steamResp)
  146. if parseErr != nil {
  147. utils.FileLog.Error("内容生成失败,code:%v,msg:%v", parseErr)
  148. err = fmt.Errorf("内容生成失败,err:%v", parseErr)
  149. return
  150. }
  151. parseErr = json.Unmarshal(steamResp.Data, &resp)
  152. if parseErr != nil {
  153. utils.FileLog.Error("内容生成失败,code:%v,msg:%v", parseErr)
  154. err = fmt.Errorf("内容生成失败,err:%v", parseErr)
  155. return
  156. }
  157. }
  158. }
  159. return
  160. }
  161. type LLMKnowledgeSearch struct {
  162. Query string `json:"Query"`
  163. KnowledgeBaseName string `json:"KnowledgeBaseName"`
  164. }
  165. type AIGC struct {
  166. Promote string
  167. ArticleId int
  168. }