From 86cdca063f41835f7acf068df463f14eecd11e17 Mon Sep 17 00:00:00 2001 From: Remco Date: Thu, 29 Jan 2015 16:29:59 +0100 Subject: [PATCH 1/2] fixed issue with not overriding of existing value of map, and added expect test --- merge.go | 6 ++++-- mergo_test.go | 15 +++++++++++++-- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/merge.go b/merge.go index 5d328b1..45f4803 100644 --- a/merge.go +++ b/merge.go @@ -53,8 +53,10 @@ func deepMerge(dst, src reflect.Value, visited map[uintptr]*visit, depth int) (e if err = deepMerge(dstElement, srcElement, visited, depth+1); err != nil { return } - } - if !dstElement.IsValid() { + if !dstElement.IsValid() { + dst.SetMapIndex(key, srcElement) + } + default: dst.SetMapIndex(key, srcElement) } } diff --git a/mergo_test.go b/mergo_test.go index 072bddb..9bca96d 100644 --- a/mergo_test.go +++ b/mergo_test.go @@ -136,17 +136,28 @@ func TestMaps(t *testing.T) { m := map[string]simpleTest{ "a": simpleTest{}, "b": simpleTest{42}, + "d": simpleTest{61}, } n := map[string]simpleTest{ "a": simpleTest{16}, "b": simpleTest{}, "c": simpleTest{12}, + "e": simpleTest{14}, } + expect := map[string]simpleTest{ + "a": simpleTest{0}, + "b": simpleTest{42}, + "c": simpleTest{12}, + "d": simpleTest{61}, + "e": simpleTest{14}, + } + if err := Merge(&m, n); err != nil { t.Fatalf(err.Error()) } - if len(m) != 3 { - t.Fatalf(`n not merged in m properly, m must have 3 elements instead of %d`, len(m)) + + if !reflect.DeepEqual(m, expect) { + t.Fatalf("Test failed:\ngot :\n%#v\n\nwant :\n%#v\n\n", m, expect) } if m["a"].Value != 0 { t.Fatalf(`n merged in m because I solved non-addressable map values TODO: m["a"].Value(%d) != n["a"].Value(%d)`, m["a"].Value, n["a"].Value) From b6ee4c714c761fe9bfb2080560e516fb81c012ef Mon Sep 17 00:00:00 2001 From: Remco Date: Mon, 2 Feb 2015 11:33:23 +0100 Subject: [PATCH 2/2] added MergeWithOverwrite and MapWithOverwrite --- map.go | 20 ++++++++++++++------ merge.go | 26 ++++++++++++++++---------- mergo_test.go | 51 ++++++++++++++++++++++++++++++++++++++++++++++++--- 3 files changed, 78 insertions(+), 19 deletions(-) diff --git a/map.go b/map.go index 44361e8..1ed3d71 100644 --- a/map.go +++ b/map.go @@ -31,7 +31,7 @@ func isExported(field reflect.StructField) bool { // Traverses recursively both values, assigning src's fields values to dst. // The map argument tracks comparisons that have already been seen, which allows // short circuiting on recursive types. -func deepMap(dst, src reflect.Value, visited map[uintptr]*visit, depth int) (err error) { +func deepMap(dst, src reflect.Value, visited map[uintptr]*visit, depth int, overwrite bool) (err error) { if dst.CanAddr() { addr := dst.UnsafeAddr() h := 17 * addr @@ -57,7 +57,7 @@ func deepMap(dst, src reflect.Value, visited map[uintptr]*visit, depth int) (err } fieldName := field.Name fieldName = changeInitialCase(fieldName, unicode.ToLower) - if v, ok := dstMap[fieldName]; !ok || isEmptyValue(reflect.ValueOf(v)) { + if v, ok := dstMap[fieldName]; !ok || (isEmptyValue(reflect.ValueOf(v)) || overwrite) { dstMap[fieldName] = src.Field(i).Interface() } } @@ -89,12 +89,12 @@ func deepMap(dst, src reflect.Value, visited map[uintptr]*visit, depth int) (err continue } if srcKind == dstKind { - if err = deepMerge(dstElement, srcElement, visited, depth+1); err != nil { + if err = deepMerge(dstElement, srcElement, visited, depth+1, overwrite); err != nil { return } } else { if srcKind == reflect.Map { - if err = deepMap(dstElement, srcElement, visited, depth+1); err != nil { + if err = deepMap(dstElement, srcElement, visited, depth+1, overwrite); err != nil { return } } else { @@ -118,6 +118,14 @@ func deepMap(dst, src reflect.Value, visited map[uintptr]*visit, depth int) (err // This is separated method from Merge because it is cleaner and it keeps sane // semantics: merging equal types, mapping different (restricted) types. func Map(dst, src interface{}) error { + return _map(dst, src, false) +} + +func MapWithOverwrite(dst, src interface{}) error { + return _map(dst, src, true) +} + +func _map(dst, src interface{}, overwrite bool) error { var ( vDst, vSrc reflect.Value err error @@ -128,7 +136,7 @@ func Map(dst, src interface{}) error { // To be friction-less, we redirect equal-type arguments // to deepMerge. Only because arguments can be anything. if vSrc.Kind() == vDst.Kind() { - return deepMerge(vDst, vSrc, make(map[uintptr]*visit), 0) + return deepMerge(vDst, vSrc, make(map[uintptr]*visit), 0, overwrite) } switch vSrc.Kind() { case reflect.Struct: @@ -142,5 +150,5 @@ func Map(dst, src interface{}) error { default: return ErrNotSupported } - return deepMap(vDst, vSrc, make(map[uintptr]*visit), 0) + return deepMap(vDst, vSrc, make(map[uintptr]*visit), 0, overwrite) } diff --git a/merge.go b/merge.go index 45f4803..f565137 100644 --- a/merge.go +++ b/merge.go @@ -15,7 +15,7 @@ import ( // Traverses recursively both values, assigning src's fields values to dst. // The map argument tracks comparisons that have already been seen, which allows // short circuiting on recursive types. -func deepMerge(dst, src reflect.Value, visited map[uintptr]*visit, depth int) (err error) { +func deepMerge(dst, src reflect.Value, visited map[uintptr]*visit, depth int, overwrite bool) (err error) { if !src.IsValid() { return } @@ -35,7 +35,7 @@ func deepMerge(dst, src reflect.Value, visited map[uintptr]*visit, depth int) (e switch dst.Kind() { case reflect.Struct: for i, n := 0, dst.NumField(); i < n; i++ { - if err = deepMerge(dst.Field(i), src.Field(i), visited, depth+1); err != nil { + if err = deepMerge(dst.Field(i), src.Field(i), visited, depth+1, overwrite); err != nil { return } } @@ -50,13 +50,11 @@ func deepMerge(dst, src reflect.Value, visited map[uintptr]*visit, depth int) (e case reflect.Struct: fallthrough case reflect.Map: - if err = deepMerge(dstElement, srcElement, visited, depth+1); err != nil { + if err = deepMerge(dstElement, srcElement, visited, depth+1, overwrite); err != nil { return } - if !dstElement.IsValid() { - dst.SetMapIndex(key, srcElement) - } - default: + } + if !isEmptyValue(srcElement) && (overwrite || !dstElement.IsValid()) { dst.SetMapIndex(key, srcElement) } } @@ -66,10 +64,10 @@ func deepMerge(dst, src reflect.Value, visited map[uintptr]*visit, depth int) (e if src.IsNil() { break } else if dst.IsNil() { - if dst.CanSet() && isEmptyValue(dst) { + if dst.CanSet() && (isEmptyValue(dst) || overwrite) { dst.Set(src) } - } else if err = deepMerge(dst.Elem(), src.Elem(), visited, depth+1); err != nil { + } else if err = deepMerge(dst.Elem(), src.Elem(), visited, depth+1, overwrite); err != nil { return } default: @@ -87,6 +85,14 @@ func deepMerge(dst, src reflect.Value, visited map[uintptr]*visit, depth int) (e // It won't merge unexported (private) fields and will do recursively // any exported field. func Merge(dst, src interface{}) error { + return merge(dst, src, false) +} + +func MergeWithOverwrite(dst, src interface{}) error { + return merge(dst, src, true) +} + +func merge(dst, src interface{}, overwrite bool) error { var ( vDst, vSrc reflect.Value err error @@ -97,5 +103,5 @@ func Merge(dst, src interface{}) error { if vDst.Type() != vSrc.Type() { return ErrDifferentArgumentsTypes } - return deepMerge(vDst, vSrc, make(map[uintptr]*visit), 0) + return deepMerge(vDst, vSrc, make(map[uintptr]*visit), 0, overwrite) } diff --git a/mergo_test.go b/mergo_test.go index 9bca96d..2924a47 100644 --- a/mergo_test.go +++ b/mergo_test.go @@ -82,6 +82,20 @@ func TestComplexStruct(t *testing.T) { } } +func TestComplexStructWithOverwrite(t *testing.T) { + a := complexTest{simpleTest{1}, 1, "do-not-overwrite-with-empty-value"} + b := complexTest{simpleTest{42}, 2, ""} + + expect := complexTest{simpleTest{42}, 1, "do-not-overwrite-with-empty-value"} + if err := MergeWithOverwrite(&a, b); err != nil { + t.FailNow() + } + + if !reflect.DeepEqual(a, expect) { + t.Fatalf("Test failed:\ngot :\n%#v\n\nwant :\n%#v\n\n", a, expect) + } +} + func TestPointerStruct(t *testing.T) { s1 := simpleTest{} s2 := simpleTest{19} @@ -132,10 +146,41 @@ func TestSliceStruct(t *testing.T) { } } +func TestMapsWithOverwrite(t *testing.T) { + m := map[string]simpleTest{ + "a": simpleTest{}, // overwritten by 16 + "b": simpleTest{42}, // not overwritten by empty value + "c": simpleTest{13}, // overwritten by 12 + "d": simpleTest{61}, + } + n := map[string]simpleTest{ + "a": simpleTest{16}, + "b": simpleTest{}, + "c": simpleTest{12}, + "e": simpleTest{14}, + } + expect := map[string]simpleTest{ + "a": simpleTest{16}, + "b": simpleTest{}, + "c": simpleTest{12}, + "d": simpleTest{61}, + "e": simpleTest{14}, + } + + if err := MergeWithOverwrite(&m, n); err != nil { + t.Fatalf(err.Error()) + } + + if !reflect.DeepEqual(m, expect) { + t.Fatalf("Test failed:\ngot :\n%#v\n\nwant :\n%#v\n\n", m, expect) + } +} + func TestMaps(t *testing.T) { m := map[string]simpleTest{ "a": simpleTest{}, "b": simpleTest{42}, + "c": simpleTest{13}, "d": simpleTest{61}, } n := map[string]simpleTest{ @@ -147,7 +192,7 @@ func TestMaps(t *testing.T) { expect := map[string]simpleTest{ "a": simpleTest{0}, "b": simpleTest{42}, - "c": simpleTest{12}, + "c": simpleTest{13}, "d": simpleTest{61}, "e": simpleTest{14}, } @@ -165,8 +210,8 @@ func TestMaps(t *testing.T) { if m["b"].Value != 42 { t.Fatalf(`n wrongly merged in m: m["b"].Value(%d) != n["b"].Value(%d)`, m["b"].Value, n["b"].Value) } - if m["c"].Value != 12 { - t.Fatalf(`n not merged in m: m["c"].Value(%d) != n["c"].Value(%d)`, m["c"].Value, n["c"].Value) + if m["c"].Value != 13 { + t.Fatalf(`n overwritten in m: m["c"].Value(%d) != n["c"].Value(%d)`, m["c"].Value, n["c"].Value) } }