package utils

import (
	"fmt"
	"github.com/gonum/stat"
	"github.com/shopspring/decimal"
	"math"
	"strings"
)

// Series is a container for a series of data
type Series []Coordinate

// Coordinate holds the data in a series
type Coordinate struct {
	X, Y float64
}

// GetLinearResult 生成线性方程式
func GetLinearResult(s []Coordinate) (gradient, intercept float64) {
	if len(s) <= 1 {
		return
	}

	// Placeholder for the math to be done
	var sum [5]float64

	// Loop over data keeping index in place
	i := 0
	for ; i < len(s); i++ {
		sum[0] += s[i].X
		sum[1] += s[i].Y
		sum[2] += s[i].X * s[i].X
		sum[3] += s[i].X * s[i].Y
		sum[4] += s[i].Y * s[i].Y
	}

	// Find gradient and intercept
	f := float64(i)
	gradient = (f*sum[3] - sum[0]*sum[1]) / (f*sum[2] - sum[0]*sum[0])
	intercept = (sum[1] / f) - (gradient * sum[0] / f)

	//fmt.Println("gradient:", gradient, ";intercept:", intercept)
	// Create the new regression series
	//for j := 0; j < len(s); j++ {
	//	regressions = append(regressions, Coordinate{
	//		X: s[j].X,
	//		Y: s[j].X*gradient + intercept,
	//	})
	//}

	return
}

// CalculateCorrelationByIntArr 相关性计算
// 计算步骤
// 1.分别计算两个序列的平均值Mx和My
// 2.分别计算两个序列的标准偏差SDx和SDy	=> √{1/(n-1)*SUM[(Xi-Mx)²]}
// 3.计算相关系数	=> SUM[(Xi-Mx)*(Yi-My)]/[(N-1)(SDx*SDy)]
func CalculateCorrelationByIntArr(xArr, yArr []float64) (ratio float64) {
	// 序列元素数要一致
	xLen := float64(len(xArr))
	yLen := float64(len(yArr))
	if xLen == 0 || xLen != yLen {
		return
	}

	// 计算Mx和My
	var Xa, Ya float64
	for i := range xArr {
		Xa += xArr[i]
	}
	Mx := Xa / xLen
	for i := range yArr {
		Ya += yArr[i]
	}
	My := Ya / yLen

	// 计算标准偏差SDx和SDy
	var Xb, Yb, SDx, SDy float64
	for i := range xArr {
		Xb += (xArr[i] - Mx) * (xArr[i] - Mx)
	}
	SDx = math.Sqrt(1 / (xLen - 1) * Xb)
	for i := range yArr {
		Yb += (yArr[i] - My) * (yArr[i] - My)
	}
	SDy = math.Sqrt(1 / (yLen - 1) * Yb)

	// 计算相关系数
	var Nume, Deno float64
	for i := 0; i < int(xLen); i++ {
		Nume += (xArr[i] - Mx) * (yArr[i] - My)
	}
	Deno = (xLen - 1) * (SDx * SDy)
	ratio = Nume / Deno
	if math.IsNaN(ratio) {
		ratio = 0
	}
	return
}

// ComputeCorrelation 通过一组数据获取相关系数R
// 计算步骤
// 1.分别计算两个序列的平均值Mx和My
// 2.分别计算两个序列的标准偏差SDx和SDy	=> √{1/(n-1)*SUM[(Xi-Mx)²]}
// 3.计算相关系数	=> SUM[(Xi-Mx)*(Yi-My)]/[(N-1)(SDx*SDy)]
func ComputeCorrelation(sList []Coordinate) (r float64) {
	var xBar, yBar float64
	lenSList := len(sList)
	// 必须两组数据及两组以上的数据才能计算
	if lenSList < 2 {
		return
	}
	decimalX := decimal.NewFromFloat(0)
	decimalY := decimal.NewFromFloat(0)

	// 计算两组数据X、Y的平均值
	for _, coordinate := range sList {
		decimalX = decimalX.Add(decimal.NewFromFloat(coordinate.X))
		decimalY = decimalY.Add(decimal.NewFromFloat(coordinate.Y))
	}
	xBar, _ = decimalX.Div(decimal.NewFromInt(int64(lenSList))).Round(4).Float64()
	yBar, _ = decimalY.Div(decimal.NewFromInt(int64(lenSList))).Round(4).Float64()
	//fmt.Println(xBar)
	//fmt.Println(yBar)

	varXDeci := decimal.NewFromFloat(0)
	varYDeci := decimal.NewFromFloat(0)
	ssrDeci := decimal.NewFromFloat(0)

	for _, coordinate := range sList {
		// 分别计算X、Y的实际数据与平均值的差值
		diffXXbarDeci := decimal.NewFromFloat(coordinate.X).Sub(decimal.NewFromFloat(xBar))
		diffYYbarDeci := decimal.NewFromFloat(coordinate.Y).Sub(decimal.NewFromFloat(yBar))
		ssrDeci = ssrDeci.Add(diffXXbarDeci.Mul(diffYYbarDeci))
		//fmt.Println("i:", i, ";diffXXbar:", diffXXbarDeci.String(), ";diffYYbar:", diffYYbarDeci.String(), ";ssr:", ssrDeci.String())
		varXDeci = varXDeci.Add(diffXXbarDeci.Mul(diffXXbarDeci))
		varYDeci = varYDeci.Add(diffYYbarDeci.Mul(diffYYbarDeci))
		//varY += diffYYbar ** 2
	}
	//当输入的两个数组完全相同时,计算相关系数会导致除以零的操作,从而产生 NaN(Not a Number)的结果。为了避免这种情况,可以在计算相关系数之前先进行一个判断,如果两个数组的标准差为零,则相关系数应为1
	if varXDeci.IsZero() && varYDeci.IsZero() {
		r = 1
		return
	}
	sqrtVal, _ := varXDeci.Mul(varYDeci).Round(4).Float64()
	//fmt.Println("sqrtVal:", sqrtVal)
	sst := math.Sqrt(sqrtVal) // 平方根
	//fmt.Println("sst:", sst)
	// 如果计算出来的平方根是0,那么就直接返回,因为0不能作为除数
	if sst == 0 {
		return
	}
	r, _ = ssrDeci.Div(decimal.NewFromFloat(sst)).Round(4).Float64()

	return
}

// CalculationDecisive 通过一组数据获取决定系数R2
func CalculationDecisive(sList []Coordinate) (r2 float64) {
	r := ComputeCorrelation(sList)
	r2, _ = decimal.NewFromFloat(r).Mul(decimal.NewFromFloat(r)).Round(4).Float64()

	return
}

// CalculateStandardDeviation 计算标准差
func CalculateStandardDeviation(data []float64) float64 {
	return stat.StdDev(data, nil)
}

func ReplaceFormula(valArr map[string]float64, formulaStr string) string {
	funMap := getFormulaMap()
	for k, v := range funMap {
		formulaStr = strings.Replace(formulaStr, k, v, -1)
	}

	replaceCount := 0
	for tag, val := range valArr {
		dvStr := fmt.Sprintf("%v", val)
		formulaStr = strings.Replace(formulaStr, tag, dvStr, -1)
		replaceCount++
	}
	for k, v := range funMap {
		formulaStr = strings.Replace(formulaStr, v, k, -1)
	}
	return formulaStr
}

// CellPosition
// @Description: 单元格位置
type CellPosition struct {
	Tag   string
	Row   int
	Value float64
}

// ReplaceFormulaByCellList
// @Description: 根据单元格列表替换
// @author: Roc
// @datetime2023-11-14 16:16:12
// @param cellList []CellPosition
// @param formulaStr string
// @return string
func ReplaceFormulaByCellList(cellList []CellPosition, formulaStr string) string {
	funMap := getFormulaMap()
	for k, v := range funMap {
		formulaStr = strings.Replace(formulaStr, k, v, -1)
	}

	replaceCount := 0
	for _, cell := range cellList {
		dvStr := fmt.Sprintf("%v", cell.Value)
		formulaStr = strings.Replace(formulaStr, fmt.Sprint(cell.Tag, cell.Row), dvStr, -1)
		replaceCount++
	}
	for k, v := range funMap {
		formulaStr = strings.Replace(formulaStr, v, k, -1)
	}
	return formulaStr
}

func ReplaceFormulaByTagMap(valTagMap map[string]int, formulaStr string) string {
	funMap := getFormulaMap()
	for k, v := range funMap {
		formulaStr = strings.Replace(formulaStr, k, v, -1)
	}

	replaceCount := 0
	for tag, val := range valTagMap {
		dvStr := fmt.Sprintf("%v", val)
		formulaStr = strings.Replace(formulaStr, tag, dvStr, -1)
		replaceCount++
	}
	for k, v := range funMap {
		formulaStr = strings.Replace(formulaStr, v, k, -1)
	}
	return formulaStr
}

func getFormulaMap() map[string]string {
	funMap := make(map[string]string)
	funMap["MAX"] = "[@@]"
	funMap["MIN"] = "[@!]"
	funMap["ABS"] = "[@#]"
	funMap["CEIL"] = "[@$]"
	funMap["COS"] = "[@%]"
	funMap["FLOOR"] = "[@^]"
	funMap["MOD"] = "[@&]"
	funMap["POW"] = "[@*]"
	funMap["ROUND"] = "[@(]"
	return funMap
}