chat_controller.go 6.2 KB


  1. package rag
  2. import (
  3. "encoding/json"
  4. "eta/eta_api/controllers"
  5. "eta/eta_api/models"
  6. "eta/eta_api/models/llm"
  7. "eta/eta_api/models/system"
  8. "eta/eta_api/services/llm/facade"
  9. "eta/eta_api/utils"
  10. "eta/eta_api/utils/ws"
  11. "fmt"
  12. "github.com/gorilla/websocket"
  13. "net"
  14. "net/http"
  15. "strings"
  16. "time"
  17. )
  18. type ChatController struct {
  19. controllers.BaseAuthController
  20. }
  21. func (cc *ChatController) Prepare() {
  22. method := cc.Ctx.Input.Method()
  23. uri := cc.Ctx.Input.URI()
  24. if method == "GET" {
  25. authorization := cc.Ctx.Input.Header("authorization")
  26. if authorization == "" {
  27. authorization = cc.Ctx.Input.Header("Authorization")
  28. }
  29. if strings.Contains(authorization, ";") {
  30. authorization = strings.Replace(authorization, ";", "$", 1)
  31. }
  32. if authorization == "" {
  33. strArr := strings.Split(uri, "?")
  34. for k, v := range strArr {
  35. fmt.Println(k, v)
  36. }
  37. if len(strArr) > 1 {
  38. authorization = strArr[1]
  39. authorization = strings.Replace(authorization, "Authorization", "authorization", -1)
  40. }
  41. }
  42. if authorization == "" {
  43. utils.FileLog.Error("authorization为空,未授权")
  44. cc.Ctx.ResponseWriter.WriteHeader(http.StatusUnauthorized)
  45. return
  46. }
  47. tokenStr := authorization
  48. tokenArr := strings.Split(tokenStr, "=")
  49. token := tokenArr[1]
  50. session, err := system.GetSysSessionByToken(token)
  51. if err != nil {
  52. if utils.IsErrNoRow(err) {
  53. utils.FileLog.Error("authorization已过期")
  54. cc.Ctx.ResponseWriter.WriteHeader(http.StatusUnauthorized)
  55. return
  56. }
  57. utils.FileLog.Error("authorization查询用户信息失败")
  58. cc.Ctx.ResponseWriter.WriteHeader(http.StatusBadRequest)
  59. return
  60. }
  61. if session == nil {
  62. utils.FileLog.Error("会话不存在")
  63. cc.Ctx.ResponseWriter.WriteHeader(http.StatusBadRequest)
  64. return
  65. }
  66. //校验token是否合法
  67. // JWT校验Token和Account
  68. account := utils.MD5(session.UserName)
  69. if !utils.CheckToken(account, token) {
  70. utils.FileLog.Error("authorization校验不合法")
  71. cc.Ctx.ResponseWriter.WriteHeader(http.StatusUnauthorized)
  72. return
  73. }
  74. if time.Now().After(session.ExpiredTime) {
  75. utils.FileLog.Error("authorization过期法")
  76. cc.Ctx.ResponseWriter.WriteHeader(http.StatusUnauthorized)
  77. return
  78. }
  79. admin, err := system.GetSysUserById(session.SysUserId)
  80. if err != nil {
  81. if utils.IsErrNoRow(err) {
  82. utils.FileLog.Error("权限不够")
  83. cc.Ctx.ResponseWriter.WriteHeader(http.StatusForbidden)
  84. return
  85. }
  86. utils.FileLog.Error("获取用户信息失败")
  87. cc.Ctx.ResponseWriter.WriteHeader(http.StatusBadRequest)
  88. return
  89. }
  90. if admin == nil {
  91. utils.FileLog.Error("权限不够")
  92. cc.Ctx.ResponseWriter.WriteHeader(http.StatusForbidden)
  93. return
  94. }
  95. //如果不是启用状态
  96. if admin.Enabled != 1 {
  97. utils.FileLog.Error("用户被禁用")
  98. cc.Ctx.ResponseWriter.WriteHeader(http.StatusForbidden)
  99. return
  100. }
  101. //接口权限校验
  102. roleId := admin.RoleId
  103. list, e := system.GetMenuButtonApisByRoleId(roleId)
  104. if e != nil {
  105. utils.FileLog.Error("接口权限查询出错", e)
  106. cc.Ctx.ResponseWriter.WriteHeader(http.StatusForbidden)
  107. return
  108. }
  109. var api string
  110. for _, v := range list {
  111. if v.Api != "" {
  112. api += v.Api + "&"
  113. }
  114. }
  115. api += "&" + models.BusinessConfMap["PublicApi"]
  116. //处理uri请求,去除前缀和参数
  117. api = strings.TrimRight(api, "&")
  118. uri = strings.Replace(uri, "/adminapi", "", 1)
  119. uris := strings.Split(uri, "?")
  120. uri = uris[0]
  121. //fmt.Println("uri:", uri)
  122. apis := strings.Split(api, "&")
  123. apiMap := make(map[string]bool, 0)
  124. for _, s := range apis {
  125. apiMap[s] = true
  126. }
  127. if !apiMap[uri] {
  128. utils.FileLog.Error("用户无权访问")
  129. cc.Ctx.ResponseWriter.WriteHeader(http.StatusForbidden)
  130. return
  131. }
  132. cc.SysUser = admin
  133. } else {
  134. utils.FileLog.Error("请求方法类型错误")
  135. cc.Ctx.ResponseWriter.WriteHeader(http.StatusBadRequest)
  136. return
  137. }
  138. }
  139. // NewChat @Title 新建对话框
  140. // @Description 新建对话框
  141. // @Success 101 {object} response.ListResp
  142. // @router /chat/new_chat [post]
  143. func (kbctrl *KbController) NewChat() {
  144. br := new(models.BaseResponse).Init()
  145. defer func() {
  146. kbctrl.Data["json"] = br
  147. kbctrl.ServeJSON()
  148. }()
  149. var req facade.LLMKnowledgeSearch
  150. err := json.Unmarshal(kbctrl.Ctx.Input.RequestBody, &req)
  151. if err != nil {
  152. br.Msg = "参数解析异常!"
  153. br.ErrMsg = "参数解析失败,Err:" + err.Error()
  154. return
  155. }
  156. sysUser := kbctrl.SysUser
  157. if sysUser == nil {
  158. br.Msg = "请登录"
  159. br.ErrMsg = "请登录,SysUser Is Empty"
  160. br.Ret = 408
  161. return
  162. }
  163. session := llm.UserLlmChat{
  164. UserId: sysUser.AdminId,
  165. CreatedTime: time.Now(),
  166. ChatTitle: "新会话",
  167. }
  168. err = session.CreateChatSession()
  169. if err != nil {
  170. br.Msg = "创建失败"
  171. br.ErrMsg = "创建失败,Err:" + err.Error()
  172. return
  173. }
  174. br.Ret = 200
  175. br.Success = true
  176. br.Msg = "创建成功"
  177. }
  178. // ChatConnect @Title 知识库问答创建对话连接
  179. // @Description 知识库问答创建对话连接
  180. // @Success 101 {object} response.ListResp
  181. // @router /chat/connect [get]
  182. func (cc *ChatController) ChatConnect() {
  183. if !ws.Allow(cc.SysUser.AdminId, ws.CONNECT_LIMITER) {
  184. utils.FileLog.Error("WebSocket连接太频繁,主动拒绝链接")
  185. cc.Ctx.ResponseWriter.WriteHeader(http.StatusTooManyRequests)
  186. return
  187. }
  188. wsCon, err := webSocketHandler(cc.Ctx.ResponseWriter, cc.Ctx.Request)
  189. if err != nil {
  190. utils.FileLog.Error("WebSocket连接失败:", err)
  191. cc.Ctx.ResponseWriter.WriteHeader(http.StatusBadRequest)
  192. return
  193. }
  194. facade.AddSession(cc.SysUser.AdminId, wsCon)
  195. }
  196. // upGrader 用于将HTTP连接升级为WebSocket连接
  197. var upGrader = websocket.Upgrader{
  198. ReadBufferSize: 1024,
  199. WriteBufferSize: 1024,
  200. CheckOrigin: func(r *http.Request) bool {
  201. return true
  202. },
  203. }
  204. // WebSocketHandler 处理WebSocket连接
  205. func webSocketHandler(w http.ResponseWriter, r *http.Request) (conn *websocket.Conn, err error) {
  206. conn, err = upGrader.Upgrade(w, r, nil)
  207. if err != nil {
  208. utils.FileLog.Error("升级协议失败:WebSocket:%s", err.Error())
  209. return
  210. }
  211. // 获取底层 TCP 连接并设置保活
  212. if tcpConn, ok := conn.NetConn().(*net.TCPConn); ok {
  213. _ = tcpConn.SetKeepAlive(true)
  214. _ = tcpConn.SetKeepAlivePeriod(ws.TcpTimeout)
  215. utils.FileLog.Info("TCP KeepAlive 已启用")
  216. }
  217. _ = conn.SetReadDeadline(time.Now().Add(ws.ReadTimeout))
  218. return
  219. }