package utils

import (
	"crypto/aes"
	"crypto/cipher"
	"crypto/rand"
	"errors"
	"io"
)

// GenerateAESKey 生成 AES 密钥
func GenerateAESKey() ([]byte, error) {
	key := make([]byte, 32)
	_, err := io.ReadFull(rand.Reader, key)
	if err != nil {
		return nil, err
	}
	return key, nil
}

// EncryptWithAES 使用 AES 加密数据
func EncryptWithAES(key []byte, plaintext []byte) ([]byte, error) {
	block, err := aes.NewCipher(key)
	if err != nil {
		return nil, err
	}

	ciphertext := make([]byte, aes.BlockSize+len(plaintext))
	iv := ciphertext[:aes.BlockSize]
	if _, err := io.ReadFull(rand.Reader, iv); err != nil {
		return nil, err
	}
	stream := cipher.NewCFBEncrypter(block, iv)
	stream.XORKeyStream(ciphertext[aes.BlockSize:], plaintext)
	return ciphertext, nil
}

// DecryptWithAES 使用 AES 解密数据
func DecryptWithAES(key []byte, ciphertext []byte) ([]byte, error) {
	block, err := aes.NewCipher(key)
	if err != nil {
		return nil, err
	}
	if len(ciphertext) < aes.BlockSize {
		return nil, errors.New("ciphertext too short")
	}
	iv := ciphertext[:aes.BlockSize]
	ciphertext = ciphertext[aes.BlockSize:]
	stream := cipher.NewCFBDecrypter(block, iv)
	stream.XORKeyStream(ciphertext, ciphertext)
	// 去填充数据
	unpadded, err := unpad(ciphertext)
	if err != nil {
		return nil, err
	}
	return unpadded, nil
}
func unpad(buf []byte) ([]byte, error) {
	if len(buf) == 0 {
		return nil, errors.New("输入缓冲区为空")
	}

	// 获取最后一个字节作为填充长度
	padding := int(buf[len(buf)-1])

	// 检查填充是否有效
	if padding > len(buf) || padding == 0 {
		return nil, errors.New("无效的填充")
	}

	// 验证填充是否一致
	for i := len(buf) - padding; i < len(buf); i++ {
		if buf[i] != byte(padding) {
			return nil, errors.New("无效的填充")
		}
	}

	// 返回未填充的数据
	return buf[:len(buf)-padding], nil
}