Skip to content

GODRIVER-1682 - Use typeDecoder in recursive decoders #465

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions bson/bsoncodec/bsoncodec.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,12 +184,28 @@ var _ typeDecoder = decodeAdapter{}
// decodeTypeOrValue calls decoder.decodeType is decoder is a typeDecoder. Otherwise, it allocates a new element of type
// t and calls decoder.DecodeValue on it.
func decodeTypeOrValue(decoder ValueDecoder, dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) {
if typeDecoder, ok := decoder.(typeDecoder); ok {
return typeDecoder.decodeType(dc, vr, t)
td, _ := decoder.(typeDecoder)
return decodeTypeOrValueWithInfo(decoder, td, dc, vr, t, true)
}

func decodeTypeOrValueWithInfo(vd ValueDecoder, td typeDecoder, dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type, convert bool) (reflect.Value, error) {
if td != nil {
val, err := td.decodeType(dc, vr, t)
if err == nil && convert && val.Type() != t {
// This conversion step is necessary for slices and maps. If a user declares variables like:
//
// type myBool bool
// var m map[string]myBool
//
// and tries to decode BSON bytes into the map, the decoding will fail if this conversion is not present
// because we'll try to assign a value of type bool to one of type myBool.
val = val.Convert(t)
}
return val, err
}

val := reflect.New(t).Elem()
err := decoder.DecodeValue(dc, vr, val)
err := vd.DecodeValue(dc, vr, val)
return val, err
}

Expand Down
20 changes: 5 additions & 15 deletions bson/bsoncodec/default_value_decoders.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ func (dvd DefaultValueDecoders) DDecodeValue(dc DecodeContext, vr bsonrw.ValueRe
if err != nil {
return err
}
typeDecoder, isTypeDecoder := decoder.(typeDecoder)
tEmptyTypeDecoder, _ := decoder.(typeDecoder)

// Use the elements in the provided value if it's non nil. Otherwise, allocate a new D instance.
var elems primitive.D
Expand All @@ -164,17 +164,8 @@ func (dvd DefaultValueDecoders) DDecodeValue(dc DecodeContext, vr bsonrw.ValueRe
return err
}

// Delegate out to the typeDecoder for interface{} if it exists. If not, create a new interface{} value and
// delegate out to the ValueDecoder. This could be accomplished by calling decodeTypeOrValue, but this would
// require casting decoder to typeDecoder for every element. Because decoder isn't changing, we can optimize and
// only cast once.
var elem reflect.Value
if isTypeDecoder {
elem, err = typeDecoder.decodeType(dc, elemVr, tEmpty)
} else {
elem = reflect.New(tEmpty).Elem()
err = decoder.DecodeValue(dc, elemVr, elem)
}
// Pass false for convert because we don't need to call reflect.Value.Convert for tEmpty.
elem, err := decodeTypeOrValueWithInfo(decoder, tEmptyTypeDecoder, dc, elemVr, tEmpty, false)
if err != nil {
return err
}
Expand Down Expand Up @@ -1577,6 +1568,7 @@ func (dvd DefaultValueDecoders) decodeDefault(dc DecodeContext, vr bsonrw.ValueR
if err != nil {
return nil, err
}
eTypeDecoder, _ := decoder.(typeDecoder)

idx := 0
for {
Expand All @@ -1588,9 +1580,7 @@ func (dvd DefaultValueDecoders) decodeDefault(dc DecodeContext, vr bsonrw.ValueR
return nil, err
}

elem := reflect.New(eType).Elem()

err = decoder.DecodeValue(dc, vr, elem)
elem, err := decodeTypeOrValueWithInfo(decoder, eTypeDecoder, dc, vr, eType, true)
if err != nil {
return nil, newDecodeError(strconv.Itoa(idx), err)
}
Expand Down
50 changes: 50 additions & 0 deletions bson/bsoncodec/default_value_decoders_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3558,6 +3558,56 @@ func TestDefaultValueDecoders(t *testing.T) {
"expected error %v to contain key pattern %s", decodeErr, keyPath)
})
})

t.Run("values are converted", func(t *testing.T) {
// When decoding into a D or M, values must be converted if they are not being decoded to the default type.

t.Run("D", func(t *testing.T) {
trueValue := bsoncore.Value{
Type: bsontype.Boolean,
Data: bsoncore.AppendBoolean(nil, true),
}
docBytes := bsoncore.BuildDocumentFromElements(nil,
bsoncore.AppendBooleanElement(nil, "bool", true),
bsoncore.BuildArrayElement(nil, "boolArray", trueValue),
)

rb := NewRegistryBuilder()
defaultValueDecoders.RegisterDefaultDecoders(rb)
reg := rb.RegisterTypeMapEntry(bsontype.Boolean, reflect.TypeOf(mybool(true))).Build()

dc := DecodeContext{Registry: reg}
vr := bsonrw.NewBSONDocumentReader(docBytes)
val := reflect.New(tD).Elem()
err := defaultValueDecoders.DDecodeValue(dc, vr, val)
assert.Nil(t, err, "DDecodeValue error: %v", err)

want := primitive.D{
{"bool", mybool(true)},
{"boolArray", primitive.A{mybool(true)}},
}
got := val.Interface().(primitive.D)
assert.Equal(t, want, got, "want document %v, got %v", want, got)
})
t.Run("M", func(t *testing.T) {
docBytes := bsoncore.BuildDocumentFromElements(nil,
bsoncore.AppendBooleanElement(nil, "bool", true),
)

type myMap map[string]mybool
dc := DecodeContext{Registry: buildDefaultRegistry()}
vr := bsonrw.NewBSONDocumentReader(docBytes)
val := reflect.New(reflect.TypeOf(myMap{})).Elem()
err := defaultMapCodec.DecodeValue(dc, vr, val)
assert.Nil(t, err, "DecodeValue error: %v", err)

want := myMap{
"bool": mybool(true),
}
got := val.Interface().(myMap)
assert.Equal(t, want, got, "expected map %v, got %v", want, got)
})
})
}

type testValueUnmarshaler struct {
Expand Down
4 changes: 2 additions & 2 deletions bson/bsoncodec/map_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ func (mc *MapCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val ref
if err != nil {
return err
}
eTypeDecoder, _ := decoder.(typeDecoder)

if eType == tEmpty {
dc.Ancestor = val.Type()
Expand All @@ -199,8 +200,7 @@ func (mc *MapCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val ref
return err
}

elem := reflect.New(eType).Elem()
err = decoder.DecodeValue(dc, vr, elem)
elem, err := decodeTypeOrValueWithInfo(decoder, eTypeDecoder, dc, vr, eType, true)
if err != nil {
return newDecodeError(key, err)
}
Expand Down