deepcopy.go 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. // deepcopy makes deep copies of things. A standard copy will copy the
  2. // pointers: deep copy copies the values pointed to. Unexported field
  3. // values are not copied.
  4. //
  5. // Copyright (c)2014-2016, Joel Scoble (github.com/mohae), all rights reserved.
  6. // License: MIT, for more details check the included LICENSE file.
  7. package deepcopy
  8. import (
  9. "errors"
  10. "reflect"
  11. "time"
  12. )
  13. // copyContext 用于追踪已复制的对象,避免循环引用
  14. type copyContext struct {
  15. seen map[uintptr]reflect.Value
  16. deep int
  17. }
  18. // Interface for delegating copy process to type
  19. type DeepCopyer interface {
  20. DeepCopy() interface{}
  21. }
  22. // Copy creates a deep copy of whatever is passed to it and returns the copy
  23. // in an interface{}. The returned value will need to be asserted to the
  24. // correct type.
  25. // 支持以下类型的复制:
  26. // - 简单类型(int, string 等)
  27. // - 时间类型 (time.Time)
  28. // - 复合类型(map, slice, array)
  29. // - 指针类型
  30. // - 结构体(包含匿名字段)
  31. //
  32. // 特性:
  33. // - 支持循环引用检测
  34. // - 未导出字段不会被复制
  35. // - 实现了 DeepCopyer 接口的类型可以自定义复制行为
  36. // - 不支持的类型(Channel、Function等)会被跳过,不影响其他数据的复制
  37. //
  38. // 限制:
  39. // - 递归深度限制为20层
  40. // - 不支持的类型(Channel、Function等)会被跳过不处理
  41. func Copy(src interface{}) (interface{}, error) {
  42. if src == nil {
  43. return nil, nil
  44. }
  45. // Make the interface a reflect.Value
  46. original := reflect.ValueOf(src)
  47. // Make a copy of the same type as the original.
  48. cpy := reflect.New(original.Type()).Elem()
  49. // Create a new context for tracking copied objects
  50. ctx := &copyContext{
  51. seen: make(map[uintptr]reflect.Value),
  52. deep: 0,
  53. }
  54. // Recursively copy the original.
  55. if err := copyRecursive(ctx, original, cpy); err != nil {
  56. return nil, err
  57. }
  58. // Return the copy as an interface.
  59. return cpy.Interface(), nil
  60. }
  61. func MustCopy(src interface{}) interface{} {
  62. val, err := Copy(src)
  63. if err != nil {
  64. panic(err)
  65. }
  66. return val
  67. }
  68. // isUnsupportedType checks if the given type is unsupported for deep copy
  69. func isUnsupportedType(k reflect.Kind) bool {
  70. switch k {
  71. case reflect.Chan, reflect.Func, reflect.UnsafePointer, reflect.Invalid:
  72. return true
  73. }
  74. return false
  75. }
  76. // copyRecursive does the actual copying of the interface.
  77. // 支持的类型:
  78. // - 简单类型
  79. // - 时间类型 (time.Time)
  80. // - map, slice, array
  81. // - 指针
  82. // - 结构体及其匿名字段
  83. //
  84. // 特性:
  85. // - 使用 copyContext 处理循环引用
  86. // - 未导出字段会被跳过
  87. // - 支持 noCopy 标记的结构体
  88. // - 不支持的类型会被跳过不处理
  89. //
  90. // 参数:
  91. // - ctx: 复制上下文,用于追踪已复制的对象
  92. // - original: 原始值
  93. // - cpy: 目标值(指针的解引用)
  94. func copyRecursive(ctx *copyContext, original, cpy reflect.Value) error {
  95. ctx.deep++
  96. if ctx.deep > 20 {
  97. return errors.New("max recursion depth exceeded")
  98. }
  99. defer func() { ctx.deep-- }()
  100. // Handle nil pointer
  101. if !original.IsValid() {
  102. return nil
  103. }
  104. // 对不支持的类型直接跳过
  105. if isUnsupportedType(original.Kind()) {
  106. return nil
  107. }
  108. // Check for circular references
  109. if original.Kind() == reflect.Ptr || original.Kind() == reflect.Interface {
  110. if original.Kind() == reflect.Ptr {
  111. ptr := original.Pointer()
  112. if copied, ok := ctx.seen[ptr]; ok {
  113. cpy.Set(copied)
  114. return nil
  115. }
  116. if !original.IsNil() {
  117. ctx.seen[ptr] = cpy
  118. }
  119. }
  120. }
  121. // check for implement deepcopy.Interface
  122. if original.CanInterface() {
  123. if copier, ok := original.Interface().(DeepCopyer); ok {
  124. cpy.Set(reflect.ValueOf(copier.DeepCopy()))
  125. return nil
  126. }
  127. }
  128. var err error
  129. switch original.Kind() {
  130. case reflect.Ptr:
  131. originalValue := original.Elem()
  132. if !originalValue.IsValid() {
  133. return nil
  134. }
  135. cpy.Set(reflect.New(originalValue.Type()))
  136. err = copyRecursive(ctx, originalValue, cpy.Elem())
  137. case reflect.Interface:
  138. if original.IsNil() {
  139. return nil
  140. }
  141. originalValue := original.Elem()
  142. copyValue := reflect.New(originalValue.Type()).Elem()
  143. err = copyRecursive(ctx, originalValue, copyValue)
  144. if err == nil {
  145. cpy.Set(copyValue)
  146. }
  147. case reflect.Struct:
  148. t, ok := original.Interface().(time.Time)
  149. if ok {
  150. cpy.Set(reflect.ValueOf(t))
  151. return nil
  152. }
  153. oriType := original.Type()
  154. if _, ok := oriType.FieldByName("noCopy"); ok {
  155. return nil
  156. }
  157. for i := 0; i < original.NumField(); i++ {
  158. if !oriType.Field(i).IsExported() {
  159. continue
  160. }
  161. err = copyRecursive(ctx, original.Field(i), cpy.Field(i))
  162. if err != nil {
  163. // 忽略不支持类型的错误
  164. continue
  165. }
  166. }
  167. case reflect.Slice:
  168. if original.IsNil() {
  169. return nil
  170. }
  171. cpy.Set(reflect.MakeSlice(original.Type(), original.Len(), original.Cap()))
  172. for i := 0; i < original.Len(); i++ {
  173. err = copyRecursive(ctx, original.Index(i), cpy.Index(i))
  174. if err != nil {
  175. // 忽略不支持类型的错误
  176. continue
  177. }
  178. }
  179. case reflect.Map:
  180. if original.IsNil() {
  181. return nil
  182. }
  183. cpy.Set(reflect.MakeMap(original.Type()))
  184. for _, key := range original.MapKeys() {
  185. originalValue := original.MapIndex(key)
  186. copyValue := reflect.New(originalValue.Type()).Elem()
  187. err = copyRecursive(ctx, originalValue, copyValue)
  188. if err != nil {
  189. // 忽略不支持类型的错误
  190. continue
  191. }
  192. copyKey := reflect.New(key.Type()).Elem()
  193. err = copyRecursive(ctx, key, copyKey)
  194. if err != nil {
  195. // 忽略不支持类型的错误
  196. continue
  197. }
  198. cpy.SetMapIndex(copyKey, copyValue)
  199. }
  200. case reflect.Array:
  201. for i := 0; i < original.Len(); i++ {
  202. err = copyRecursive(ctx, original.Index(i), cpy.Index(i))
  203. if err != nil {
  204. // 忽略不支持类型的错误
  205. continue
  206. }
  207. }
  208. default:
  209. cpy.Set(original)
  210. }
  211. // 只返回非不支持类型的错误
  212. if err != nil && err.Error() != "max recursion depth exceeded" {
  213. return nil
  214. }
  215. return err
  216. }