Browse Source

refactor(models): 使用GORM替换ORM库

在`eta_gn/eta_api`项目中,已将`github.com/beego/beego/v2/client/orm`库替换为`gorm`库,以实现ORM功能的现代化和统一。此次改动涉及多个模型文件,包括`base_auth.go`、`supply_analysis`包下的文件、`trade_analysis`包下的文件以及系统相关模型如`admin_operate_record.go`等。改动内容包括使用GORM的API进行数据库操作,如查询、插入、更新和删除,以替代原先的ORM方法。
Roc 5 months ago
parent
commit
0b0f7e73e6

+ 1 - 1
controllers/base_auth.go

@@ -214,7 +214,7 @@ func (c *BaseAuthController) Prepare() {
 				}
 			}
 
-			fmt.Println(api)
+			//fmt.Println(api)
 			//处理uri请求,去除前缀和参数
 			api = strings.TrimRight(api, "&")
 			uri = strings.Replace(uri, "/adminapi", "", 1)

+ 7 - 7
models/data_manage/supply_analysis/base_from_stock_plant_data.go

@@ -1,8 +1,8 @@
 package supply_analysis
 
 import (
+	"eta_gn/eta_api/global"
 	"eta_gn/eta_api/utils"
-	"github.com/beego/beego/v2/client/orm"
 	"time"
 )
 
@@ -19,9 +19,9 @@ type BaseFromStockPlantData struct {
 
 // GetDataVarietyEdbInfoByVarietyId 根据品种指标id获取所有的数据
 func GetDataVarietyEdbInfoByVarietyId(varietyEdbId int) (items []*BaseFromStockPlantData, err error) {
-	o := orm.NewOrmUsingDB("data")
 	sql := `SELECT * FROM base_from_stock_plant_data AS a WHERE a.variety_edb_id = ? ORDER BY data_time desc `
-	_, err = o.Raw(sql, varietyEdbId).QueryRows(&items)
+	err = global.DmSQL["data"].Raw(sql, varietyEdbId).Find(&items).Error
+
 	return
 }
 
@@ -39,8 +39,8 @@ func GetVarietyEdbData(varietyEdbId int, startDate, endDate string) (items []*Ba
 	}
 
 	sql += ` ORDER BY data_time ASC `
-	o := orm.NewOrmUsingDB("data")
-	_, err = o.Raw(sql, varietyEdbId, pars).QueryRows(&items)
+	err = global.DmSQL["data"].Raw(sql, varietyEdbId, pars).Scan(&items).Error
+
 	return
 }
 
@@ -51,7 +51,7 @@ func GetVarietyEdbDataListByIdList(varietyEdbIdList []int) (items []*BaseFromSto
 	}
 	sql := `SELECT * FROM  (SELECT * FROM base_from_stock_plant_data AS a WHERE a.variety_edb_id in (` + utils.GetOrmInReplace(num) + `) GROUP BY data_time) d ORDER BY data_time DESC `
 	//sql := ` SELECT  DT FROM edbdata WHERE TRADE_CODE IN(` + tradeCode + `)  GROUP BY DT ORDER BY DT DESC `
-	o := orm.NewOrmUsingDB("data")
-	_, err = o.Raw(sql, varietyEdbIdList).QueryRows(&items)
+	err = global.DmSQL["data"].Raw(sql, varietyEdbIdList).Find(&items).Error
+
 	return
 }

+ 25 - 25
models/data_manage/supply_analysis/variety.go

@@ -1,6 +1,9 @@
 package supply_analysis
 
 import (
+	"eta_gn/eta_api/global"
+	"eta_gn/eta_api/utils"
+	"fmt"
 	"github.com/beego/beego/v2/client/orm"
 	"time"
 )
@@ -20,33 +23,30 @@ type Variety struct {
 
 // GetVarietyById 根据品种id获取品种详情
 func GetVarietyById(id int) (item *Variety, err error) {
-	o := orm.NewOrmUsingDB("data")
 	sql := `SELECT * FROM variety WHERE variety_id = ?`
-	err = o.Raw(sql, id).QueryRow(&item)
+	err = global.DmSQL["data"].Raw(sql, id).First(&item).Error
+
 	return
 }
 
 // GetVarietyByName 根据品种名称获取品种详情
 func GetVarietyByName(name string) (item *Variety, err error) {
-	o := orm.NewOrmUsingDB("data")
 	sql := `SELECT * FROM variety WHERE variety_name = ?`
-	err = o.Raw(sql, name).QueryRow(&item)
+	err = global.DmSQL["data"].Raw(sql, name).First(&item).Error
+
 	return
 }
 
 // AddVariety 添加品种
 func AddVariety(item *Variety) (lastId int64, err error) {
-	o := orm.NewOrmUsingDB("data")
-	lastId, err = o.Insert(item)
+	err = global.DmSQL["data"].Create(item).Error
+
 	return
 }
 
 // CreateVariety 添加品种
 func CreateVariety(item *Variety, adminIdList []int) (err error) {
-	to, err := orm.NewOrmUsingDB("data").Begin()
-	if err != nil {
-		return
-	}
+	to := global.DmSQL["data"].Begin()
 
 	defer func() {
 		if err != nil {
@@ -55,13 +55,11 @@ func CreateVariety(item *Variety, adminIdList []int) (err error) {
 			_ = to.Commit()
 		}
 	}()
-	lastId, err := to.Insert(item)
+	err = to.Create(item).Error
 	if err != nil {
 		return
 	}
 
-	item.VarietyId = int(lastId)
-
 	varietyAdminPermissionList := make([]*VarietyAdminPermission, 0)
 	for _, adminId := range adminIdList {
 		varietyAdminPermissionList = append(varietyAdminPermissionList, &VarietyAdminPermission{
@@ -72,7 +70,7 @@ func CreateVariety(item *Variety, adminIdList []int) (err error) {
 		})
 	}
 	if len(varietyAdminPermissionList) > 0 {
-		_, err = to.InsertMulti(len(varietyAdminPermissionList), varietyAdminPermissionList)
+		err = to.CreateInBatches(varietyAdminPermissionList, utils.MultiAddNum).Error
 	}
 
 	return
@@ -80,10 +78,7 @@ func CreateVariety(item *Variety, adminIdList []int) (err error) {
 
 // EditVariety 编辑品种
 func EditVariety(item *Variety, adminIdList []int) (err error) {
-	to, err := orm.NewOrmUsingDB("data").Begin()
-	if err != nil {
-		return
-	}
+	to := global.DmSQL["data"].Begin()
 
 	defer func() {
 		if err != nil {
@@ -93,14 +88,14 @@ func EditVariety(item *Variety, adminIdList []int) (err error) {
 		}
 	}()
 
-	_, err = to.Update(item, "VarietyName", "LastUpdateSysUserId", "LastUpdateSysUserRealName", "ModifyTime")
+	err = to.Select("VarietyName", "LastUpdateSysUserId", "LastUpdateSysUserRealName", "ModifyTime").Updates(item).Error
 	if err != nil {
 		return
 	}
 
 	// 删除历史的权限配置
 	sql := `DELETE FROM variety_admin_permission where variety_id = ? `
-	_, err = to.Raw(sql, item.VarietyId).Exec()
+	err = to.Exec(sql, item.VarietyId).Error
 	if err != nil {
 		return
 	}
@@ -116,7 +111,7 @@ func EditVariety(item *Variety, adminIdList []int) (err error) {
 		})
 	}
 	if len(varietyAdminPermissionList) > 0 {
-		_, err = to.InsertMulti(len(varietyAdminPermissionList), varietyAdminPermissionList)
+		err = to.CreateInBatches(varietyAdminPermissionList, utils.MultiAddNum).Error
 	}
 
 	return
@@ -143,7 +138,6 @@ type VarietyButton struct {
 
 // GetListBySuperAdminPage 不区分是否有分析权限的获取分页数据
 func (item Variety) GetListBySuperAdminPage(condition string, pars []interface{}, startSize, pageSize int) (total int, items []*VarietyItem, err error) {
-	o := orm.NewOrmUsingDB("data")
 	baseSql := ` FROM ( SELECT a.*, GROUP_CONCAT(DISTINCT b.sys_user_id ORDER BY b.sys_user_id ASC SEPARATOR ',') AS permission_user_id FROM variety a 
 				LEFT JOIN variety_admin_permission b on a.variety_id=b.variety_id 
 				LEFT JOIN variety_edb_info c on a.variety_id=c.variety_id WHERE 1=1 `
@@ -153,14 +147,20 @@ func (item Variety) GetListBySuperAdminPage(condition string, pars []interface{}
 	baseSql += ` GROUP BY a.variety_id ) d `
 	// 数据总数
 	totalSql := `SELECT COUNT(1) total ` + baseSql
-	err = o.Raw(totalSql, pars).QueryRow(&total)
+
+	var adminMenusCount int64
+
+	err = global.DmSQL["data"].Select("total").Raw(totalSql, pars).Row().Scan(&adminMenusCount)
 	if err != nil {
+		fmt.Println("Count Err:", err)
 		return
 	}
-
+	fmt.Println(adminMenusCount)
+	total = int(adminMenusCount)
 	// 列表页数据
 	listSql := `SELECT * ` + baseSql + ` ORDER BY modify_time DESC,variety_id DESC LIMIT ?,?`
-	_, err = o.Raw(listSql, pars, startSize, pageSize).QueryRows(&items)
+	err = global.DmSQL["data"].Raw(listSql, pars, startSize, pageSize).Scan(&items).Error
+
 	return
 }
 

+ 4 - 8
models/data_manage/trade_analysis/trade_analysis.go

@@ -1,7 +1,7 @@
 package trade_analysis
 
 import (
-	"github.com/beego/beego/v2/client/orm"
+	"eta_gn/eta_api/global"
 	"time"
 )
 
@@ -96,9 +96,7 @@ func GetExchangeClassify(exchange string) (list []TradeClassifyName, err error)
 	}
 	sql := "SELECT classify_name, classify_type FROM " + tableName + " WHERE `rank` <=20 and `rank` > 0 GROUP BY classify_name, classify_type  "
 	sql += ` ORDER BY ` + orderStr
-
-	o := orm.NewOrmUsingDB("data")
-	_, err = o.Raw(sql).QueryRows(&list)
+	err = global.DmSQL["data"].Raw(sql).Scan(&list).Error
 
 	return
 }
@@ -111,8 +109,7 @@ type LastTimeItem struct {
 func GetExchangeLastTime(exchange string) (item LastTimeItem, err error) {
 	tableName := "base_from_trade_" + exchange + "_index"
 	sql := `SELECT create_time FROM ` + tableName + ` ORDER BY create_time desc`
-	o := orm.NewOrmUsingDB("data")
-	err = o.Raw(sql).QueryRow(&item)
+	err = global.DmSQL["data"].Raw(sql).Scan(&item).Error
 
 	return
 }
@@ -154,8 +151,7 @@ func GetTradePositionTop(exchange string, classifyName, classifyType, dataTime s
 	tableName := "trade_position_" + exchange + "_top"
 	sql := `SELECT * FROM ` + tableName + " WHERE classify_name=? and classify_type=? and data_time=? and `rank` <=20 and `rank` > 0 ORDER BY deal_value desc"
 
-	o := orm.NewOrmUsingDB("data")
-	_, err = o.Raw(sql, classifyName, classifyType, dataTime).QueryRows(&list)
+	err = global.DmSQL["data"].Raw(sql, classifyName, classifyType, dataTime).Find(&list).Error
 
 	return
 }

+ 4 - 8
models/data_manage/trade_analysis/trade_classify.go

@@ -1,7 +1,7 @@
 package trade_analysis
 
 import (
-	"github.com/beego/beego/v2/client/orm"
+	"eta_gn/eta_api/global"
 	"time"
 )
 
@@ -19,9 +19,7 @@ type BaseFromTradeClassify struct {
 // GetAllBaseFromTradeClassify 获取所有的交易所分类列表
 func GetAllBaseFromTradeClassify() (list []*BaseFromTradeClassify, err error) {
 	sql := `SELECT * FROM base_from_trade_classify   `
-
-	o := orm.NewOrmUsingDB("data")
-	_, err = o.Raw(sql).QueryRows(&list)
+	err = global.DmSQL["data"].Raw(sql).Find(&list).Error
 
 	return
 }
@@ -39,8 +37,7 @@ func GetTradeTopLastDataTime(exchange string, classifyName, classifyType string)
 		pars = append(pars, classifyType)
 	}
 	sql += ` ORDER BY latest_date desc`
-	o := orm.NewOrmUsingDB("data")
-	err = o.Raw(sql, pars...).QueryRow(&item)
+	err = global.DmSQL["data"].Raw(sql, pars...).First(&item).Error
 
 	return
 }
@@ -48,8 +45,7 @@ func GetTradeTopLastDataTime(exchange string, classifyName, classifyType string)
 // GetClassifyTypeByClassifyName 根据分类名称获取分类类型
 func GetClassifyTypeByClassifyName(exchange, classifyName string) (item *TradeClassifyName, err error) {
 	sql := `SELECT classify_name, classify_type FROM base_from_trade_classify WHERE exchange = ? AND classify_name=? `
-	o := orm.NewOrmUsingDB("data")
-	err = o.Raw(sql, exchange, classifyName).QueryRow(&item)
+	err = global.DmSQL["data"].Raw(sql, exchange, classifyName).First(&item).Error
 
 	return
 }

+ 1 - 1
models/db.go

@@ -2,7 +2,7 @@ package models
 
 import "eta_gn/eta_api/models/data_manage"
 
-// afterInitTable
+// AfterInitTable
 // @Description: 初始化表结构的的后置操作
 // @author: Roc
 // @datetime 2024-07-01 13:31:09

+ 6 - 4
models/system/admin_operate_record.go

@@ -1,12 +1,13 @@
 package system
 
 import (
-	"github.com/beego/beego/v2/client/orm"
+	"eta_gn/eta_api/global"
+	"fmt"
 	"time"
 )
 
 type AdminOperateRecord struct {
-	AdminOperateRecordId int       `orm:"column(admin_operate_record_id);pk" description:"id"`
+	AdminOperateRecordId int       `gorm:"primaryKey" orm:"column(admin_operate_record_id);pk" description:"id"`
 	AdminId              int       `description:"系统客户id"`
 	RealName             string    `description:"配置编码"`
 	Uuid                 string    `description:"配置值"`
@@ -19,7 +20,8 @@ type AdminOperateRecord struct {
 }
 
 func (item *AdminOperateRecord) Insert() (err error) {
-	o := orm.NewOrm()
-	_, err = o.Insert(item)
+	err = global.DEFAULT_DmSQL.Create(item).Error
+	fmt.Println("AdminOperateRecord ERR: ", err)
+
 	return
 }

+ 5 - 3
models/system/sys_menu.go

@@ -1,6 +1,7 @@
 package system
 
 import (
+	"eta_gn/eta_api/global"
 	"fmt"
 	"github.com/beego/beego/v2/client/orm"
 	"strings"
@@ -22,7 +23,7 @@ type SysMenu struct {
 	Sort       string    `description:"排序"`
 	Path       string    `description:"路由地址"`
 	IconPath   string    `description:"菜单图标地址"`
-	Component  int       `description:"组件路径"`
+	Component  string    `description:"组件路径"`
 	Hidden     int       `description:"是否隐藏:1-隐藏 0-显示"`
 	IsLevel    int       `description:"是否为多级菜单:1,只有一级;2,有多级"`
 	LevelPath  string    `description:"兼容以前menu表的字段"`
@@ -165,7 +166,7 @@ type SysMenuItem struct {
 	Sort       string         `description:"排序"`
 	Path       string         `description:"路由地址"`
 	IconPath   string         `description:"菜单图标地址"`
-	Component  int            `description:"组件路径"`
+	Component  string         `description:"组件路径"`
 	Hidden     int            `description:"是否隐藏:1-隐藏 0-显示"`
 	MenuType   int            `description:"菜单类型: 0-菜单; 1-按钮; 2-字段(需要特殊处理)"`
 	ButtonCode string         `description:"按钮/菜单唯一标识"`
@@ -190,7 +191,8 @@ func GetMenuButtonApisByRoleId(roleId int) (items []*SysMenu, err error) {
 			ORDER BY
 				r.sort ASC,
 				r.create_time DESC`
-	_, err = orm.NewOrm().Raw(sql, roleId).QueryRows(&items)
+	err = global.DEFAULT_DmSQL.Raw(sql, roleId).Scan(&items).Error
+
 	return
 }
 

+ 4 - 4
models/system/sys_session.go

@@ -39,16 +39,16 @@ func GetSysSessionBySysUserId(sysUserId int) (item *SysSession, err error) {
 // GetSysSessionByToken 根据token获取session
 func GetSysSessionByToken(token string) (item *SysSession, err error) {
 	sql := `SELECT * FROM sys_session WHERE access_token=? AND expired_time> NOW() ORDER BY expired_time DESC LIMIT 1 `
-	o := orm.NewOrm()
-	err = o.Raw(sql, token).QueryRow(&item)
+	err = global.DEFAULT_DmSQL.Raw(sql, token).First(&item).Error
+
 	return
 }
 
 // ExpiredSysSessionByAdminId 过期掉用户token
 func ExpiredSysSessionByAdminId(adminId int) (err error) {
 	sql := `update sys_session set expired_time = NOW()  WHERE sys_user_id=? `
-	o := orm.NewOrm()
-	_, err = o.Raw(sql, adminId).Exec()
+	err = global.DEFAULT_DmSQL.Exec(sql, adminId).Error
+
 	return
 }
 

+ 11 - 11
models/system/sys_user.go

@@ -69,14 +69,14 @@ type Admin struct {
 
 // Update 更新用户基础信息
 func (item *Admin) Update(cols []string) (err error) {
-	o := global.DEFAULT_DmSQL
-	err = o.Select(cols).Updates(*item).Error
+	//o := global.DEFAULT_DmSQL
+	err = global.DEFAULT_DmSQL.Select(cols).Updates(*item).Error
 
 	return
 }
 
 func CheckSysUser(userName, password string) (item *Admin, err error) {
-	sql := ` SELECT a.*,b.role_type_code FROM admin AS a
+	sql := ` SELECT a.*,b.role_type_code FROM "admin" AS a
 			 INNER JOIN sys_role AS b ON a.role_id=b.role_id WHERE a.admin_name=? AND a.password=? LIMIT 1`
 	o := orm.NewOrm()
 	err = o.Raw(sql, userName, password).QueryRow(&item)
@@ -84,14 +84,14 @@ func CheckSysUser(userName, password string) (item *Admin, err error) {
 }
 
 func GetSysUserById(sysUserId int) (item *Admin, err error) {
-	sql := `SELECT * FROM admin WHERE admin_id=? LIMIT 1`
-	o := orm.NewOrm()
-	err = o.Raw(sql, sysUserId).QueryRow(&item)
+	sql := `SELECT * FROM "admin" WHERE admin_id=? LIMIT 1`
+	err = global.DEFAULT_DmSQL.Raw(sql, sysUserId).First(&item).Error
+
 	return
 }
 
 func ModifyPwd(sysUserId int, newPwd string) (err error) {
-	sql := `UPDATE admin SET password=?,last_updated_time=NOW() WHERE admin_id=? `
+	sql := `UPDATE "admin" SET password=?,last_updated_time=NOW() WHERE admin_id=? `
 	o := orm.NewOrm()
 	_, err = o.Raw(sql, newPwd, sysUserId).Exec()
 	return
@@ -116,7 +116,7 @@ func GetAdminListByIdList(idList []int) (items []*Admin, err error) {
 	if lenNum <= 0 {
 		return
 	}
-	sql := `SELECT * FROM admin WHERE admin_id in (` + utils.GetOrmInReplace(lenNum) + `) and enabled=1 `
+	sql := `SELECT * FROM "admin" WHERE admin_id in (` + utils.GetOrmInReplace(lenNum) + `) and enabled=1 `
 	o := orm.NewOrm()
 	_, err = o.Raw(sql, idList).QueryRows(&items)
 	return
@@ -128,7 +128,7 @@ func GetAdminListByIdListWithoutEnable(idList []int) (items []*Admin, err error)
 	if lenNum <= 0 {
 		return
 	}
-	sql := `SELECT * FROM admin WHERE admin_id in (` + utils.GetOrmInReplace(lenNum) + `) `
+	sql := `SELECT * FROM "admin" WHERE admin_id in (` + utils.GetOrmInReplace(lenNum) + `) `
 	o := orm.NewOrm()
 	_, err = o.Raw(sql, idList).QueryRows(&items)
 	return
@@ -136,7 +136,7 @@ func GetAdminListByIdListWithoutEnable(idList []int) (items []*Admin, err error)
 
 func (item *Admin) GetCountByCondition(condition string, pars []interface{}) (count int, err error) {
 	o := orm.NewOrm()
-	sql := fmt.Sprintf(`SELECT COUNT(1) FROM admin WHERE 1=1 %s`, condition)
+	sql := fmt.Sprintf(`SELECT COUNT(1) FROM "admin" WHERE 1=1 %s`, condition)
 	err = o.Raw(sql, pars).QueryRow(&count)
 	return
 }
@@ -151,7 +151,7 @@ func (item *Admin) GetItemsByCondition(condition string, pars []interface{}, fie
 	if orderRule != "" {
 		order = ` ORDER BY ` + orderRule
 	}
-	sql := fmt.Sprintf(`SELECT %s FROM admin WHERE 1=1 %s %s`, fields, condition, order)
+	sql := fmt.Sprintf(`SELECT %s FROM "admin" WHERE 1=1 %s %s`, fields, condition, order)
 	_, err = o.Raw(sql, pars).QueryRows(&items)
 	return
 }