kobe6258 1 月之前
父節點
當前提交
46726f4a95

+ 72 - 0
controllers/rag/chat_controller.go

@@ -0,0 +1,72 @@
+package rag
+
+import (
+	"eta/eta_api/controllers"
+	"eta/eta_api/models/system"
+	"eta/eta_api/services"
+	"eta/eta_api/services/llm/facade"
+	"eta/eta_api/utils"
+	"eta/eta_api/utils/ws"
+	"github.com/gorilla/websocket"
+	"net/http"
+	"time"
+)
+
+type ChatController struct {
+	controllers.BaseAuthController
+}
+
+func (cc *ChatController) Prepare() {
+	cc.SysUser = cc.Ctx.Input.GetData("admin").(*system.Admin)
+	if cc.SysUser == nil || cc.SysUser.AdminId == 0 {
+		utils.FileLog.Error("用户信息不存在")
+		cc.Ctx.ResponseWriter.WriteHeader(http.StatusBadRequest)
+		return
+	}
+}
+
+// @Title 知识库问答接口
+// @Description 知识库问答接口
+// @Param	request	body aimod.ChatReq true "type json string"
+// @Success 200 {object} response.ListResp
+// @router /chat/connect [get]
+func (cc *ChatController) ChatConnect() {
+	//if !ws.Allow(cc.SysUser.AdminId) {
+	//	utils.FileLog.Error("WebSocket连接太频繁,主动拒绝链接")
+	//	cc.Ctx.ResponseWriter.WriteHeader(http.StatusTooManyRequests)
+	//	return
+	//}
+	wsCon, err := webSocketHandler(cc.Ctx.ResponseWriter, cc.Ctx.Request)
+	if err != nil {
+		utils.FileLog.Error("WebSocket连接失败:", err)
+		cc.Ctx.ResponseWriter.WriteHeader(http.StatusBadRequest)
+		return
+	}
+	session := &ws.Session{
+		UserID: cc.SysUser.AdminId,
+		ID:     facade.GenerateSessionCode(),
+		Conn:   wsCon}
+
+	facade.AddSession(session)
+}
+
+// upGrader 用于将HTTP连接升级为WebSocket连接
+var upGrader = websocket.Upgrader{
+	ReadBufferSize:  1024,
+	WriteBufferSize: 1024,
+	CheckOrigin: func(r *http.Request) bool {
+		return true
+	},
+}
+
+// WebSocketHandler 处理WebSocket连接
+func webSocketHandler(w http.ResponseWriter, r *http.Request) (conn *websocket.Conn, err error) {
+	conn, err = upGrader.Upgrade(w, r, nil)
+	if err != nil {
+		utils.FileLog.Error("升级协议失败:WebSocket:%s", err.Error())
+		return
+	}
+	_ = conn.SetReadDeadline(time.Now().Add(time.Second * 60))
+	services.HandleWebSocketConnection(conn)
+	return
+}

+ 7 - 0
controllers/rag/llm_http/request.go

@@ -0,0 +1,7 @@
+package llm_http
+
+type LLMQuestionReq struct {
+	Question      string `description:"提问"`
+	KnowledgeBase string `description:"知识库"`
+	SessionId     string `description:"会话ID"`
+}

+ 7 - 0
controllers/rag/llm_http/response.go

@@ -0,0 +1,7 @@
+package llm_http
+
+
+type LLMQuestionRes struct {
+	Answer      string `description:"回答"`
+	SessionId     string `description:"会话ID"`
+}

+ 2 - 2
main.go

@@ -12,7 +12,7 @@ import (
 	_ "eta/eta_api/routers"
 	"eta/eta_api/services"
 	"eta/eta_api/utils"
-
+	_ "eta/eta_api/utils/llm/deepseek"
 	"github.com/beego/beego/v2/adapter/logs"
 	"github.com/beego/beego/v2/server/web"
 	"github.com/beego/beego/v2/server/web/context"
@@ -23,7 +23,7 @@ func main() {
 		web.BConfig.WebConfig.DirectoryIndex = true
 		web.BConfig.WebConfig.StaticDir["/swagger"] = "swagger"
 	}
-
+	web.Router("/", &controllers.BaseCommonController{})
 	go services.Task()
 
 	// 异常处理

+ 9 - 0
routers/commentsRouter.go

@@ -8521,6 +8521,15 @@ func init() {
             Filters: nil,
             Params: nil})
 
+    beego.GlobalControllerRouter["eta/eta_api/controllers/rag:ChatController"] = append(beego.GlobalControllerRouter["eta/eta_api/controllers/rag:ChatController"],
+        beego.ControllerComments{
+            Method: "ChatConnect",
+            Router: `/chat/connect`,
+            AllowHTTPMethods: []string{"get"},
+            MethodParams: param.Make(),
+            Filters: nil,
+            Params: nil})
+
     beego.GlobalControllerRouter["eta/eta_api/controllers/report_approve:ReportApproveController"] = append(beego.GlobalControllerRouter["eta/eta_api/controllers/report_approve:ReportApproveController"],
         beego.ControllerComments{
             Method: "Approve",

+ 8 - 0
routers/router.go

@@ -31,6 +31,7 @@ import (
 	"eta/eta_api/controllers/eta_trial"
 	"eta/eta_api/controllers/fe_calendar"
 	"eta/eta_api/controllers/material"
+	"eta/eta_api/controllers/rag"
 	"eta/eta_api/controllers/report_approve"
 	"eta/eta_api/controllers/residual_analysis"
 	"eta/eta_api/controllers/roadshow"
@@ -39,6 +40,7 @@ import (
 	"eta/eta_api/controllers/smart_report"
 	"eta/eta_api/controllers/speech_recognition"
 	"eta/eta_api/controllers/trade_analysis"
+	"eta/eta_api/services"
 	"github.com/beego/beego/v2/server/web"
 	"github.com/beego/beego/v2/server/web/filter/cors"
 )
@@ -52,6 +54,7 @@ func init() {
 		ExposeHeaders:    []string{"Content-Length", "Access-Control-Allow-Origin", "Access-Control-Allow-Headers", "Content-Type"},
 		AllowCredentials: true,
 	}))
+	web.InsertFilter("/adminapi/llm/chat/connect", web.BeforeRouter,services.WsAuthenticate())
 	ns := web.NewNamespace("/adminapi",
 		web.NSNamespace("/sysuser",
 			web.NSInclude(
@@ -66,6 +69,11 @@ func init() {
 				&controllers.ClassifyController{},
 			),
 		),
+		web.NSNamespace("/llm",
+			web.NSInclude(
+				&rag.ChatController{},
+			),
+		),
 		web.NSNamespace("/banner",
 			web.NSInclude(
 				&controllers.BannerController{},

+ 27 - 0
services/llm/facade/llm_service.go

@@ -0,0 +1,27 @@
+package facade
+
+import (
+	"eta/eta_api/utils/llm"
+	"eta/eta_api/utils/ws"
+	"fmt"
+	"github.com/rdlucklib/rdluck_tools/uuid"
+)
+
+var (
+	deepseekService, _ = llm.GetInstance(llm.LLM_DEEPSEEK)
+)
+
+
+
+func GenerateSessionCode() (code string) {
+	return fmt.Sprintf("%s%s", "llm_session_", uuid.NewUUID().Hex32())
+}
+
+func GetSession(userId int, sessionId string) (session *ws.Session, ok bool) {
+	token := fmt.Sprintf("%d_%s", userId, sessionId)
+	return ws.Manager().Get(token)
+}
+
+func AddSession(session *ws.Session) {
+	ws.Manager().Add(session)
+}

+ 145 - 0
services/ws_service.go

@@ -0,0 +1,145 @@
+package services
+
+import (
+	"eta/eta_api/models"
+	"eta/eta_api/models/system"
+	"eta/eta_api/utils"
+	"eta/eta_api/utils/ws"
+	"fmt"
+	"github.com/beego/beego/v2/server/web"
+	"github.com/beego/beego/v2/server/web/context"
+	"github.com/gorilla/websocket"
+	"net/http"
+	"strings"
+	"time"
+)
+
+func HandleWebSocketConnection(conn *websocket.Conn) {
+	ws.Manager().HandleWebSocketConnection(conn)
+}
+func WsAuthenticate() web.FilterFunc {
+	return func(ctx *context.Context) {
+		method := ctx.Input.Method()
+		uri := ctx.Input.URI()
+		if method == "POST" || method == "GET" {
+			authorization := ctx.Input.Header("authorization")
+			if authorization == "" {
+				authorization = ctx.Input.Header("Authorization")
+			}
+			if strings.Contains(authorization, ";") {
+				authorization = strings.Replace(authorization, ";", "$", 1)
+			}
+			if authorization == "" {
+				strArr := strings.Split(uri, "?")
+				for k, v := range strArr {
+					fmt.Println(k, v)
+				}
+				if len(strArr) > 1 {
+					authorization = strArr[1]
+					authorization = strings.Replace(authorization, "Authorization", "authorization", -1)
+				}
+			}
+			if authorization == "" {
+				utils.FileLog.Error("authorization为空,未授权")
+				ctx.ResponseWriter.WriteHeader(http.StatusUnauthorized)
+				return
+			}
+			tokenStr := authorization
+			tokenArr := strings.Split(tokenStr, "=")
+			token := tokenArr[1]
+
+			//accountStr := authorizationArr[1]
+			//accountArr := strings.Split(accountStr, "=")
+			//account := accountArr[1]
+
+			session, err := system.GetSysSessionByToken(token)
+			if err != nil {
+				if utils.IsErrNoRow(err) {
+					utils.FileLog.Error("authorization已过期")
+					ctx.ResponseWriter.WriteHeader(http.StatusUnauthorized)
+					return
+				}
+				utils.FileLog.Error("authorization查询用户信息失败")
+				ctx.ResponseWriter.WriteHeader(http.StatusBadRequest)
+				return
+			}
+			if session == nil {
+				utils.FileLog.Error("会话不存在")
+				ctx.ResponseWriter.WriteHeader(http.StatusBadRequest)
+				return
+			}
+			//校验token是否合法
+			// JWT校验Token和Account
+			account := utils.MD5(session.UserName)
+			if !utils.CheckToken(account, token) {
+				utils.FileLog.Error("authorization校验不合法")
+				ctx.ResponseWriter.WriteHeader(http.StatusUnauthorized)
+				return
+			}
+			if time.Now().After(session.ExpiredTime) {
+				utils.FileLog.Error("authorization过期法")
+				ctx.ResponseWriter.WriteHeader(http.StatusUnauthorized)
+				return
+			}
+			admin, err := system.GetSysUserById(session.SysUserId)
+			if err != nil {
+				if utils.IsErrNoRow(err) {
+					utils.FileLog.Error("权限不够")
+					ctx.ResponseWriter.WriteHeader(http.StatusForbidden)
+					return
+				}
+				utils.FileLog.Error("获取用户信息失败")
+				ctx.ResponseWriter.WriteHeader(http.StatusBadRequest)
+				return
+			}
+			if admin == nil {
+				utils.FileLog.Error("权限不够")
+				ctx.ResponseWriter.WriteHeader(http.StatusForbidden)
+				return
+			}
+			//如果不是启用状态
+			if admin.Enabled != 1 {
+				utils.FileLog.Error("用户被禁用")
+				ctx.ResponseWriter.WriteHeader(http.StatusForbidden)
+				return
+			}
+
+			//接口权限校验
+			roleId := admin.RoleId
+			list, e := system.GetMenuButtonApisByRoleId(roleId)
+			if e != nil {
+				utils.FileLog.Error("接口权限查询出错", e)
+				ctx.ResponseWriter.WriteHeader(http.StatusForbidden)
+				return
+			}
+			var api string
+			for _, v := range list {
+				if v.Api != "" {
+					api += v.Api + "&"
+				}
+			}
+			api += "&" + models.BusinessConfMap["PublicApi"]
+			//处理uri请求,去除前缀和参数
+			api = strings.TrimRight(api, "&")
+			uri = strings.Replace(uri, "/adminapi", "", 1)
+			uris := strings.Split(uri, "?")
+			uri = uris[0]
+			//fmt.Println("uri:", uri)
+			apis := strings.Split(api, "&")
+			apiMap := make(map[string]bool, 0)
+			for _, s := range apis {
+				apiMap[s] = true
+			}
+			if !apiMap[uri] {
+				utils.FileLog.Error("用户无权访问")
+				ctx.ResponseWriter.WriteHeader(http.StatusForbidden)
+				return
+			}
+			ctx.Input.SetData("admin", admin)
+		} else {
+			utils.FileLog.Error("请求方法类型错误")
+			ctx.ResponseWriter.WriteHeader(http.StatusBadRequest)
+			return
+		}
+	}
+}

+ 6 - 0
utils/config.go

@@ -11,6 +11,10 @@ import (
 	"github.com/spf13/viper"
 )
 
+// 大模型配置
+var (
+	DS_LLM_SERVER string //模型服务地址
+)
 var (
 	RunMode          string //运行模式
 	MYSQL_URL        string //数据库连接
@@ -629,6 +633,8 @@ func init() {
 	// chrome配置
 	ChromePath = config["chrome_path"]
 
+	//模型服务器地址
+	DS_LLM_SERVER = config["llm_server"]
 	// 初始化ES
 	initEs()
 

+ 38 - 0
utils/llm/deepseek/deekseek.go

@@ -0,0 +1,38 @@
+package deepseek
+
+import (
+	"eta/eta_api/utils"
+	"eta/eta_api/utils/llm"
+	"sync"
+)
+
+var (
+	dsOnce sync.Once
+
+	deepseekClient *DeepSeekClient
+)
+
+type DeepSeekClient struct {
+	*llm.LLMClient
+}
+
+func Getinstance() llm.LLMService {
+	dsOnce.Do(func() {
+		if deepseekClient == nil {
+			deepseekClient = &DeepSeekClient{
+				LLMClient: llm.NewLLMClient(utils.DS_LLM_SERVER, 10),
+			}
+		}
+	})
+	return deepseekClient
+}
+
+func (ds *DeepSeekClient) AskQuestion() string {
+	return ""
+}
+func init() {
+	err := llm.Register(llm.LLM_DEEPSEEK, Getinstance())
+	if err != nil {
+		utils.FileLog.Error("注册deepseek服务失败:", err)
+	}
+}

+ 24 - 0
utils/llm/llm_client.go

@@ -0,0 +1,24 @@
+package llm
+
+import (
+	"net/http"
+	"time"
+)
+
+type LLMClient struct {
+	BaseURL    string
+	HttpClient *http.Client
+}
+
+func NewLLMClient(baseURL string, timeout time.Duration) *LLMClient {
+	return &LLMClient{
+		BaseURL: baseURL,
+		HttpClient: &http.Client{
+			Timeout: timeout * time.Second,
+		},
+	}
+}
+
+type LLMService interface {
+	AskQuestion() string
+}

+ 39 - 0
utils/llm/llm_factory.go

@@ -0,0 +1,39 @@
+package llm
+
+import (
+	"errors"
+)
+
+var (
+	llmInstanceMap map[string]LLMService = make(map[string]LLMService)
+)
+
+const (
+	LLM_DEEPSEEK = "deepseek"
+)
+
+func Register(name string, llmClient LLMService) (err error) {
+	if name == "" {
+		err = errors.New("模型实例名不能为空")
+		return
+	}
+	if _, ok := llmInstanceMap[name]; ok {
+		err = errors.New("模型实例已经存在")
+		return
+	}
+	llmInstanceMap[name] = llmClient
+	return
+}
+func GetInstance(name string) (llmClient LLMService, err error) {
+	if name == "" {
+		err = errors.New("模型实例名不能为空")
+		return
+	}
+	if _, ok := llmInstanceMap[name]; !ok {
+		err = errors.New("当前模型类型不支持")
+		return
+	}
+	llmClient = llmInstanceMap[name]
+	return
+}
+

+ 77 - 0
utils/ws/limiter.go

@@ -0,0 +1,77 @@
+package ws
+
+import (
+	"fmt"
+	"golang.org/x/time/rate"
+	"sync"
+	"time"
+)
+
+var (
+	limiterManager *QALimiterManger
+	limiterOnce    sync.Once
+)
+
+const (
+	LIMITER_KEY = "llm_chat_key_user_%d"
+)
+
+type QALimiterManger struct {
+	sync.RWMutex
+	limiterMap map[string]*QALimiter
+}
+
+type QALimiter struct {
+	LastRequest time.Time
+	*rate.Limiter
+}
+
+//func (qaLimiter *QALimiter) Allow() bool {
+//	return qaLimiter.Limiter.Allow()
+//}
+
+// GetLimiter 获取或创建用户的限流器
+func (qalm *QALimiterManger) GetLimiter(token string) *QALimiter {
+	qalm.Lock()
+	defer qalm.Unlock()
+
+	if limiter, exists := qalm.limiterMap[token]; exists {
+		return limiter
+	}
+
+	// 创建一个新的限流器,例如每10秒1个请求
+	limiter := &QALimiter{
+		Limiter: rate.NewLimiter(rate.Every(10*time.Second), 1),
+	}
+	qalm.limiterMap[token] = limiter
+	return limiter
+}
+
+func (qalm *QALimiterManger) Allow(token string) bool {
+
+	limiter := qalm.GetLimiter(token)
+	if limiter.LastRequest.IsZero() {
+		limiter.LastRequest = time.Now()
+		return limiter.Allow()
+	}
+	if time.Now().Sub(limiter.LastRequest) < 10*time.Second {
+		return false
+	}
+	limiter.LastRequest = time.Now()
+	return limiter.Allow()
+}
+func getInstance() *QALimiterManger {
+	limiterOnce.Do(func() {
+		if limiterManager == nil {
+			limiterManager = &QALimiterManger{
+				limiterMap: make(map[string]*QALimiter),
+			}
+		}
+	})
+	return limiterManager
+}
+
+func Allow(userId int) bool {
+	token := fmt.Sprintf(LIMITER_KEY, userId)
+	return getInstance().Allow(token)
+}

+ 42 - 0
utils/ws/session.go

@@ -0,0 +1,42 @@
+package ws
+
+import (
+	"github.com/gorilla/websocket"
+	"sync"
+	"time"
+)
+
+// Session 会话结构
+type Session struct {
+	ID          string
+	UserID      int
+	Conn        *websocket.Conn
+	LastActive  time.Time
+	qaLimiter   *QALimiter
+	Latency     *LatencyMeasurer
+	CloseChan   chan struct{}
+	MessageChan chan []byte
+	mu          sync.RWMutex
+}
+
+// HeartbeatManager 心跳管理器
+type HeartbeatManager struct {
+	interval  time.Duration
+	sessions  sync.Map
+	closeChan chan struct{}
+}
+
+// LatencyMeasurer 延迟测量器
+type LatencyMeasurer struct {
+	measurements []time.Duration
+	lastLatency  time.Duration
+	mu           sync.Mutex
+}
+
+// NewHeartbeatManager 创建心跳管理器
+func NewHeartbeatManager(interval time.Duration) *HeartbeatManager {
+	return &HeartbeatManager{
+		interval:  interval,
+		closeChan: make(chan struct{}),
+	}
+}

+ 141 - 0
utils/ws/session_manager.go

@@ -0,0 +1,141 @@
+package ws
+
+import (
+	"eta/eta_api/utils"
+	"fmt"
+	"github.com/gorilla/websocket"
+	"math/rand"
+	"net"
+	"time"
+)
+const (
+	maxMessageSize  = 1024 * 1024 // 1MB
+	basePingInterval = 30 * time.Second
+	maxPingInterval  = 120 * time.Second
+	minPingInterval  = 15 * time.Second
+)
+type ConnectionManager struct {
+	Sessions    map[string]*Session
+	heartbeat   *HeartbeatManager
+}
+
+var (
+	manager = &ConnectionManager{
+		Sessions:  make(map[string]*Session),
+		heartbeat: NewHeartbeatManager(30 * time.Second),
+	}
+)
+
+func Manager() *ConnectionManager {
+	return manager
+}
+
+// Add 添加一个新的会话
+func (manager *ConnectionManager) Add(session *Session) {
+	manager.Lock()
+	defer manager.Unlock()
+	manager.Sessions[manager.GetSessionId(session.UserId, session.SessionId)] = session
+}
+func (manager *ConnectionManager) GetSessionId(userId int, sessionId string) (sessionID string) {
+	return fmt.Sprintf("%d_%s", userId, sessionId)
+}
+
+// Remove 移除一个会话
+func (manager *ConnectionManager) Remove(sessionCode string) {
+	delete(manager.Sessions, sessionCode)
+}
+
+func (manager *ConnectionManager) Get(sessionID string) (session *Session, ok bool) {
+	session, ok = manager.Sessions[sessionID]
+	return
+}
+func (manager *ConnectionManager) HeartBeat(session *Session) {
+	fmt.Println("执行心跳")
+	if err := session.Conn.WriteMessage(websocket.PingMessage, nil); err != nil {
+		err = session.Conn.Close()
+		if err != nil {
+			utils.FileLog.Error("关闭长连接失败: %v", err)
+			return
+		}
+		delete(manager.Sessions, manager.GetSessionId(session.UserId, session.SessionId))
+	}
+}
+func (manager *ConnectionManager) HandleWebSocketConnection(conn *websocket.Conn) {
+	defer func() {
+		if err := conn.Close(); err != nil {
+			handleClose(err)
+		}
+	}()
+	// 获取底层 TCP 连接并设置保活
+	if tcpConn, ok := conn.NetConn().(*net.TCPConn); ok {
+		_ = tcpConn.SetKeepAlive(true)
+		_ = tcpConn.SetKeepAlivePeriod(90 * time.Second)
+		utils.FileLog.Info("TCP KeepAlive 已启用")
+	}
+	// 初始化心跳间隔(基础值)
+	baseInterval := 30 * time.Second
+	adjustHeartbeatInterval(conn, baseInterval)
+	// 设置心跳检测
+	conn.SetPongHandler(func(string) error {
+		err := conn.SetReadDeadline(time.Now().Add(60 * time.Second))
+		if err != nil {
+			utils.FileLog.Error("设置读取超时失败:WebSocket:", err)
+		}
+		return nil
+	})
+	// 消息处理循环
+	for {
+		messageType, message, err := conn.ReadMessage()
+		if err != nil {
+			utils.FileLog.Error("Read error:", err)
+			return
+		}
+		// 业务处理逻辑
+		response := processMessage(message)
+		// 返回响应
+		if err = conn.WriteMessage(messageType, response); err != nil {
+			utils.FileLog.Error("Write error:", err)
+			return
+		}
+	}
+}
+func handleClose(err error) {
+	if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) {
+		if wsErr, ok := err.(*websocket.CloseError); !ok {
+			utils.FileLog.Error("未知错误 %s", err.Error())
+		} else {
+			switch wsErr.Code {
+			case websocket.CloseNormalClosure:
+				utils.FileLog.Info("正常关闭连接")
+			default:
+				utils.FileLog.Error("关闭代码:%d:%s", wsErr.Code, wsErr.Text)
+			}
+		}
+
+	}
+}
+
+// 动态调整心跳间隔(需配合业务逻辑调用)
+func adjustHeartbeatInterval(conn *websocket.Conn, baseInterval time.Duration) {
+	// 模拟网络延迟计算(实际应通过Ping-Pong测量)
+	latency := time.Duration(rand.Intn(100)) * time.Millisecond
+	newInterval := baseInterval + latency*2
+
+	// 创建新的心跳定时器
+	ticker := time.NewTicker(newInterval)
+	defer ticker.Stop()
+
+	go func() {
+		for range ticker.C {
+			if err := conn.WriteControl(websocket.PingMessage, nil, time.Now().Add(5*time.Second)); err != nil {
+				utils.FileLog.Error("发送心跳包失败:", err)
+				return
+			}
+		}
+	}()
+	utils.FileLog.Info("心跳间隔调整为: %v", newInterval)
+}
+func processMessage(msg []byte) []byte {
+	// 实现具体的业务逻辑
+	return []byte("Received: " + string(msg))
+}