sql.go 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. package character
  2. import (
  3. "fmt"
  4. "strings"
  5. )
  6. // BuildBatchUpdateSql 生成批量更新sql
  7. // tableName := "articles"
  8. // where := make(map[string][]int)
  9. // where["id"] = []int{180, 181, 182, 183}
  10. // where["user_id"] = []int{5, 15, 11, 1}
  11. // needUpdateFields := make(map[string][]int)
  12. // needUpdateFields["view_count"] = []int{11, 22, 33, 44}
  13. // needUpdateFields["updated_at"] = []int{1653147405, 1653147405, 1653147405, 1653147405}
  14. func BuildBatchUpdateSql(tableName string, where, needUpdateFields map[string][]interface{}) string {
  15. if len(where) == 0 || len(needUpdateFields) == 0 {
  16. return ""
  17. }
  18. // 所有的条件字段数组
  19. var whereKeys []string
  20. for k := range where {
  21. whereKeys = append(whereKeys, k)
  22. }
  23. // 第一个 where 条件所有的值
  24. firstWhere := where[whereKeys[0]]
  25. // 所有需要更新的字段数组
  26. var needUpdateFieldsKeys []string
  27. for k := range needUpdateFields {
  28. needUpdateFieldsKeys = append(needUpdateFieldsKeys, k)
  29. }
  30. if len(firstWhere) != len(needUpdateFields[needUpdateFieldsKeys[0]]) {
  31. // 更新的条件与更新的字段值数量不相等
  32. return ""
  33. }
  34. var s1 []string
  35. for k := range firstWhere {
  36. for _, vv := range whereKeys {
  37. s1 = append(s1, fmt.Sprintf("%s = %v AND ", vv, where[vv][k]))
  38. }
  39. }
  40. // 按照 where 条件字段数量做切割
  41. whereSize := len(whereKeys)
  42. batches := make([][]string, 0, (len(s1)+whereSize-1)/whereSize)
  43. for whereSize < len(s1) {
  44. s1, batches = s1[whereSize:], append(batches, s1[0:whereSize:whereSize])
  45. }
  46. batches = append(batches, s1)
  47. var whereArr []string
  48. for _, v := range batches {
  49. whereArr = append(whereArr, strings.TrimSuffix(strings.Join(v, " "), "AND "))
  50. }
  51. // 拼接 sql 语句
  52. sqlStr := ""
  53. for _, v := range needUpdateFieldsKeys {
  54. str := ""
  55. for kk, vv := range whereArr {
  56. str += fmt.Sprintf(" WHEN %v THEN %v ", vv, needUpdateFields[v][kk])
  57. }
  58. sqlStr += fmt.Sprintf("%s = CASE %s ELSE %s END, ", v, str, v)
  59. }
  60. // 去除掉最后面的逗号及空格
  61. sqlStr = strings.TrimSuffix(sqlStr, ", ")
  62. caseWhenSql := fmt.Sprintf("UPDATE %s SET %s", tableName, sqlStr)
  63. return caseWhenSql
  64. }