|
@@ -1,19 +1,28 @@
|
|
package ws
|
|
package ws
|
|
|
|
|
|
import (
|
|
import (
|
|
|
|
+ "encoding/json"
|
|
"errors"
|
|
"errors"
|
|
"eta/eta_api/utils"
|
|
"eta/eta_api/utils"
|
|
|
|
+ "eta/eta_api/utils/llm"
|
|
|
|
+ "eta/eta_api/utils/llm/eta_llm"
|
|
"fmt"
|
|
"fmt"
|
|
"github.com/gorilla/websocket"
|
|
"github.com/gorilla/websocket"
|
|
|
|
+ "net/http"
|
|
"sync"
|
|
"sync"
|
|
"time"
|
|
"time"
|
|
)
|
|
)
|
|
|
|
|
|
|
|
+var (
|
|
|
|
+ llmService, _ = llm.GetInstance(llm.ETA_LLM_CLIENT)
|
|
|
|
+)
|
|
|
|
+
|
|
const (
|
|
const (
|
|
- defaultCheckInterval = 5 * time.Second // 检测间隔应小于心跳超时时间
|
|
|
|
- connectionTimeout = 20 * time.Second // 客户端超时时间
|
|
|
|
- ReadTimeout = 10 * time.Second // 客户端超时时间
|
|
|
|
- writeWaitTimeout = 5 * time.Second
|
|
|
|
|
|
+ defaultCheckInterval = 2 * time.Minute // 检测间隔应小于心跳超时时间
|
|
|
|
+ connectionTimeout = 10 * time.Minute // 客户端超时时间
|
|
|
|
+ TcpTimeout = 20 * time.Minute // TCP超时时间,保底关闭,覆盖会话超时时间
|
|
|
|
+ ReadTimeout = 15 * time.Minute // 读取超时时间,保底关闭,覆盖会话超时时间
|
|
|
|
+ writeWaitTimeout = 60 * time.Second //写入超时时间
|
|
)
|
|
)
|
|
|
|
|
|
type ConnectionManager struct {
|
|
type ConnectionManager struct {
|
|
@@ -51,14 +60,54 @@ func (manager *ConnectionManager) HandleMessage(userID int, sessionID string, me
|
|
if !exists {
|
|
if !exists {
|
|
return errors.New("session not found")
|
|
return errors.New("session not found")
|
|
}
|
|
}
|
|
-
|
|
|
|
|
|
+ var userMessage Message
|
|
|
|
+ err := json.Unmarshal(message, &userMessage)
|
|
|
|
+ if err != nil {
|
|
|
|
+ return errors.New("消息格式错误")
|
|
|
|
+ }
|
|
// 处理业务逻辑
|
|
// 处理业务逻辑
|
|
- session.History = append(session.History, string(message))
|
|
|
|
- response := "Processed: " + string(message)
|
|
|
|
|
|
+ session.History = append(session.History, userMessage.LastTopics...)
|
|
|
|
+ resp, err := llmService.KnowledgeBaseChat(userMessage.Query, userMessage.KbName, session.History)
|
|
|
|
+ defer func() {
|
|
|
|
+ _ = resp.Body.Close()
|
|
|
|
+ }()
|
|
|
|
+ if err != nil {
|
|
|
|
+ err = errors.New(fmt.Sprintf("知识库问答失败: httpCode:%d,错误信息:%s", resp.StatusCode, http.StatusText(resp.StatusCode)))
|
|
|
|
+ return err
|
|
|
|
+ }
|
|
|
|
+ if resp.StatusCode != http.StatusOK {
|
|
|
|
+ err = errors.New(fmt.Sprintf("知识库问答失败: httpCode:%d,错误信息:%s", resp.StatusCode, http.StatusText(resp.StatusCode)))
|
|
|
|
+ return err
|
|
|
|
+ }
|
|
|
|
+ // 解析流式响应
|
|
|
|
+ contentChan, errChan, closeChan := eta_llm.ParseStreamResponse(resp)
|
|
|
|
+ // 处理流式数据并发送到 WebSocket
|
|
|
|
+ for {
|
|
|
|
+ select {
|
|
|
|
+ case content, ok := <-contentChan:
|
|
|
|
+ if !ok {
|
|
|
|
+ err = errors.New("未知的错误异常")
|
|
|
|
+ return err
|
|
|
|
+ }
|
|
|
|
+ session.UpdateActivity()
|
|
|
|
+ // 发送消息到 WebSocket
|
|
|
|
+ _ = session.Conn.WriteMessage(websocket.TextMessage, []byte(content))
|
|
|
|
+ case chanErr, ok := <-errChan:
|
|
|
|
+ if !ok {
|
|
|
|
+ err = errors.New("未知的错误异常")
|
|
|
|
+ } else {
|
|
|
|
+ err = errors.New(chanErr.Error())
|
|
|
|
+ }
|
|
|
|
+ // 发送错误消息到 WebSocket
|
|
|
|
+ return err
|
|
|
|
+ case <-closeChan:
|
|
|
|
+ _ = session.Conn.WriteMessage(websocket.TextMessage, []byte("<EOF>"))
|
|
|
|
+ return nil
|
|
|
|
+ }
|
|
|
|
+ }
|
|
// 更新最后活跃时间
|
|
// 更新最后活跃时间
|
|
- session.LastActive = time.Now()
|
|
|
|
// 发送响应
|
|
// 发送响应
|
|
- return session.Conn.WriteMessage(websocket.TextMessage, []byte(response))
|
|
|
|
|
|
+ //return session.Conn.WriteMessage(websocket.TextMessage, []byte(response))
|
|
}
|
|
}
|
|
|
|
|
|
// AddSession Add 添加一个新的会话
|
|
// AddSession Add 添加一个新的会话
|
|
@@ -71,6 +120,7 @@ func (manager *ConnectionManager) GetSessionId(userId int, sessionId string) (se
|
|
|
|
|
|
// RemoveSession Remove 移除一个会话
|
|
// RemoveSession Remove 移除一个会话
|
|
func (manager *ConnectionManager) RemoveSession(sessionCode string) {
|
|
func (manager *ConnectionManager) RemoveSession(sessionCode string) {
|
|
|
|
+ fmt.Printf("移除会话: SessionID=%s, UserID=%s", sessionCode, sessionCode)
|
|
manager.Sessions.Delete(sessionCode)
|
|
manager.Sessions.Delete(sessionCode)
|
|
}
|
|
}
|
|
|
|
|
|
@@ -85,12 +135,10 @@ func (manager *ConnectionManager) GetSession(sessionCode string) (session *Sessi
|
|
|
|
|
|
// CheckAll 批量检测所有连接
|
|
// CheckAll 批量检测所有连接
|
|
func (manager *ConnectionManager) CheckAll() {
|
|
func (manager *ConnectionManager) CheckAll() {
|
|
- n := 0
|
|
|
|
manager.Sessions.Range(func(key, value interface{}) bool {
|
|
manager.Sessions.Range(func(key, value interface{}) bool {
|
|
- n++
|
|
|
|
session := value.(*Session)
|
|
session := value.(*Session)
|
|
// 判断超时
|
|
// 判断超时
|
|
- if time.Since(session.LastActive) > connectionTimeout {
|
|
|
|
|
|
+ if time.Since(session.LastActive) > 2*connectionTimeout {
|
|
fmt.Printf("连接超时关闭: SessionID=%s, UserID=%s", session.Id, session.UserId)
|
|
fmt.Printf("连接超时关闭: SessionID=%s, UserID=%s", session.Id, session.UserId)
|
|
utils.FileLog.Warn("连接超时关闭: SessionID=%s, UserID=%s", session.Id, session.UserId)
|
|
utils.FileLog.Warn("连接超时关闭: SessionID=%s, UserID=%s", session.Id, session.UserId)
|
|
session.Close()
|
|
session.Close()
|
|
@@ -99,18 +147,17 @@ func (manager *ConnectionManager) CheckAll() {
|
|
// 发送心跳
|
|
// 发送心跳
|
|
go func(s *Session) {
|
|
go func(s *Session) {
|
|
err := s.Conn.WriteControl(websocket.PingMessage,
|
|
err := s.Conn.WriteControl(websocket.PingMessage,
|
|
- nil, time.Now().Add(5*time.Second))
|
|
|
|
|
|
+ nil, time.Now().Add(writeWaitTimeout))
|
|
if err != nil {
|
|
if err != nil {
|
|
fmt.Printf("心跳发送失败: SessionID=%s, Error=%v", s.Id, err)
|
|
fmt.Printf("心跳发送失败: SessionID=%s, Error=%v", s.Id, err)
|
|
utils.FileLog.Warn("心跳发送失败: SessionID=%s, Error=%v",
|
|
utils.FileLog.Warn("心跳发送失败: SessionID=%s, Error=%v",
|
|
s.Id, err)
|
|
s.Id, err)
|
|
|
|
+ fmt.Println("心跳无响应,退出请求")
|
|
session.Close()
|
|
session.Close()
|
|
}
|
|
}
|
|
}(session)
|
|
}(session)
|
|
- fmt.Println("当前连接数:", n)
|
|
|
|
return true
|
|
return true
|
|
})
|
|
})
|
|
- fmt.Println("当前连接数:", n)
|
|
|
|
}
|
|
}
|
|
|
|
|
|
// Start 启动心跳检测
|
|
// Start 启动心跳检测
|
|
@@ -119,10 +166,8 @@ func (manager *ConnectionManager) Start() {
|
|
for {
|
|
for {
|
|
select {
|
|
select {
|
|
case <-manager.ticker.C:
|
|
case <-manager.ticker.C:
|
|
- fmt.Printf("开始检测连接超时")
|
|
|
|
manager.CheckAll()
|
|
manager.CheckAll()
|
|
case <-manager.stopChan:
|
|
case <-manager.stopChan:
|
|
- fmt.Printf("退出检测")
|
|
|
|
return
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|