12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576 |
- package character
- import (
- "fmt"
- "strings"
- )
- // BuildBatchUpdateSql 生成批量更新sql
- // tableName := "articles"
- // where := make(map[string][]int)
- // where["id"] = []int{180, 181, 182, 183}
- // where["user_id"] = []int{5, 15, 11, 1}
- // needUpdateFields := make(map[string][]int)
- // needUpdateFields["view_count"] = []int{11, 22, 33, 44}
- // needUpdateFields["updated_at"] = []int{1653147405, 1653147405, 1653147405, 1653147405}
- func BuildBatchUpdateSql(tableName string, where, needUpdateFields map[string][]interface{}) string {
- if len(where) == 0 || len(needUpdateFields) == 0 {
- return ""
- }
- // 所有的条件字段数组
- var whereKeys []string
- for k := range where {
- whereKeys = append(whereKeys, k)
- }
- // 第一个 where 条件所有的值
- firstWhere := where[whereKeys[0]]
- // 所有需要更新的字段数组
- var needUpdateFieldsKeys []string
- for k := range needUpdateFields {
- needUpdateFieldsKeys = append(needUpdateFieldsKeys, k)
- }
- if len(firstWhere) != len(needUpdateFields[needUpdateFieldsKeys[0]]) {
- // 更新的条件与更新的字段值数量不相等
- return ""
- }
- var s1 []string
- for k := range firstWhere {
- for _, vv := range whereKeys {
- s1 = append(s1, fmt.Sprintf("%s = %v AND ", vv, where[vv][k]))
- }
- }
- // 按照 where 条件字段数量做切割
- whereSize := len(whereKeys)
- batches := make([][]string, 0, (len(s1)+whereSize-1)/whereSize)
- for whereSize < len(s1) {
- s1, batches = s1[whereSize:], append(batches, s1[0:whereSize:whereSize])
- }
- batches = append(batches, s1)
- var whereArr []string
- for _, v := range batches {
- whereArr = append(whereArr, strings.TrimSuffix(strings.Join(v, " "), "AND "))
- }
- // 拼接 sql 语句
- sqlStr := ""
- for _, v := range needUpdateFieldsKeys {
- str := ""
- for kk, vv := range whereArr {
- str += fmt.Sprintf(" WHEN %v THEN %v ", vv, needUpdateFields[v][kk])
- }
- sqlStr += fmt.Sprintf("%s = CASE %s ELSE %s END, ", v, str, v)
- }
- // 去除掉最后面的逗号及空格
- sqlStr = strings.TrimSuffix(sqlStr, ", ")
- caseWhenSql := fmt.Sprintf("UPDATE %s SET %s", tableName, sqlStr)
- return caseWhenSql
- }
|