sql.go 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. package utils
  2. import (
  3. "errors"
  4. "fmt"
  5. "regexp"
  6. )
  7. type Driver string
  8. const (
  9. DM Driver = "dm"
  10. MySql Driver = "mysql"
  11. )
  12. var supportDriverMap = map[string]Driver{
  13. "mysql": MySql,
  14. "dm": DM,
  15. }
  16. func GroupUnitFunc(driver string, column, delimiter string) (sqlStr string) {
  17. dbDriver, _ := getDriverInstance(driver)
  18. if delimiter == "" {
  19. delimiter = ","
  20. }
  21. if column == "" {
  22. column = "[UNKNOWN COLUMN]"
  23. }
  24. //if tableAlia != "" {
  25. // column = fmt.Sprintf("%s.%s", tableAlia, column)
  26. //}
  27. switch dbDriver {
  28. case MySql:
  29. sqlStr = fmt.Sprintf("GROUP_CONCAT(%s SEPARATOR '%s')", column, delimiter)
  30. case DM:
  31. sqlStr = fmt.Sprintf("LISTAGG(%s, '%s') WITHIN GROUP (ORDER BY %s)", column, delimiter, column)
  32. default:
  33. sqlStr = fmt.Sprintf("GROUP_CONCAT(%s SEPARATOR '%s')", column, delimiter) // 默认使用 MySQL 的语法
  34. }
  35. return sqlStr
  36. }
  37. func GroupUnitDistinctFunc(driver string, column, delimiter string) (sqlStr string) {
  38. dbDriver, _ := getDriverInstance(driver)
  39. if delimiter == "" {
  40. delimiter = ","
  41. }
  42. if column == "" {
  43. column = "[UNKNOWN COLUMN]"
  44. }
  45. //if tableAlia != "" {
  46. // column = fmt.Sprintf("%s.%s", tableAlia, column)
  47. //}
  48. switch dbDriver {
  49. case MySql:
  50. sqlStr = fmt.Sprintf("GROUP_CONCAT(DISTINCT %s SEPARATOR '%s')", column, delimiter)
  51. case DM:
  52. sqlStr = fmt.Sprintf("LISTAGG(DISTINCT %s, '%s') WITHIN GROUP (ORDER BY %s)", column, delimiter, column)
  53. default:
  54. sqlStr = fmt.Sprintf("GROUP_CONCAT(DISTINCT %s SEPARATOR '%s')", column, delimiter) // 默认使用 MySQL 的语法
  55. }
  56. return sqlStr
  57. }
  58. func getDriverInstance(driver string) (dbDriver Driver, err error) {
  59. if driver == "" {
  60. dbDriver = supportDriverMap[DbDriverName]
  61. }
  62. if currentDriver, ok := supportDriverMap[driver]; !ok {
  63. err = errors.New("不支持的数据库驱动类型")
  64. return
  65. } else {
  66. dbDriver = currentDriver
  67. }
  68. return
  69. }
  70. func NeedDateOrTimeFormat(driver string) bool {
  71. var dbDriver Driver
  72. if driver == "" {
  73. dbDriver = supportDriverMap[DbDriverName]
  74. } else {
  75. dbDriver, _ = getDriverInstance(driver)
  76. }
  77. if dbDriver == DM {
  78. return true
  79. }
  80. return false
  81. }
  82. func ReplaceDriverKeywords(driver string, sql string) string {
  83. dbDriver, _ := getDriverInstance(driver)
  84. rules := map[Driver]map[string]string{
  85. DM: {
  86. "admin": `"admin"`,
  87. "value": `"value"`,
  88. "exchange": `"exchange"`,
  89. },
  90. }
  91. replacements, ok := rules[dbDriver]
  92. if !ok {
  93. return sql
  94. }
  95. for keyword, replace := range replacements {
  96. // 仅替换单独的单词,复合单词含关键词不管
  97. pattern := fmt.Sprintf(`\b%s\b`, regexp.QuoteMeta(keyword))
  98. re := regexp.MustCompile(pattern)
  99. sql = re.ReplaceAllString(sql, replace)
  100. }
  101. return sql
  102. }