|
@@ -0,0 +1,80 @@
|
|
|
+package database
|
|
|
+
|
|
|
+import (
|
|
|
+ "database/sql"
|
|
|
+ "gorm.io/gorm"
|
|
|
+)
|
|
|
+
|
|
|
+type MysqlDataBase struct {
|
|
|
+ db *gorm.DB
|
|
|
+}
|
|
|
+
|
|
|
+func NewGORMAdapter(db *gorm.DB) *MysqlDataBase {
|
|
|
+ return &MysqlDataBase{db: db}
|
|
|
+}
|
|
|
+
|
|
|
+func (a *MysqlDataBase) Begin() (Tx, error) {
|
|
|
+ return &GORMTxAdapter{tx: a.db.Begin()}, nil
|
|
|
+}
|
|
|
+
|
|
|
+func (a *MysqlDataBase) Commit() error {
|
|
|
+ return a.db.Commit().Error
|
|
|
+}
|
|
|
+
|
|
|
+func (a *MysqlDataBase) Rollback() error {
|
|
|
+ return a.db.Rollback().Error
|
|
|
+}
|
|
|
+
|
|
|
+func (a *MysqlDataBase) Exec(query string, args ...interface{}) (sql.Result, error) {
|
|
|
+ res := a.db.Exec(query, args...)
|
|
|
+ return &gormResultAdapter{res}, res.Error
|
|
|
+}
|
|
|
+
|
|
|
+func (a *MysqlDataBase) Query(query string, args ...interface{}) (*sql.Rows, error) {
|
|
|
+ rows, err := a.db.Raw(query, args...).Rows()
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ return rows, nil
|
|
|
+}
|
|
|
+
|
|
|
+type GORMTxAdapter struct {
|
|
|
+ tx *gorm.DB
|
|
|
+}
|
|
|
+
|
|
|
+func (t *GORMTxAdapter) Commit() error {
|
|
|
+ return t.tx.Commit().Error
|
|
|
+}
|
|
|
+
|
|
|
+func (t *GORMTxAdapter) Rollback() error {
|
|
|
+ return t.tx.Rollback().Error
|
|
|
+}
|
|
|
+
|
|
|
+func (t *GORMTxAdapter) Exec(query string, args ...interface{}) (sql.Result, error) {
|
|
|
+ res := t.tx.Exec(query, args...)
|
|
|
+ return &gormResultAdapter{res}, res.Error
|
|
|
+}
|
|
|
+
|
|
|
+func (t *GORMTxAdapter) Query(query string, args ...interface{}) (*sql.Rows, error) {
|
|
|
+ rows, err := t.tx.Raw(query, args...).Rows()
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ return rows, nil
|
|
|
+}
|
|
|
+
|
|
|
+type gormResultAdapter struct {
|
|
|
+ res *gorm.DB
|
|
|
+}
|
|
|
+
|
|
|
+func (r *gormResultAdapter) LastInsertId() (int64, error) {
|
|
|
+ return r.res.RowsAffected, nil
|
|
|
+}
|
|
|
+
|
|
|
+func (r *gormResultAdapter) RowsAffected() (int64, error) {
|
|
|
+ return r.res.RowsAffected, nil
|
|
|
+}
|
|
|
+
|
|
|
+func init() {
|
|
|
+ Register("mysql", MY)
|
|
|
+}
|