sql.go 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. package utils
  2. import (
  3. "errors"
  4. "fmt"
  5. "regexp"
  6. )
  7. type Driver string
  8. const (
  9. DM Driver = "dm"
  10. MySql Driver = "mysql"
  11. )
  12. var supportDriverMap = map[string]Driver{
  13. "mysql": MySql,
  14. "dm": DM,
  15. }
  16. func GroupUnitFunc(driver string, column, delimiter, tableAlia string) (sqlStr string) {
  17. dbDriver, _ := getDriverInstance(driver)
  18. if delimiter == "" {
  19. delimiter = ","
  20. }
  21. if column == "" {
  22. column = "[UNKNOWN COLUMN]"
  23. }
  24. if tableAlia != "" {
  25. column = fmt.Sprintf("%s.%s", tableAlia, column)
  26. }
  27. switch dbDriver {
  28. case MySql:
  29. sqlStr = fmt.Sprintf("GROUP_CONCAT(%s SEPARATOR '%s')", column, delimiter)
  30. case DM:
  31. sqlStr = fmt.Sprintf("LISTAGG(%s, '%s') WITHIN GROUP (ORDER BY %s)", column, delimiter, column)
  32. default:
  33. sqlStr = fmt.Sprintf("GROUP_CONCAT(%s SEPARATOR '%s')", column, delimiter) // 默认使用 MySQL 的语法
  34. }
  35. return sqlStr
  36. }
  37. func getDriverInstance(driver string) (dbDriver Driver, err error) {
  38. if driver == "" {
  39. dbDriver = supportDriverMap[DbDriverName]
  40. }
  41. if currentDriver, ok := supportDriverMap[driver]; !ok {
  42. err = errors.New("不支持的数据库驱动类型")
  43. return
  44. } else {
  45. dbDriver = currentDriver
  46. }
  47. return
  48. }
  49. func NeedDateOrTimeFormat(driver string) bool {
  50. var dbDriver Driver
  51. if driver == "" {
  52. dbDriver = supportDriverMap[DbDriverName]
  53. } else {
  54. dbDriver, _ = getDriverInstance(driver)
  55. }
  56. if dbDriver == DM {
  57. return true
  58. }
  59. return false
  60. }
  61. func ReplaceDriverKeywords(driver string, sql string) string {
  62. dbDriver, _ := getDriverInstance(driver)
  63. rules := map[Driver]map[string]string{
  64. DM: {
  65. "admin": `"admin"`,
  66. "value": `"value"`,
  67. "exchange": `"exchange"`,
  68. },
  69. }
  70. replacements, ok := rules[dbDriver]
  71. if !ok {
  72. return sql
  73. }
  74. for keyword, replace := range replacements {
  75. // 仅替换单独的单词,复合单词含关键词不管
  76. pattern := fmt.Sprintf(`\b%s\b`, regexp.QuoteMeta(keyword))
  77. re := regexp.MustCompile(pattern)
  78. sql = re.ReplaceAllString(sql, replace)
  79. }
  80. return sql
  81. }