Quellcode durchsuchen

新增接口签名校验

tuoling805 vor 1 Jahr
Ursprung
Commit
aa042354a0
6 geänderte Dateien mit 108 neuen und 211 gelöschten Zeilen
  1. 41 135
      controllers/base_auth.go
  2. 0 58
      controllers/resource.go
  3. 6 9
      models/base.go
  4. 1 8
      routers/router.go
  5. 37 0
      utils/common.go
  6. 23 1
      utils/config.go

+ 41 - 135
controllers/base_auth.go

@@ -6,11 +6,9 @@ import (
 	"eta/eta_hub/utils"
 	"fmt"
 	"github.com/beego/beego/v2/server/web"
-	"github.com/shopspring/decimal"
 	"github.com/sirupsen/logrus"
 	"net/http"
 	"net/url"
-	"reflect"
 )
 
 type BaseAuthController struct {
@@ -23,105 +21,54 @@ func (this *BaseAuthController) Prepare() {
 	uri := this.Ctx.Input.URI()
 	fmt.Println("Url:", uri)
 	if method != "HEAD" {
-		//if method == "POST" {
-		//	ok, errMsg := checkSign(this)
-		//	if !ok {
-		//		this.JSON(models.BaseResponse{Ret: 408, Msg: "签名错误!", ErrMsg: errMsg}, false, false)
-		//		this.StopRun()
-		//		return
-		//	}
-		//} else {
-		//	this.JSON(models.BaseResponse{Ret: 408, Msg: "请求异常,请联系客服!", ErrMsg: "POST之外的请求,暂不支持"}, false, false)
-		//	this.StopRun()
-		//	return
-		//}
+		//校验签名
+		nonce := this.Ctx.Input.Header("nonce")
+		timestamp := this.Ctx.Input.Header("timestamp")
+		appid := this.Ctx.Input.Header("appid")
+		signature := this.Ctx.Input.Header("signature")
+
+		if nonce == "" {
+			errMsg := "随机字符串不能为空"
+			this.JSON(models.BaseResponse{Ret: 400, Msg: "", ErrMsg: errMsg}, false, false)
+			this.StopRun()
+			return
+		}
+
+		if timestamp == "" {
+			errMsg := "时间戳不能为空"
+			this.JSON(models.BaseResponse{Ret: 400, Msg: "", ErrMsg: errMsg}, false, false)
+			this.StopRun()
+			return
+		}
+
+		if appid != utils.AppId {
+			errMsg := "商家AppId错误,请核查"
+			this.JSON(models.BaseResponse{Ret: 400, Msg: "", ErrMsg: errMsg}, false, false)
+			this.StopRun()
+			return
+		}
+
+		checkSign := utils.GetSign(nonce, timestamp)
+		if signature != checkSign {
+			errMsg := "签名错误"
+			this.JSON(models.BaseResponse{Ret: 401, Msg: "", ErrMsg: errMsg}, false, false)
+			this.StopRun()
+			return
+		}
+		if method != "GET" && method != "POST" {
+			errMsg := "无效的请求方式"
+			this.JSON(models.BaseResponse{Ret: 501, Msg: "", ErrMsg: errMsg}, false, false)
+			this.StopRun()
+			return
+		}
 	} else {
-		this.JSON(models.BaseResponse{Ret: 408, Msg: "请求异常,请联系客服!", ErrMsg: "method:" + method}, false, false)
+		this.JSON(models.BaseResponse{Ret: 500, Msg: "系统异常,请联系客服!", ErrMsg: "method:" + method}, false, false)
 		this.StopRun()
 		return
 	}
 }
 
-//func checkSign(c *BaseAuthController) (ok bool, errMsg string) {
-//	method := c.Ctx.Input.Method()
-//	signData := make(map[string]string)
-//
-//	switch method {
-//	case "GET":
-//		//requestBody = c.Ctx.Request.RequestURI
-//		params := c.Ctx.Request.URL.Query()
-//		signData = convertParam(params)
-//	case "POST":
-//		//requestBody, _ = url.QueryUnescape(string(c.Ctx.Input.RequestBody))
-//
-//		//请求类型
-//		contentType := c.Ctx.Request.Header.Get("content-type")
-//		//fmt.Println("contentType:", contentType)
-//		//fmt.Println("c.Ctx.Input.RequestBody:", string(c.Ctx.Input.RequestBody))
-//
-//		switch contentType {
-//		case "multipart/form-data":
-//			//文件最大5M
-//			err := c.Ctx.Request.ParseMultipartForm(-int64(5 << 20))
-//			if err != nil {
-//				errMsg = fmt.Sprintf("获取参数失败,%v", err)
-//				return
-//			}
-//			params := c.Ctx.Request.Form
-//			signData = convertParam(params)
-//		case "application/x-www-form-urlencoded":
-//			err := c.Ctx.Request.ParseForm()
-//			if err != nil {
-//				errMsg = fmt.Sprintf("获取参数失败,%v", err)
-//				return
-//			}
-//			params := c.Ctx.Request.Form
-//			signData = convertParam(params)
-//		case "application/json":
-//			//var v interface{}
-//			params := make(map[string]interface{})
-//			err := json.Unmarshal(c.Ctx.Input.RequestBody, &params)
-//			if err != nil {
-//				errMsg = fmt.Sprintf("获取参数失败,%v", err)
-//				return
-//			}
-//			//fmt.Println("params:", params)
-//
-//			signData = convertParamInterface(params)
-//			//tmpV := v.(map[string]string)
-//			//fmt.Println("tmpV:", tmpV)
-//			//fmt.Sprintln("list type is v%", tmpV["list"])
-//		default: //正常应该是其他方式获取解析的,暂时这么处理吧
-//			err := c.Ctx.Request.ParseForm()
-//			if err != nil {
-//				errMsg = fmt.Sprintf("获取参数失败,%v", err)
-//				return
-//			}
-//			params := c.Ctx.Request.Form
-//			signData = convertParam(params)
-//		}
-//	}
-//
-//	// 开始校验数据
-//	ip := c.Ctx.Input.IP()
-//	err := checkSignData(signData, ip)
-//	if err != nil {
-//		errMsg = fmt.Sprintf("签名校验失败,%v", err)
-//		return
-//	}
-//
-//	ok = true
-//	return
-//}
-
 func (c *BaseAuthController) ServeJSON(encoding ...bool) {
-	// 方法处理完后,需要后置处理的业务逻辑
-	//if handlerList, ok := AfterHandlerUrlMap[c.Ctx.Request.URL.Path]; ok {
-	//	for _, handler := range handlerList {
-	//		handler(c.Ctx.Input.RequestBody)
-	//	}
-	//}
-
 	//所有请求都做这么个处理吧,目前这边都是做编辑、刷新逻辑处理(新增的话,并没有指标id,不会有影响)
 	var (
 		hasIndent   = false
@@ -183,47 +130,6 @@ func (c *BaseAuthController) JSON(data interface{}, hasIndent bool, coding bool)
 	return c.Ctx.Output.Body(content)
 }
 
-// 将请求传入的数据格式转换成签名需要的格式
-func convertParam(params map[string][]string) (signData map[string]string) {
-	signData = make(map[string]string)
-	for key := range params {
-		signData[key] = params[key][0]
-	}
-	return signData
-}
-
-// 将请求传入的数据格式转换成签名需要的格式(目前只能处理简单的类型,数组、对象暂不支持)
-func convertParamInterface(params map[string]interface{}) (signData map[string]string) {
-	signData = make(map[string]string)
-	for key := range params {
-		val := ``
-		//fmt.Println("key", key, ";val:", params[key], ";type:", reflect.TypeOf(params[key]))
-		//signData[key] = params[key][0]
-		tmpVal := params[key]
-		switch reflect.TypeOf(tmpVal).Kind() {
-		case reflect.String:
-			val = fmt.Sprint(tmpVal)
-		case reflect.Int, reflect.Int16, reflect.Int64, reflect.Int32, reflect.Int8:
-			val = fmt.Sprint(tmpVal)
-		case reflect.Uint, reflect.Uint32, reflect.Uint16, reflect.Uint8, reflect.Uint64:
-			val = fmt.Sprint(tmpVal)
-		case reflect.Bool:
-			val = fmt.Sprint(tmpVal)
-		case reflect.Float64:
-			decimalNum := decimal.NewFromFloat(tmpVal.(float64))
-			val = decimalNum.String()
-			//val = strconv.FormatFloat(tmpVal.(float64), 'E', -1, 64) //float64
-		case reflect.Float32:
-			decimalNum := decimal.NewFromFloat32(tmpVal.(float32))
-			val = decimalNum.String()
-		}
-		signData[key] = val
-	}
-	return signData
-}
-
-
-
 func (c *BaseAuthController) logUri(data interface{}, requestBody, ip string) {
 	var reqData interface{}
 	err := json.Unmarshal([]byte(requestBody), &reqData)

+ 0 - 58
controllers/resource.go

@@ -1,58 +0,0 @@
-package controllers
-
-import (
-	"eta/eta_hub/models"
-	"os"
-)
-type ResourceController struct {
-	BaseAuthController
-}
-
-// ResourceUpload 上传文件
-// @Title 上传文件
-// @Description 上传文件
-// @Param   MenuId  query  int  true  "目录ID"
-// @Param   File	query  file  true  "文件"
-// @Success 200 Ret=200 操作成功
-// @router /resource/upload [post]
-func (this *ResourceController) ResourceUpload() {
-	br := new(models.BaseResponse).Init()
-	defer func() {
-		this.Data["json"] = br
-		this.ServeJSON()
-	}()
-
-	f, h, e := this.GetFile("file")
-	if e != nil {
-		br.Msg = "获取资源信息失败"
-		br.ErrMsg = "获取资源信息失败, Err:" + e.Error()
-		return
-	}
-	defer func() {
-		_ = f.Close()
-	}()
-
-
-	uploadDir :=  "./static/"
-	//uploadDir :=  "/Users/xi/Desktop/file"
-	if e = os.MkdirAll(uploadDir, 766); e != nil {
-		br.Msg = "存储目录创建失败"
-		br.ErrMsg = "存储目录创建失败, Err:" + e.Error()
-		return
-	}
-	//ossFileName := utils.GetRandStringNoSpecialChar(28) + ext
-	filePath := uploadDir + "/" + h.Filename
-	if e = this.SaveToFile("file", filePath); e != nil {
-		br.Msg = "文件保存失败"
-		br.ErrMsg = "文件保存失败, Err:" + e.Error()
-		return
-	}
-	//defer func() {
-	//	_ = os.Remove(filePath)
-	//}()
-
-
-	br.Msg = "上传成功"
-	br.Ret = 200
-	br.Success = true
-}

+ 6 - 9
models/base.go

@@ -1,18 +1,15 @@
 package models
 
 type BaseResponse struct {
-	Ret         int
-	Msg         string
-	ErrMsg      string
-	ErrCode     string
-	Data        interface{}
-	Success     bool `description:"true 执行成功,false 执行失败"`
-	IsSendEmail bool `json:"-" description:"true 发送邮件,false 不发送邮件"`
-	IsAddLog    bool `json:"-" description:"true 新增操作日志,false 不新增操作日志" `
+	Ret     int
+	Msg     string
+	ErrMsg  string
+	ErrCode string
+	Data    interface{}
 }
 
 func (r *BaseResponse) Init() *BaseResponse {
-	return &BaseResponse{Ret: 403, IsSendEmail: true}
+	return &BaseResponse{Ret: 403}
 }
 
 type BaseRequest struct {

+ 1 - 8
routers/router.go

@@ -8,17 +8,10 @@
 package routers
 
 import (
-	"eta/eta_hub/controllers"
 	"github.com/beego/beego/v2/server/web"
 )
 
 func init() {
-	var ns = web.NewNamespace("/v1",
-		web.NSNamespace("/test",
-			web.NSInclude(
-				&controllers.ResourceController{},
-			),
-		),
-	)
+	var ns = web.NewNamespace("/v1")
 	web.AddNamespace(ns)
 }

+ 37 - 0
utils/common.go

@@ -2,9 +2,11 @@ package utils
 
 import (
 	"bufio"
+	"crypto/hmac"
 	"crypto/md5"
 	cryRand "crypto/rand"
 	"crypto/sha1"
+	"crypto/sha256"
 	"encoding/base64"
 	"encoding/hex"
 	"encoding/json"
@@ -22,6 +24,7 @@ import (
 	"os/exec"
 	"path"
 	"regexp"
+	"sort"
 	"strconv"
 	"strings"
 	"time"
@@ -986,3 +989,37 @@ func GetLocalIP() (ip string, err error) {
 	}
 	return
 }
+
+// HmacSha256 计算HmacSha256
+// key 是加密所使用的key
+// data 是加密的内容
+func HmacSha256(key string, data string) []byte {
+	mac := hmac.New(sha256.New, []byte(key))
+	_, _ = mac.Write([]byte(data))
+
+	return mac.Sum(nil)
+}
+
+// HmacSha256ToHex 将加密后的二进制转Base64字符串
+func HmacSha256ToBase64(key string, data string) string {
+	return base64.URLEncoding.EncodeToString(HmacSha256(key, data))
+}
+
+func GetSign(nonce, timestamp string) (sign string) {
+	signStrMap := map[string]string{
+		"nonce":     nonce,
+		"timestamp": timestamp,
+		"appid":     AppId,
+	}
+	keys := make([]string, 0, len(signStrMap))
+	for k := range signStrMap {
+		keys = append(keys, k)
+	}
+	sort.Strings(keys)
+	var signStr string
+	for _, k := range keys {
+		signStr += k + "&" + signStrMap[k]
+	}
+	sign = HmacSha256ToBase64(Secret, signStr)
+	return
+}

+ 23 - 1
utils/config.go

@@ -24,6 +24,12 @@ var (
 	LogMaxDays int //日志最大保留天数
 )
 
+var (
+	BusinessCode string //商家编码
+	AppId        string
+	Secret       string
+)
+
 func init() {
 	tmpRunMode, err := web.AppConfig.String("run_mode")
 	if err != nil {
@@ -52,12 +58,28 @@ func init() {
 	beeLogger.Log.Info(RunMode + " 模式")
 	MYSQL_URL = config["mysql_url"]
 	MYSQL_URL_DATA = config["mysql_url_data"]
-
 	if RunMode == "release" {
 
 	} else {
 
 	}
+
+	//商家编码
+	BusinessCode = config["business_code"]
+	if BusinessCode == "" {
+		panic("商家编码未配置,请先配置商家编码")
+	}
+
+	AppId = config["appid"]
+	if AppId == "" {
+		panic("appid未配置")
+	}
+
+	Secret = config["secret"]
+	if Secret == "" {
+		panic("secret未配置")
+	}
+
 	//日志配置
 	{
 		LogPath = config["log_path"]