package rag import ( "eta/eta_api/controllers" "eta/eta_api/models" "eta/eta_api/models/system" "eta/eta_api/services/llm/facade" "eta/eta_api/utils" "eta/eta_api/utils/ws" "fmt" "github.com/gorilla/websocket" "net" "net/http" "strings" "time" ) type ChatWsController struct { controllers.BaseAuthController } func (cc *ChatWsController) Prepare() { method := cc.Ctx.Input.Method() uri := cc.Ctx.Input.URI() if method == "GET" { authorization := cc.Ctx.Input.Header("authorization") if authorization == "" { authorization = cc.Ctx.Input.Header("Authorization") } if strings.Contains(authorization, ";") { authorization = strings.Replace(authorization, ";", "$", 1) } if authorization == "" { strArr := strings.Split(uri, "?") for k, v := range strArr { fmt.Println(k, v) } if len(strArr) > 1 { authorization = strArr[1] authorization = strings.Replace(authorization, "Authorization", "authorization", -1) } } if authorization == "" { utils.FileLog.Error("authorization为空,未授权") cc.Ctx.ResponseWriter.WriteHeader(http.StatusUnauthorized) return } tokenStr := authorization tokenArr := strings.Split(tokenStr, "=") token := tokenArr[1] session, err := system.GetSysSessionByToken(token) if err != nil { if utils.IsErrNoRow(err) { utils.FileLog.Error("authorization已过期") cc.Ctx.ResponseWriter.WriteHeader(http.StatusUnauthorized) return } utils.FileLog.Error("authorization查询用户信息失败") cc.Ctx.ResponseWriter.WriteHeader(http.StatusBadRequest) return } if session == nil { utils.FileLog.Error("会话不存在") cc.Ctx.ResponseWriter.WriteHeader(http.StatusBadRequest) return } //校验token是否合法 // JWT校验Token和Account account := utils.MD5(session.UserName) if !utils.CheckToken(account, token) { utils.FileLog.Error("authorization校验不合法") cc.Ctx.ResponseWriter.WriteHeader(http.StatusUnauthorized) return } if time.Now().After(session.ExpiredTime) { utils.FileLog.Error("authorization过期法") cc.Ctx.ResponseWriter.WriteHeader(http.StatusUnauthorized) return } admin, err := system.GetSysUserById(session.SysUserId) if err != nil { if utils.IsErrNoRow(err) { utils.FileLog.Error("权限不够") cc.Ctx.ResponseWriter.WriteHeader(http.StatusForbidden) return } utils.FileLog.Error("获取用户信息失败") cc.Ctx.ResponseWriter.WriteHeader(http.StatusBadRequest) return } if admin == nil { utils.FileLog.Error("权限不够") cc.Ctx.ResponseWriter.WriteHeader(http.StatusForbidden) return } //如果不是启用状态 if admin.Enabled != 1 { utils.FileLog.Error("用户被禁用") cc.Ctx.ResponseWriter.WriteHeader(http.StatusForbidden) return } //接口权限校验 roleId := admin.RoleId list, e := system.GetMenuButtonApisByRoleId(roleId) if e != nil { utils.FileLog.Error("接口权限查询出错", e) cc.Ctx.ResponseWriter.WriteHeader(http.StatusForbidden) return } var api string for _, v := range list { if v.Api != "" { api += v.Api + "&" } } api += "&" + models.BusinessConfMap["PublicApi"] //处理uri请求,去除前缀和参数 api = strings.TrimRight(api, "&") uri = strings.Replace(uri, "/adminapi", "", 1) uris := strings.Split(uri, "?") uri = uris[0] //fmt.Println("uri:", uri) apis := strings.Split(api, "&") apiMap := make(map[string]bool, 0) for _, s := range apis { apiMap[s] = true } if !apiMap[uri] { utils.FileLog.Error("用户无权访问") cc.Ctx.ResponseWriter.WriteHeader(http.StatusForbidden) return } cc.SysUser = admin } else { utils.FileLog.Error("请求方法类型错误") cc.Ctx.ResponseWriter.WriteHeader(http.StatusBadRequest) return } } // ChatConnect @Title 知识库问答创建对话连接 // @Description 知识库问答创建对话连接 // @Success 101 {object} response.ListResp // @router /chat/connect [get] func (cc *ChatWsController) ChatConnect() { if !ws.Allow(cc.SysUser.AdminId, ws.CONNECT_LIMITER) { utils.FileLog.Error("WebSocket连接太频繁,主动拒绝链接") cc.Ctx.ResponseWriter.WriteHeader(http.StatusTooManyRequests) return } wsCon, err := webSocketHandler(cc.Ctx.ResponseWriter, cc.Ctx.Request) if err != nil { utils.FileLog.Error("WebSocket连接失败:", err) cc.Ctx.ResponseWriter.WriteHeader(http.StatusBadRequest) return } facade.AddSession(cc.SysUser.AdminId, wsCon) } // upGrader 用于将HTTP连接升级为WebSocket连接 var upGrader = websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, CheckOrigin: func(r *http.Request) bool { return true }, } // WebSocketHandler 处理WebSocket连接 func webSocketHandler(w http.ResponseWriter, r *http.Request) (conn *websocket.Conn, err error) { conn, err = upGrader.Upgrade(w, r, nil) if err != nil { utils.FileLog.Error("升级协议失败:WebSocket:%s", err.Error()) return } // 获取底层 TCP 连接并设置保活 if tcpConn, ok := conn.NetConn().(*net.TCPConn); ok { _ = tcpConn.SetKeepAlive(true) _ = tcpConn.SetKeepAlivePeriod(ws.TcpTimeout) utils.FileLog.Info("TCP KeepAlive 已启用") } _ = conn.SetReadDeadline(time.Now().Add(ws.ReadTimeout)) return }