sql.go 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. package utils
  2. import (
  3. "errors"
  4. "fmt"
  5. )
  6. type Driver string
  7. const (
  8. DM Driver = "dm"
  9. MySql Driver = "mysql"
  10. )
  11. var supportDriverMap = map[string]Driver{
  12. "mysql": MySql,
  13. "dm": DM,
  14. }
  15. func GroupUnitFunc(driver string, column, delimiter, tableAlia string) (sqlStr string) {
  16. dbDriver, _ := getDriverInstance(driver)
  17. if delimiter == "" {
  18. delimiter = ","
  19. }
  20. if column == "" {
  21. column = "[UNKNOWN COLUMN]"
  22. }
  23. if tableAlia != "" {
  24. column = fmt.Sprintf("%s.%s", tableAlia, column)
  25. }
  26. switch dbDriver {
  27. case MySql:
  28. sqlStr = fmt.Sprintf("GROUP_CONCAT(%s SEPARATOR '%s')", column, delimiter)
  29. case DM:
  30. sqlStr = fmt.Sprintf("LISTAGG(%s, '%s') WITHIN GROUP (ORDER BY %s)", column, delimiter, column)
  31. default:
  32. sqlStr = fmt.Sprintf("GROUP_CONCAT(%s SEPARATOR '%s')", column, delimiter) // 默认使用 MySQL 的语法
  33. }
  34. return sqlStr
  35. }
  36. func getDriverInstance(driver string) (dbDriver Driver, err error) {
  37. if driver == "" {
  38. dbDriver = supportDriverMap[DbDriverName]
  39. }
  40. if currentDriver, ok := supportDriverMap[driver]; !ok {
  41. err = errors.New("不支持的数据库驱动类型")
  42. return
  43. } else {
  44. dbDriver = currentDriver
  45. }
  46. return
  47. }
  48. func NeedDateOrTimeFormat(driver string) bool {
  49. var dbDriver Driver
  50. if driver == "" {
  51. dbDriver = supportDriverMap[DbDriverName]
  52. } else {
  53. dbDriver, _ = getDriverInstance(driver)
  54. }
  55. if dbDriver == DM {
  56. return true
  57. }
  58. return false
  59. }