//加密工具类,用了3des和base64
package utils

import (
	"bytes"
	"crypto/cipher"
	"crypto/des"
	"encoding/base64"
	"encoding/hex"
	"errors"
	"strings"
)

//des3 + base64 encrypt
func DesBase64Encrypt(origData []byte) []byte {
	result, err := TripleDesEncrypt(origData, []byte(key))
	if err != nil {
		panic(err)
	}
	return []byte(base64.StdEncoding.EncodeToString(result))
}

func DesBase64Decrypt(crypted []byte) []byte {
	result, _ := base64.StdEncoding.DecodeString(string(crypted))
	remain := len(result) % 8
	if remain > 0 {
		mod := 8 - remain
		for i := 0; i < mod; i++ {
			result = append(result, 0)
		}
	}
	origData, err := TripleDesDecrypt(result, []byte(key))
	if err != nil {
		panic(err)
	}
	return origData
}

// 3DES加密
func TripleDesEncrypt(origData, key []byte) ([]byte, error) {
	block, err := des.NewTripleDESCipher(key)
	if err != nil {
		return nil, err
	}
	origData = PKCS5Padding(origData, block.BlockSize())
	// origData = ZeroPadding(origData, block.BlockSize())
	blockMode := cipher.NewCBCEncrypter(block, key[:8])
	crypted := make([]byte, len(origData))
	blockMode.CryptBlocks(crypted, origData)
	return crypted, nil
}

// 3DES解密
func TripleDesDecrypt(crypted, key []byte) ([]byte, error) {
	block, err := des.NewTripleDESCipher(key)
	if err != nil {
		return nil, err
	}
	blockMode := cipher.NewCBCDecrypter(block, key[:8])
	origData := make([]byte, len(crypted))
	// origData := crypted
	blockMode.CryptBlocks(origData, crypted)
	origData = PKCS5UnPadding(origData)
	// origData = ZeroUnPadding(origData)
	return origData, nil
}

func ZeroPadding(ciphertext []byte, blockSize int) []byte {
	padding := blockSize - len(ciphertext)%blockSize
	padtext := bytes.Repeat([]byte{0}, padding)
	return append(ciphertext, padtext...)
}

func ZeroUnPadding(origData []byte) []byte {
	length := len(origData)
	unpadding := int(origData[length-1])
	return origData[:(length - unpadding)]
}

func PKCS5Padding(ciphertext []byte, blockSize int) []byte {
	padding := blockSize - len(ciphertext)%blockSize
	padtext := bytes.Repeat([]byte{byte(padding)}, padding)
	return append(ciphertext, padtext...)
}

func PKCS5UnPadding(origData []byte) []byte {
	length := len(origData)
	// 去掉最后一个字节 unpadding 次
	unpadding := int(origData[length-1])
	return origData[:(length - unpadding)]
}

//DES加密
func DesEncrypt(content string, key string) string {
	contents := []byte(content)
	keys := []byte(key)
	block, err := des.NewCipher(keys)
	if err != nil {
		return ""
	}
	contents = PKCS5Padding(contents, block.BlockSize())
	blockMode := cipher.NewCBCEncrypter(block, keys)
	crypted := make([]byte, len(contents))
	blockMode.CryptBlocks(crypted, contents)
	return byteToHexString(crypted)
}

func byteToHexString(bytes []byte) string {
	str := ""
	for i := 0; i < len(bytes); i++ {
		sTemp := hex.EncodeToString([]byte{bytes[i]})
		if len(sTemp) < 2 {
			str += string(0)
		}
		str += strings.ToUpper(sTemp)
	}
	return str
}

//DES解密
func DesDecrypt(content string, key string) string {
	contentBytes, err := hex.DecodeString(content)
	if err != nil {
		return "字符串转换16进制数组失败" + err.Error()
	}
	keys := []byte(key)
	block, err := des.NewCipher(keys)
	if err != nil {
		return "解密失败" + err.Error()
	}
	blockMode := cipher.NewCBCDecrypter(block, keys)
	origData := contentBytes
	blockMode.CryptBlocks(origData, contentBytes)
	origData = ZeroUnPadding(origData)
	return string(origData)
}

// DES ECB PKCK5Padding
func EntryptDesECB(data, key []byte) (string, error) {
	if len(key) > 8 {
		key = key[:8]
	}
	block, err := des.NewCipher(key)
	if err != nil {
		return "", errors.New("des.NewCipher " + err.Error())
	}
	bs := block.BlockSize()
	data = PKCS5Padding(data, bs)
	if len(data)%bs != 0 {
		return "", errors.New("EntryptDesECB Need a multiple of the blocksize")
	}
	out := make([]byte, len(data))
	dst := out
	for len(data) > 0 {
		block.Encrypt(dst, data[:bs])
		data = data[bs:]
		dst = dst[bs:]
	}
	return base64.StdEncoding.EncodeToString(out), nil
}

func DecryptDESECB(d string, key []byte) ([]byte, error) {
	data, err := base64.StdEncoding.DecodeString(d)
	if err != nil {
		return nil, errors.New("decodebase64 " + err.Error())
	}
	if len(key) > 8 {
		key = key[:8]
	}
	block, err := des.NewCipher(key)
	if err != nil {
		return nil, errors.New("des.NewCipher " + err.Error())
	}
	bs := block.BlockSize()
	if len(data)%bs != 0 {
		return nil, errors.New("DecryptDES crypto/cipher: input not full blocks")
	}
	out := make([]byte, len(data))
	dst := out
	for len(data) > 0 {
		block.Decrypt(dst, data[:bs])
		data = data[bs:]
		dst = dst[bs:]
	}
	out = PKCS5UnPadding(out)
	return out, nil
}