|
@@ -1,6 +1,7 @@
|
|
package ws
|
|
package ws
|
|
|
|
|
|
import (
|
|
import (
|
|
|
|
+ "encoding/json"
|
|
"errors"
|
|
"errors"
|
|
"eta/eta_api/utils"
|
|
"eta/eta_api/utils"
|
|
"eta/eta_api/utils/llm"
|
|
"eta/eta_api/utils/llm"
|
|
@@ -17,10 +18,10 @@ var (
|
|
)
|
|
)
|
|
|
|
|
|
const (
|
|
const (
|
|
- defaultCheckInterval = 5 * time.Second // 检测间隔应小于心跳超时时间
|
|
|
|
- connectionTimeout = 20 * time.Second // 客户端超时时间
|
|
|
|
- ReadTimeout = 10 * time.Second // 客户端超时时间
|
|
|
|
- writeWaitTimeout = 5 * time.Second
|
|
|
|
|
|
+ defaultCheckInterval = 2 * time.Minute // 检测间隔应小于心跳超时时间
|
|
|
|
+ connectionTimeout = 10 * time.Minute // 客户端超时时间
|
|
|
|
+ ReadTimeout = 60 * time.Second // 读取超时时间
|
|
|
|
+ writeWaitTimeout = 60 * time.Second //写入超时时间
|
|
)
|
|
)
|
|
|
|
|
|
type ConnectionManager struct {
|
|
type ConnectionManager struct {
|
|
@@ -58,10 +59,17 @@ func (manager *ConnectionManager) HandleMessage(userID int, sessionID string, me
|
|
if !exists {
|
|
if !exists {
|
|
return errors.New("session not found")
|
|
return errors.New("session not found")
|
|
}
|
|
}
|
|
-
|
|
|
|
|
|
+ var userMessage Message
|
|
|
|
+ err := json.Unmarshal(message, &userMessage)
|
|
|
|
+ if err != nil {
|
|
|
|
+ return errors.New("消息格式错误")
|
|
|
|
+ }
|
|
// 处理业务逻辑
|
|
// 处理业务逻辑
|
|
- session.History = append(session.History, message)
|
|
|
|
- resp, err := llmService.KnowledgeBaseChat("", "hz", nil)
|
|
|
|
|
|
+ session.History = append(session.History, userMessage.LastTopics...)
|
|
|
|
+ resp, err := llmService.KnowledgeBaseChat(userMessage.Query, userMessage.KbName, session.History)
|
|
|
|
+ defer func() {
|
|
|
|
+ _ = resp.Body.Close()
|
|
|
|
+ }()
|
|
if err != nil {
|
|
if err != nil {
|
|
err = errors.New(fmt.Sprintf("知识库问答失败: httpCode:%d,错误信息:%s", resp.StatusCode, http.StatusText(resp.StatusCode)))
|
|
err = errors.New(fmt.Sprintf("知识库问答失败: httpCode:%d,错误信息:%s", resp.StatusCode, http.StatusText(resp.StatusCode)))
|
|
return err
|
|
return err
|
|
@@ -80,6 +88,7 @@ func (manager *ConnectionManager) HandleMessage(userID int, sessionID string, me
|
|
err = errors.New("未知的错误异常")
|
|
err = errors.New("未知的错误异常")
|
|
return err
|
|
return err
|
|
}
|
|
}
|
|
|
|
+ session.UpdateActivity()
|
|
// 发送消息到 WebSocket
|
|
// 发送消息到 WebSocket
|
|
_ = session.Conn.WriteMessage(websocket.TextMessage, []byte(content))
|
|
_ = session.Conn.WriteMessage(websocket.TextMessage, []byte(content))
|
|
case chanErr, ok := <-errChan:
|
|
case chanErr, ok := <-errChan:
|
|
@@ -123,12 +132,10 @@ func (manager *ConnectionManager) GetSession(sessionCode string) (session *Sessi
|
|
|
|
|
|
// CheckAll 批量检测所有连接
|
|
// CheckAll 批量检测所有连接
|
|
func (manager *ConnectionManager) CheckAll() {
|
|
func (manager *ConnectionManager) CheckAll() {
|
|
- n := 0
|
|
|
|
manager.Sessions.Range(func(key, value interface{}) bool {
|
|
manager.Sessions.Range(func(key, value interface{}) bool {
|
|
- n++
|
|
|
|
session := value.(*Session)
|
|
session := value.(*Session)
|
|
// 判断超时
|
|
// 判断超时
|
|
- if time.Since(session.LastActive) > connectionTimeout {
|
|
|
|
|
|
+ if time.Since(session.LastActive) > 2*connectionTimeout {
|
|
fmt.Printf("连接超时关闭: SessionID=%s, UserID=%s", session.Id, session.UserId)
|
|
fmt.Printf("连接超时关闭: SessionID=%s, UserID=%s", session.Id, session.UserId)
|
|
utils.FileLog.Warn("连接超时关闭: SessionID=%s, UserID=%s", session.Id, session.UserId)
|
|
utils.FileLog.Warn("连接超时关闭: SessionID=%s, UserID=%s", session.Id, session.UserId)
|
|
session.Close()
|
|
session.Close()
|
|
@@ -142,13 +149,12 @@ func (manager *ConnectionManager) CheckAll() {
|
|
fmt.Printf("心跳发送失败: SessionID=%s, Error=%v", s.Id, err)
|
|
fmt.Printf("心跳发送失败: SessionID=%s, Error=%v", s.Id, err)
|
|
utils.FileLog.Warn("心跳发送失败: SessionID=%s, Error=%v",
|
|
utils.FileLog.Warn("心跳发送失败: SessionID=%s, Error=%v",
|
|
s.Id, err)
|
|
s.Id, err)
|
|
|
|
+ fmt.Println("心跳无响应,退出请求")
|
|
session.Close()
|
|
session.Close()
|
|
}
|
|
}
|
|
}(session)
|
|
}(session)
|
|
- fmt.Println("当前连接数:", n)
|
|
|
|
return true
|
|
return true
|
|
})
|
|
})
|
|
- fmt.Println("当前连接数:", n)
|
|
|
|
}
|
|
}
|
|
|
|
|
|
// Start 启动心跳检测
|
|
// Start 启动心跳检测
|
|
@@ -157,10 +163,8 @@ func (manager *ConnectionManager) Start() {
|
|
for {
|
|
for {
|
|
select {
|
|
select {
|
|
case <-manager.ticker.C:
|
|
case <-manager.ticker.C:
|
|
- fmt.Printf("开始检测连接超时")
|
|
|
|
manager.CheckAll()
|
|
manager.CheckAll()
|
|
case <-manager.stopChan:
|
|
case <-manager.stopChan:
|
|
- fmt.Printf("退出检测")
|
|
|
|
return
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|