package ws
import (
"encoding/json"
"errors"
"eta/eta_api/utils"
"fmt"
"github.com/gorilla/websocket"
"sync"
"time"
)
// Session 会话结构
type Session struct {
Id string
UserId int
Conn *websocket.Conn
LastActive time.Time
Latency *LatencyMeasurer
History []json.RawMessage
CloseChan chan struct{}
MessageChan chan string
mu sync.RWMutex
sessionOnce sync.Once
}
type Message struct {
KbName string `json:"KbName"`
Query string `json:"Query"`
ChatId int `json:"ChatId"`
//LastTopics []json.RawMessage `json:"LastTopics"`
}
// readPump 处理读操作
func (s *Session) readPump() {
defer func() {
fmt.Printf("读进程session %s closed", s.Id)
manager.RemoveSession(s.Id)
}()
s.Conn.SetReadLimit(maxMessageSize)
_ = s.Conn.SetReadDeadline(time.Now().Add(ReadTimeout))
for {
_, message, err := s.Conn.ReadMessage()
if err != nil {
fmt.Printf("websocket 错误关闭 %s closed", err.Error())
handleCloseError(err)
return
}
// 更新活跃时间
s.UpdateActivity()
// 处理消息
if err = manager.HandleMessage(s.UserId, s.Id, message); err != nil {
//写应答
_ = s.writeWithTimeout("")
_ = s.writeWithTimeout(err.Error())
_ = s.writeWithTimeout("")
}
}
}
// UpdateActivity 跟新最近活跃时间
func (s *Session) UpdateActivity() {
s.mu.Lock()
defer s.mu.Unlock()
s.LastActive = time.Now()
}
func (s *Session) Close() {
s.sessionOnce.Do(func() {
// 控制关闭顺序
close(s.CloseChan)
close(s.MessageChan)
s.forceClose()
})
}
// 带超时的安全写入
func (s *Session) writeWithTimeout(msg string) error {
s.mu.Lock()
defer s.mu.Unlock()
if s.Conn == nil {
utils.FileLog.Error("写入消息失败,connection已关闭")
return errors.New("connection closed")
}
// 设置写超时
if err := s.Conn.SetWriteDeadline(time.Now().Add(writeWaitTimeout)); err != nil {
return err
}
return s.Conn.WriteMessage(websocket.TextMessage, []byte(msg))
}
// writePump 处理写操作
func (s *Session) writePump() {
ticker := time.NewTicker(basePingInterval)
defer func() {
fmt.Printf("写继进程:session %s closed", s.Id)
manager.RemoveSession(s.Id)
ticker.Stop()
}()
for {
select {
case message, ok := <-s.MessageChan:
if !ok {
return
}
_ = s.writeWithTimeout(message)
case <-ticker.C:
_ = s.Latency.SendPing(s.Conn)
ticker.Reset(s.Latency.lastLatency)
case <-s.CloseChan:
return
}
}
}
func handleCloseError(err error) {
utils.FileLog.Error("websocket错误关闭 %s closed", err.Error())
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) {
var wsErr *websocket.CloseError
if !errors.As(err, &wsErr) {
fmt.Printf("websocket未知错误 %s", err.Error())
utils.FileLog.Error("未知错误 %s", err.Error())
} else {
switch wsErr.Code {
case websocket.CloseNormalClosure:
fmt.Println("websocket正常关闭连接")
utils.FileLog.Info("正常关闭连接")
default:
fmt.Printf("websocket关闭代码 %d:%s", wsErr.Code, wsErr.Text)
utils.FileLog.Error(":%d:%s", wsErr.Code, wsErr.Text)
}
}
}
}
// 强制关闭连接
func (s *Session) forceClose() {
// 添加互斥锁保护
s.mu.Lock()
defer s.mu.Unlock()
// 发送关闭帧
_ = s.Conn.WriteControl(websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.ClosePolicyViolation, "heartbeat failed"),
time.Now().Add(writeWaitTimeout))
_ = s.Conn.Close()
s.Conn = nil // 标记连接已关闭
utils.FileLog.Info("连接已强制关闭",
"user", s.UserId,
"session", s.Id)
}
func NewSession(userId int, sessionId string, conn *websocket.Conn) (session *Session) {
session = &Session{
UserId: userId,
Id: sessionId,
Conn: conn,
LastActive: time.Now(),
CloseChan: make(chan struct{}),
MessageChan: make(chan string, 10),
}
session.Latency = SetupLatencyMeasurement(conn)
go session.readPump()
go session.writePump()
return
}