kobe6258 vor 1 Monat
Ursprung
Commit
b99dbafbca

+ 4 - 10
controllers/rag/chat_controller.go

@@ -25,10 +25,9 @@ func (cc *ChatController) Prepare() {
 	}
 }
 
-// @Title 知识库问答接口
-// @Description 知识库问答接口
-// @Param	request	body aimod.ChatReq true "type json string"
-// @Success 200 {object} response.ListResp
+// ChatConnect @Title 知识库问答创建对话连接
+// @Description 知识库问答创建对话连接
+// @Success 101 {object} response.ListResp
 // @router /chat/connect [get]
 func (cc *ChatController) ChatConnect() {
 	if !ws.Allow(cc.SysUser.AdminId, ws.CONNECT_LIMITER) {
@@ -42,12 +41,7 @@ func (cc *ChatController) ChatConnect() {
 		cc.Ctx.ResponseWriter.WriteHeader(http.StatusBadRequest)
 		return
 	}
-	session := &ws.Session{
-		UserID: cc.SysUser.AdminId,
-		ID:     facade.GenerateSessionCode(),
-		Conn:   wsCon}
-
-	facade.AddSession(session)
+	facade.AddSession(cc.SysUser.AdminId, wsCon)
 }
 
 // upGrader 用于将HTTP连接升级为WebSocket连接

+ 5 - 2
services/llm/facade/llm_service.go

@@ -4,6 +4,7 @@ import (
 	"eta/eta_api/utils/llm"
 	"eta/eta_api/utils/ws"
 	"fmt"
+	"github.com/gorilla/websocket"
 	"github.com/rdlucklib/rdluck_tools/uuid"
 )
 
@@ -11,10 +12,12 @@ var (
 	deepseekService, _ = llm.GetInstance(llm.LLM_DEEPSEEK)
 )
 
-func GenerateSessionCode() (code string) {
+func generateSessionCode() (code string) {
 	return fmt.Sprintf("%s%s", "llm_session_", uuid.NewUUID().Hex32())
 }
 
-func AddSession(session *ws.Session) {
+func AddSession(userId int, conn *websocket.Conn) {
+	sessionId := generateSessionCode()
+	session := ws.NewSession(userId, sessionId, conn)
 	ws.Manager().AddSession(session)
 }

+ 10 - 8
utils/ws/heart_beat_manager.go

@@ -59,13 +59,16 @@ func (hm *HeartbeatManager) Stop() {
 func (hm *HeartbeatManager) Register(session *Session) {
 	session.LastActive = time.Now()
 	session.CloseChan = make(chan struct{})
-	hm.sessions.Store(session.ID, session)
+	hm.sessions.Store(session.Id, session)
 }
 
 // 注销会话
 func (hm *HeartbeatManager) Unregister(sessionID string) {
-	if session, ok := hm.sessions.LoadAndDelete(sessionID); ok {
-		close(session.(*Session).CloseChan)
+	if session, ok := hm.sessions.Load(sessionID); ok {
+		if session.(*Session).Conn != nil {
+			_ = session.(*Session).Conn.Close()
+			close(session.(*Session).CloseChan)
+		}
 		hm.sessions.Delete(sessionID)
 	}
 }
@@ -76,21 +79,20 @@ func (hm *HeartbeatManager) CheckAll() {
 		session := value.(*Session)
 		// 判断超时
 		if time.Since(session.LastActive) > connectionTimeout {
-			utils.FileLog.Warn("连接超时关闭: SessionID=%s, UserID=%s", session.ID, session.UserID)
+			utils.FileLog.Warn("连接超时关闭: SessionID=%s, UserID=%s", session.Id, session.UserId)
 			_ = session.Conn.Close()
-			hm.Unregister(session.ID)
+			hm.Unregister(session.Id)
 			return true
 		}
-
 		// 发送心跳
 		go func(s *Session) {
 			err := s.Conn.WriteControl(websocket.PingMessage,
 				nil, time.Now().Add(5*time.Second))
 			if err != nil {
 				utils.FileLog.Warn("心跳发送失败: SessionID=%s, Error=%v",
-					s.ID, err)
+					s.Id, err)
 				_ = s.Conn.Close()
-				hm.Unregister(s.ID)
+				hm.Unregister(s.Id)
 			}
 		}(session)
 

+ 15 - 0
utils/ws/latency_measurer.go

@@ -6,6 +6,13 @@ import (
 	"time"
 )
 
+const (
+	maxMessageSize   = 1024 * 1024 // 1MB
+	basePingInterval = 5 * time.Second
+	maxPingInterval  = 120 * time.Second
+	minPingInterval  = 15 * time.Second
+)
+
 // LatencyMeasurer 延迟测量器
 type LatencyMeasurer struct {
 	measurements    []time.Duration
@@ -18,6 +25,8 @@ type LatencyMeasurer struct {
 func NewLatencyMeasurer(windowSize int) *LatencyMeasurer {
 	return &LatencyMeasurer{
 		maxMeasurements: windowSize,
+		measurements:    make([]time.Duration, 0, windowSize),
+		lastLatency:     basePingInterval,
 	}
 }
 
@@ -54,6 +63,12 @@ func (lm *LatencyMeasurer) CalculateLatency() {
 		sum += d
 	}
 	lm.lastLatency = sum / time.Duration(len(lm.measurements))
+	if lm.lastLatency > maxPingInterval {
+		lm.lastLatency = maxPingInterval
+	}
+	if lm.lastLatency < minPingInterval {
+		lm.lastLatency = minPingInterval
+	}
 }
 
 // 获取当前网络延迟估值

+ 32 - 10
utils/ws/session.go

@@ -3,6 +3,7 @@ package ws
 import (
 	"errors"
 	"eta/eta_api/utils"
+	"fmt"
 	"github.com/gorilla/websocket"
 	"sync"
 	"time"
@@ -10,20 +11,24 @@ import (
 
 // Session 会话结构
 type Session struct {
-	ID          string
-	UserID      int
+	Id          string
+	UserId      int
 	Conn        *websocket.Conn
 	LastActive  time.Time
 	Latency     *LatencyMeasurer
 	History     []string
 	CloseChan   chan struct{}
-	MessageChan chan []byte
+	MessageChan chan *Message
 	mu          sync.RWMutex
 }
+type Message struct {
+	MessageType string
+	message     []byte
+}
 
 // readPump 处理读操作
 func (s *Session) readPump() {
-	defer manager.RemoveSession(s.ID)
+	defer manager.RemoveSession(s.Id)
 	s.Conn.SetReadLimit(maxMessageSize)
 	_ = s.Conn.SetReadDeadline(time.Now().Add(60 * time.Second))
 	for {
@@ -32,19 +37,22 @@ func (s *Session) readPump() {
 			handleCloseError(err)
 			return
 		}
+		fmt.Printf("用户读取数据:%s", string(message))
 		// 更新活跃时间
 		s.mu.Lock()
 		s.LastActive = time.Now()
 		s.mu.Unlock()
 		// 处理消息
-		if err = manager.HandleMessage(s.UserID, s.ID, message); err != nil {
-
+		if err = manager.HandleMessage(s.UserId, s.Id, message); err != nil {
+			//写应答
+			s.Conn.WriteMessage(websocket.TextMessage, []byte(err.Error()))
 		}
 	}
 }
 
 // writePump 处理写操作
 func (s *Session) writePump() {
+	fmt.Printf("用户写入数据")
 	ticker := time.NewTicker(basePingInterval)
 	defer ticker.Stop()
 	for {
@@ -53,7 +61,7 @@ func (s *Session) writePump() {
 			if !ok {
 				return
 			}
-			_ = s.Conn.WriteMessage(websocket.TextMessage, message)
+			_ = s.Conn.WriteMessage(websocket.TextMessage, message.message)
 		case <-ticker.C:
 			_ = s.Latency.SendPing(s.Conn)
 			ticker.Reset(s.Latency.lastLatency)
@@ -85,9 +93,23 @@ func (s *Session) forceClose() {
 	s.Conn.WriteControl(websocket.CloseMessage,
 		websocket.FormatCloseMessage(websocket.ClosePolicyViolation, "heartbeat failed"),
 		time.Now().Add(5*time.Second))
-
 	s.Conn.Close()
 	utils.FileLog.Info("连接已强制关闭",
-		"user", s.UserID,
-		"session", s.ID)
+		"user", s.UserId,
+		"session", s.Id)
+}
+
+func NewSession(userId int, sessionId string, conn *websocket.Conn) (session *Session) {
+	session = &Session{
+		UserId:      userId,
+		Id:          sessionId,
+		Conn:        conn,
+		History:     []string{},
+		CloseChan:   make(chan struct{}),
+		MessageChan: make(chan *Message),
+	}
+	session.Latency = SetupLatencyMeasurement(conn)
+	go session.readPump()
+	go session.writePump()
+	return
 }

+ 4 - 10
utils/ws/session_manager.go

@@ -8,13 +8,6 @@ import (
 	"time"
 )
 
-const (
-	maxMessageSize   = 1024 * 1024 // 1MB
-	basePingInterval = 30 * time.Second
-	maxPingInterval  = 120 * time.Second
-	minPingInterval  = 15 * time.Second
-)
-
 type ConnectionManager struct {
 	Sessions  sync.Map
 	heartbeat *HeartbeatManager
@@ -33,7 +26,7 @@ func Manager() *ConnectionManager {
 // 消息处理核心逻辑
 func (manager *ConnectionManager) HandleMessage(userID int, sessionID string, message []byte) error {
 	if !Allow(userID, QA_LIMITER) {
-		return errors.New("request too frequent")
+		return errors.New("您提问的太频繁了,请稍后再试")
 	}
 	session, exists := manager.GetSession(sessionID)
 	if !exists {
@@ -75,7 +68,7 @@ func (manager *ConnectionManager) checkSessions() {
 
 // AddSession Add 添加一个新的会话
 func (manager *ConnectionManager) AddSession(session *Session) {
-	manager.Sessions.Store(session.ID, session)
+	manager.Sessions.Store(session.Id, session)
 	manager.heartbeat.Register(session)
 }
 func (manager *ConnectionManager) GetSessionId(userId int, sessionId string) (sessionID string) {
@@ -92,9 +85,10 @@ func (manager *ConnectionManager) RemoveSession(sessionCode string) {
 }
 
 // GetSession 获取一个会话
-func (manager *ConnectionManager) GetSession(sessionCode string) (session *Session, ok bool) {
+func (manager *ConnectionManager) GetSession(sessionCode string) (session *Session, exists bool) {
 	if data, ok := manager.Sessions.Load(sessionCode); ok {
 		session = data.(*Session)
+		exists = ok
 	}
 	return
 }