From e2db15326ab337267aa752c13b000ade8e3cce0d Mon Sep 17 00:00:00 2001 From: "Artem V. Navrotskiy" Date: Thu, 31 Aug 2017 10:13:14 +0300 Subject: [PATCH] Fix GetBSON() method usage Original issue --- You can't use type with custom GetBSON() method mixed with structure field type and structure field reference type. For example, you can't create custom GetBSON() for Bar type: ``` struct Foo { a Bar b *Bar } ``` Type implementation (`func (t Bar) GetBSON()` ) would crash on `Foo.b = nil` value encoding. Reference implementation (`func (t *Bar) GetBSON()` ) would not call on `Foo.a` value encoding. After this change --- For type implementation `func (t Bar) GetBSON()` would not call on `Foo.b = nil` value encoding. In this case `nil` value would be seariazied as `nil` BSON value. For reference implementation `func (t *Bar) GetBSON()` would call even on `Foo.a` value encoding. --- bson/bson_test.go | 64 +++++++++++++++++++++++++++++++++++++++++++++++ bson/encode.go | 62 ++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 125 insertions(+), 1 deletion(-) diff --git a/bson/bson_test.go b/bson/bson_test.go index 37451f9fd..763967e68 100644 --- a/bson/bson_test.go +++ b/bson/bson_test.go @@ -377,8 +377,54 @@ func (s *S) Test64bitInt(c *C) { // -------------------------------------------------------------------------- // Generic two-way struct marshaling tests. +type prefixPtr string +type prefixVal string + +func (t *prefixPtr) GetBSON() (interface{}, error) { + if t == nil { + return nil, nil + } + return "foo-" + string(*t), nil +} + +func (t *prefixPtr) SetBSON(raw bson.Raw) error { + var s string + if raw.Kind == 0x0A { + return bson.SetZero + } + if err := raw.Unmarshal(&s); err != nil { + return err + } + if !strings.HasPrefix(s, "foo-") { + return errors.New("Prefix not found: " + s) + } + *t = prefixPtr(s[4:]) + return nil +} + +func (t prefixVal) GetBSON() (interface{}, error) { + return "foo-" + string(t), nil +} + +func (t *prefixVal) SetBSON(raw bson.Raw) error { + var s string + if raw.Kind == 0x0A { + return bson.SetZero + } + if err := raw.Unmarshal(&s); err != nil { + return err + } + if !strings.HasPrefix(s, "foo-") { + return errors.New("Prefix not found: " + s) + } + *t = prefixVal(s[4:]) + return nil +} + var bytevar = byte(8) var byteptr = &bytevar +var prefixptr = prefixPtr("bar") +var prefixval = prefixVal("bar") var structItems = []testItemType{ {&struct{ Ptr *byte }{nil}, @@ -415,6 +461,24 @@ var structItems = []testItemType{ // Byte arrays. {&struct{ V [2]byte }{[2]byte{'y', 'o'}}, "\x05v\x00\x02\x00\x00\x00\x00yo"}, + + {&struct{ V prefixPtr }{prefixPtr("buzz")}, + "\x02v\x00\x09\x00\x00\x00foo-buzz\x00"}, + + {&struct{ V *prefixPtr }{&prefixptr}, + "\x02v\x00\x08\x00\x00\x00foo-bar\x00"}, + + {&struct{ V *prefixPtr }{nil}, + "\x0Av\x00"}, + + {&struct{ V prefixVal }{prefixVal("buzz")}, + "\x02v\x00\x09\x00\x00\x00foo-buzz\x00"}, + + {&struct{ V *prefixVal }{&prefixval}, + "\x02v\x00\x08\x00\x00\x00foo-bar\x00"}, + + {&struct{ V *prefixVal }{nil}, + "\x0Av\x00"}, } func (s *S) TestMarshalStructItems(c *C) { diff --git a/bson/encode.go b/bson/encode.go index add39e865..fe4dd9ebd 100644 --- a/bson/encode.go +++ b/bson/encode.go @@ -34,6 +34,7 @@ import ( "net/url" "reflect" "strconv" + "sync" "time" ) @@ -58,13 +59,28 @@ var ( const itoaCacheSize = 32 +const ( + getterUnknown = iota + getterNone + getterTypeVal + getterTypePtr + getterAddr +) + var itoaCache []string +var getterStyles map[reflect.Type]int +var getterIface reflect.Type +var getterMutex sync.RWMutex + func init() { itoaCache = make([]string, itoaCacheSize) for i := 0; i != itoaCacheSize; i++ { itoaCache[i] = strconv.Itoa(i) } + var iface Getter + getterIface = reflect.TypeOf(&iface).Elem() + getterStyles = make(map[reflect.Type]int) } func itoa(i int) string { @@ -74,6 +90,50 @@ func itoa(i int) string { return strconv.Itoa(i) } +func getterStyle(outt reflect.Type) int { + getterMutex.RLock() + style := getterStyles[outt] + getterMutex.RUnlock() + if style == getterUnknown { + getterMutex.Lock() + defer getterMutex.Unlock() + if outt.Implements(getterIface) { + vt := outt + for vt.Kind() == reflect.Ptr { + vt = vt.Elem() + } + if vt.Implements(getterIface) { + getterStyles[outt] = getterTypeVal + } else { + getterStyles[outt] = getterTypePtr + } + } else if reflect.PtrTo(outt).Implements(getterIface) { + getterStyles[outt] = getterAddr + } else { + getterStyles[outt] = getterNone + } + style = getterStyles[outt] + } + return style +} + +func getGetter(outt reflect.Type, out reflect.Value) Getter { + style := getterStyle(outt) + if style == getterNone { + return nil + } + if style == getterAddr { + if !out.CanAddr() { + return nil + } + return out.Addr().Interface().(Getter) + } + if style == getterTypeVal && out.Kind() == reflect.Ptr && out.IsNil() { + return nil + } + return out.Interface().(Getter) +} + // -------------------------------------------------------------------------- // Marshaling of the document value itself. @@ -251,7 +311,7 @@ func (e *encoder) addElem(name string, v reflect.Value, minSize bool) { return } - if getter, ok := v.Interface().(Getter); ok { + if getter := getGetter(v.Type(), v); getter != nil { getv, err := getter.GetBSON() if err != nil { panic(err)