package ws

import (
	"encoding/json"
	"errors"
	chatService "eta/eta_api/services/llm"
	"eta/eta_api/utils"
	"eta/eta_api/utils/llm"
	"eta/eta_api/utils/llm/eta_llm"
	"eta/eta_api/utils/llm/eta_llm/eta_llm_http"
	"fmt"
	"github.com/gorilla/websocket"
	"net/http"
	"strings"
	"sync"
	"time"
)

var (
	llmService, _ = llm.GetInstance(llm.ETA_LLM_CLIENT)
)

const (
	defaultCheckInterval = 2 * time.Minute  // 检测间隔应小于心跳超时时间
	connectionTimeout    = 10 * time.Minute // 客户端超时时间
	TcpTimeout           = 20 * time.Minute // TCP超时时间,保底关闭,覆盖会话超时时间
	ReadTimeout          = 15 * time.Minute // 读取超时时间,保底关闭,覆盖会话超时时间
	writeWaitTimeout     = 60 * time.Second //写入超时时间
)

type ConnectionManager struct {
	Sessions sync.Map
	ticker   *time.Ticker
	stopChan chan struct{}
}

var (
	smOnce  sync.Once
	manager *ConnectionManager
)

func GetInstance() *ConnectionManager {
	smOnce.Do(func() {
		if manager == nil {
			manager = &ConnectionManager{
				ticker:   time.NewTicker(defaultCheckInterval),
				stopChan: make(chan struct{}),
			}
		}
	})
	return manager
}
func Manager() *ConnectionManager {
	return manager
}

// HandleMessage 消息处理核心逻辑
func (manager *ConnectionManager) HandleMessage(userID int, sessionID string, message []byte) error {

	session, exists := manager.GetSession(sessionID)
	if !exists {
		return errors.New("session not found")
	}
	if strings.ToLower(string(message)) == "pong" {
		session.UpdateActivity()
		fmt.Printf("收到心跳消息,续期长连接:%v", session.LastActive)
		return nil
	}
	if !Allow(userID, QA_LIMITER) {
		_ = session.Conn.WriteMessage(websocket.TextMessage, []byte("<think></think>"))
		_ = session.Conn.WriteMessage(websocket.TextMessage, []byte("您提问的太频繁了,请稍后再试"))
		_ = session.Conn.WriteMessage(websocket.TextMessage, []byte("<EOF/>"))
		return nil
	}
	var userMessage Message
	err := json.Unmarshal(message, &userMessage)
	if err != nil {
		utils.FileLog.Error(fmt.Sprintf("消息格式错误:%s", string(message)))
		fmt.Printf("消息格式错误:%s", string(message))
		return errors.New("消息格式错误:" + err.Error())
	}
	// 处理业务逻辑
	//session.History = append(session.History, userMessage.LastTopics...)
	redisHisChat, err := chatService.GetChatRecordsFromRedis(userMessage.ChatId)
	if err != nil {
		utils.FileLog.Error("获取历史对话数据失败,err:", err.Error())
	} else {
		for _, chat := range redisHisChat {
			his := eta_llm_http.HistoryContent{
				Content: chat.Content,
				Role:    chat.ChatUserType,
			}
			hisMsg, _ := json.Marshal(&his)
			if len(hisMsg) != 0 {
				session.History = append(session.History, hisMsg)
			}
		}
	}
	resp, err := llmService.KnowledgeBaseChat(userMessage.Query, userMessage.KbName, session.History)
	defer func() {
		if resp != nil && resp.Body != nil && err == nil {
			_ = resp.Body.Close()
		}
	}()
	if resp == nil {
		utils.FileLog.Error("知识库问答失败: 无应答")
		return errors.New("知识库问答失败: 无应答")
	}
	if err != nil {
		utils.FileLog.Error(fmt.Sprintf("知识库问答失败: httpCode:%d,错误信息:%s", resp.StatusCode, http.StatusText(resp.StatusCode)))
		err = errors.New(fmt.Sprintf("知识库问答失败: httpCode:%d,错误信息:%s", resp.StatusCode, http.StatusText(resp.StatusCode)))
		return err
	}

	if resp.StatusCode != http.StatusOK {
		utils.FileLog.Error(fmt.Sprintf("知识库问答失败: httpCode:%d,错误信息:%s", resp.StatusCode, http.StatusText(resp.StatusCode)))
		err = errors.New(fmt.Sprintf("知识库问答失败: httpCode:%d,错误信息:%s", resp.StatusCode, http.StatusText(resp.StatusCode)))
		return err
	}
	// 解析流式响应
	contentChan, errChan, closeChan := eta_llm.ParseStreamResponse(resp)
	emptyContent := true
	// 处理流式数据并发送到 WebSocket
	for {
		select {
		case content, ok := <-contentChan:
			if !ok {
				err = errors.New("未知的错误异常")
				return err
			}
			session.UpdateActivity()
			if emptyContent {
				emptyContent = false
			}
			// 发送消息到 WebSocket
			_ = session.Conn.WriteMessage(websocket.TextMessage, []byte(content))
		case chanErr, ok := <-errChan:
			if !ok {
				err = errors.New("未知的错误异常")
			} else {
				err = errors.New(chanErr.Error())
			}
			// 发送错误消息到 WebSocket
			return err
		case <-closeChan:
			if emptyContent {
				_ = session.Conn.WriteMessage(websocket.TextMessage, []byte("<think></think>"))
			}
			_ = session.Conn.WriteMessage(websocket.TextMessage, []byte("<EOF/>"))
			return nil
		}
	}
	// 更新最后活跃时间
	// 发送响应
	//return session.Conn.WriteMessage(websocket.TextMessage, []byte(response))
}

// AddSession Add 添加一个新的会话
func (manager *ConnectionManager) AddSession(session *Session) {
	manager.Sessions.Store(session.Id, session)
}
func (manager *ConnectionManager) GetSessionId(userId int, sessionId string) (sessionID string) {
	return fmt.Sprintf("%d_%s", userId, sessionId)
}

// RemoveSession Remove 移除一个会话
func (manager *ConnectionManager) RemoveSession(sessionCode string) {
	fmt.Printf("移除会话: SessionID=%s, UserID=%s", sessionCode, sessionCode)
	manager.Sessions.Delete(sessionCode)
}

// GetSession 获取一个会话
func (manager *ConnectionManager) GetSession(sessionCode string) (session *Session, exists bool) {
	if data, ok := manager.Sessions.Load(sessionCode); ok {
		session = data.(*Session)
		exists = ok
	}
	return
}

// CheckAll 批量检测所有连接
func (manager *ConnectionManager) CheckAll() {
	manager.Sessions.Range(func(key, value interface{}) bool {
		session := value.(*Session)
		// 判断超时
		if time.Since(session.LastActive) > 2*connectionTimeout {
			fmt.Printf("连接超时关闭: SessionID=%s, UserID=%s", session.Id, session.UserId)
			utils.FileLog.Warn("连接超时关闭: SessionID=%s, UserID=%s", session.Id, session.UserId)
			session.Close()
			return true
		}
		// 发送心跳
		go func(s *Session) {
			err := s.Conn.WriteControl(websocket.PingMessage,
				nil, time.Now().Add(writeWaitTimeout))
			if err != nil {
				fmt.Printf("心跳发送失败: SessionID=%s, Error=%v", s.Id, err)
				utils.FileLog.Warn("心跳发送失败: SessionID=%s, Error=%v",
					s.Id, err)
				fmt.Println("心跳无响应,退出请求")
				session.Close()
			}
		}(session)
		return true
	})
}

// Start 启动心跳检测
func (manager *ConnectionManager) Start() {
	defer manager.ticker.Stop()
	for {
		select {
		case <-manager.ticker.C:
			manager.CheckAll()
		case <-manager.stopChan:
			return
		}
	}
}

// Stop 停止心跳检测
func (manager *ConnectionManager) Stop() {
	close(manager.stopChan)
}