session_manager.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. package ws
  2. import (
  3. "errors"
  4. "eta/eta_api/utils"
  5. "fmt"
  6. "github.com/gorilla/websocket"
  7. "sync"
  8. "time"
  9. )
  10. const (
  11. defaultCheckInterval = 5 * time.Second // 检测间隔应小于心跳超时时间
  12. connectionTimeout = 20 * time.Second // 客户端超时时间
  13. ReadTimeout = 10 * time.Second // 客户端超时时间
  14. writeWaitTimeout = 5 * time.Second
  15. )
  16. type ConnectionManager struct {
  17. Sessions sync.Map
  18. ticker *time.Ticker
  19. stopChan chan struct{}
  20. }
  21. var (
  22. smOnce sync.Once
  23. manager *ConnectionManager
  24. )
  25. func GetInstance() *ConnectionManager {
  26. smOnce.Do(func() {
  27. if manager == nil {
  28. manager = &ConnectionManager{
  29. ticker: time.NewTicker(defaultCheckInterval),
  30. stopChan: make(chan struct{}),
  31. }
  32. }
  33. })
  34. return manager
  35. }
  36. func Manager() *ConnectionManager {
  37. return manager
  38. }
  39. // HandleMessage 消息处理核心逻辑
  40. func (manager *ConnectionManager) HandleMessage(userID int, sessionID string, message []byte) error {
  41. if !Allow(userID, QA_LIMITER) {
  42. return errors.New("您提问的太频繁了,请稍后再试")
  43. }
  44. session, exists := manager.GetSession(sessionID)
  45. if !exists {
  46. return errors.New("session not found")
  47. }
  48. // 处理业务逻辑
  49. session.History = append(session.History, string(message))
  50. response := "Processed: " + string(message)
  51. // 更新最后活跃时间
  52. session.LastActive = time.Now()
  53. // 发送响应
  54. return session.Conn.WriteMessage(websocket.TextMessage, []byte(response))
  55. }
  56. // AddSession Add 添加一个新的会话
  57. func (manager *ConnectionManager) AddSession(session *Session) {
  58. manager.Sessions.Store(session.Id, session)
  59. }
  60. func (manager *ConnectionManager) GetSessionId(userId int, sessionId string) (sessionID string) {
  61. return fmt.Sprintf("%d_%s", userId, sessionId)
  62. }
  63. // RemoveSession Remove 移除一个会话
  64. func (manager *ConnectionManager) RemoveSession(sessionCode string) {
  65. manager.Sessions.Delete(sessionCode)
  66. }
  67. // GetSession 获取一个会话
  68. func (manager *ConnectionManager) GetSession(sessionCode string) (session *Session, exists bool) {
  69. if data, ok := manager.Sessions.Load(sessionCode); ok {
  70. session = data.(*Session)
  71. exists = ok
  72. }
  73. return
  74. }
  75. // CheckAll 批量检测所有连接
  76. func (manager *ConnectionManager) CheckAll() {
  77. n := 0
  78. manager.Sessions.Range(func(key, value interface{}) bool {
  79. n++
  80. session := value.(*Session)
  81. // 判断超时
  82. if time.Since(session.LastActive) > connectionTimeout {
  83. fmt.Printf("连接超时关闭: SessionID=%s, UserID=%s", session.Id, session.UserId)
  84. utils.FileLog.Warn("连接超时关闭: SessionID=%s, UserID=%s", session.Id, session.UserId)
  85. session.Close()
  86. return true
  87. }
  88. // 发送心跳
  89. go func(s *Session) {
  90. err := s.Conn.WriteControl(websocket.PingMessage,
  91. nil, time.Now().Add(5*time.Second))
  92. if err != nil {
  93. fmt.Printf("心跳发送失败: SessionID=%s, Error=%v", s.Id, err)
  94. utils.FileLog.Warn("心跳发送失败: SessionID=%s, Error=%v",
  95. s.Id, err)
  96. session.Close()
  97. }
  98. }(session)
  99. fmt.Println("当前连接数:", n)
  100. return true
  101. })
  102. fmt.Println("当前连接数:", n)
  103. }
  104. // Start 启动心跳检测
  105. func (manager *ConnectionManager) Start() {
  106. defer manager.ticker.Stop()
  107. for {
  108. select {
  109. case <-manager.ticker.C:
  110. fmt.Printf("开始检测连接超时")
  111. manager.CheckAll()
  112. case <-manager.stopChan:
  113. fmt.Printf("退出检测")
  114. return
  115. }
  116. }
  117. }
  118. // Stop 停止心跳检测
  119. func (manager *ConnectionManager) Stop() {
  120. close(manager.stopChan)
  121. }