package ws

import (
	"errors"
	"github.com/gorilla/websocket"
	"sync"
	"time"
)

const (
	maxMessageSize   = 1024 * 1024 * 10 // 1MB
	basePingInterval = 5 * time.Second
	maxPingInterval  = 120 * time.Second
	minPingInterval  = 15 * time.Second
)

// LatencyMeasurer 延迟测量器
type LatencyMeasurer struct {
	measurements    []time.Duration
	lastLatency     time.Duration
	mu              sync.Mutex
	lastPingTime    time.Time // 最后一次发送Ping的时间
	maxMeasurements int       // 保留的最大测量次数
}

func NewLatencyMeasurer(windowSize int) *LatencyMeasurer {
	return &LatencyMeasurer{
		maxMeasurements: windowSize,
		measurements:    make([]time.Duration, 0, windowSize),
		lastLatency:     basePingInterval,
	}
}

// 发送Ping并记录时间戳
func (lm *LatencyMeasurer) SendPing(conn *websocket.Conn) error {
	lm.mu.Lock()
	defer lm.mu.Unlock()
	if conn == nil {
		return errors.New("connection closed")
	}
	// 发送Ping消息
	err := conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(writeWaitTimeout))
	if err != nil {
		return err
	}
	lm.lastPingTime = time.Now()
	return nil
}

// 处理Pong响应
func (lm *LatencyMeasurer) CalculateLatency() {
	lm.mu.Lock()
	defer lm.mu.Unlock()
	if lm.lastPingTime.IsZero() {
		return
	}
	// 计算往返时间
	rtt := time.Since(lm.lastPingTime)
	// 维护滑动窗口
	if len(lm.measurements) >= lm.maxMeasurements {
		lm.measurements = lm.measurements[1:]
	}
	lm.measurements = append(lm.measurements, rtt)
	// 计算平均延迟(可根据需求改为中位数等)
	sum := time.Duration(0)
	for _, d := range lm.measurements {
		sum += d
	}
	lm.lastLatency = sum / time.Duration(len(lm.measurements))
	if lm.lastLatency > maxPingInterval {
		lm.lastLatency = maxPingInterval
	}
	if lm.lastLatency < minPingInterval {
		lm.lastLatency = minPingInterval
	}
}

// 获取当前网络延迟估值
func (lm *LatencyMeasurer) GetLatency() time.Duration {
	lm.mu.Lock()
	defer lm.mu.Unlock()
	return lm.lastLatency
}

// 在连接初始化时设置Pong处理器
func SetupLatencyMeasurement(conn *websocket.Conn) *LatencyMeasurer {
	lm := NewLatencyMeasurer(5) // 使用最近5次测量的滑动窗口
	conn.SetPongHandler(func(appData string) error {
		lm.CalculateLatency()
		return nil
	})
	return lm
}