Skip to content

Commit 488c521

Browse files
author
Divjot Arora
authored
GODRIVER-1682 - Use typeDecoder in recursive decoders (#465)
1 parent ca59ef9 commit 488c521

File tree

4 files changed

+76
-20
lines changed

4 files changed

+76
-20
lines changed

bson/bsoncodec/bsoncodec.go

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,12 +184,28 @@ var _ typeDecoder = decodeAdapter{}
184184
// decodeTypeOrValue calls decoder.decodeType is decoder is a typeDecoder. Otherwise, it allocates a new element of type
185185
// t and calls decoder.DecodeValue on it.
186186
func decodeTypeOrValue(decoder ValueDecoder, dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) {
187-
if typeDecoder, ok := decoder.(typeDecoder); ok {
188-
return typeDecoder.decodeType(dc, vr, t)
187+
td, _ := decoder.(typeDecoder)
188+
return decodeTypeOrValueWithInfo(decoder, td, dc, vr, t, true)
189+
}
190+
191+
func decodeTypeOrValueWithInfo(vd ValueDecoder, td typeDecoder, dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type, convert bool) (reflect.Value, error) {
192+
if td != nil {
193+
val, err := td.decodeType(dc, vr, t)
194+
if err == nil && convert && val.Type() != t {
195+
// This conversion step is necessary for slices and maps. If a user declares variables like:
196+
//
197+
// type myBool bool
198+
// var m map[string]myBool
199+
//
200+
// and tries to decode BSON bytes into the map, the decoding will fail if this conversion is not present
201+
// because we'll try to assign a value of type bool to one of type myBool.
202+
val = val.Convert(t)
203+
}
204+
return val, err
189205
}
190206

191207
val := reflect.New(t).Elem()
192-
err := decoder.DecodeValue(dc, vr, val)
208+
err := vd.DecodeValue(dc, vr, val)
193209
return val, err
194210
}
195211

bson/bsoncodec/default_value_decoders.go

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ func (dvd DefaultValueDecoders) DDecodeValue(dc DecodeContext, vr bsonrw.ValueRe
145145
if err != nil {
146146
return err
147147
}
148-
typeDecoder, isTypeDecoder := decoder.(typeDecoder)
148+
tEmptyTypeDecoder, _ := decoder.(typeDecoder)
149149

150150
// Use the elements in the provided value if it's non nil. Otherwise, allocate a new D instance.
151151
var elems primitive.D
@@ -164,17 +164,8 @@ func (dvd DefaultValueDecoders) DDecodeValue(dc DecodeContext, vr bsonrw.ValueRe
164164
return err
165165
}
166166

167-
// Delegate out to the typeDecoder for interface{} if it exists. If not, create a new interface{} value and
168-
// delegate out to the ValueDecoder. This could be accomplished by calling decodeTypeOrValue, but this would
169-
// require casting decoder to typeDecoder for every element. Because decoder isn't changing, we can optimize and
170-
// only cast once.
171-
var elem reflect.Value
172-
if isTypeDecoder {
173-
elem, err = typeDecoder.decodeType(dc, elemVr, tEmpty)
174-
} else {
175-
elem = reflect.New(tEmpty).Elem()
176-
err = decoder.DecodeValue(dc, elemVr, elem)
177-
}
167+
// Pass false for convert because we don't need to call reflect.Value.Convert for tEmpty.
168+
elem, err := decodeTypeOrValueWithInfo(decoder, tEmptyTypeDecoder, dc, elemVr, tEmpty, false)
178169
if err != nil {
179170
return err
180171
}
@@ -1577,6 +1568,7 @@ func (dvd DefaultValueDecoders) decodeDefault(dc DecodeContext, vr bsonrw.ValueR
15771568
if err != nil {
15781569
return nil, err
15791570
}
1571+
eTypeDecoder, _ := decoder.(typeDecoder)
15801572

15811573
idx := 0
15821574
for {
@@ -1588,9 +1580,7 @@ func (dvd DefaultValueDecoders) decodeDefault(dc DecodeContext, vr bsonrw.ValueR
15881580
return nil, err
15891581
}
15901582

1591-
elem := reflect.New(eType).Elem()
1592-
1593-
err = decoder.DecodeValue(dc, vr, elem)
1583+
elem, err := decodeTypeOrValueWithInfo(decoder, eTypeDecoder, dc, vr, eType, true)
15941584
if err != nil {
15951585
return nil, newDecodeError(strconv.Itoa(idx), err)
15961586
}

bson/bsoncodec/default_value_decoders_test.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3558,6 +3558,56 @@ func TestDefaultValueDecoders(t *testing.T) {
35583558
"expected error %v to contain key pattern %s", decodeErr, keyPath)
35593559
})
35603560
})
3561+
3562+
t.Run("values are converted", func(t *testing.T) {
3563+
// When decoding into a D or M, values must be converted if they are not being decoded to the default type.
3564+
3565+
t.Run("D", func(t *testing.T) {
3566+
trueValue := bsoncore.Value{
3567+
Type: bsontype.Boolean,
3568+
Data: bsoncore.AppendBoolean(nil, true),
3569+
}
3570+
docBytes := bsoncore.BuildDocumentFromElements(nil,
3571+
bsoncore.AppendBooleanElement(nil, "bool", true),
3572+
bsoncore.BuildArrayElement(nil, "boolArray", trueValue),
3573+
)
3574+
3575+
rb := NewRegistryBuilder()
3576+
defaultValueDecoders.RegisterDefaultDecoders(rb)
3577+
reg := rb.RegisterTypeMapEntry(bsontype.Boolean, reflect.TypeOf(mybool(true))).Build()
3578+
3579+
dc := DecodeContext{Registry: reg}
3580+
vr := bsonrw.NewBSONDocumentReader(docBytes)
3581+
val := reflect.New(tD).Elem()
3582+
err := defaultValueDecoders.DDecodeValue(dc, vr, val)
3583+
assert.Nil(t, err, "DDecodeValue error: %v", err)
3584+
3585+
want := primitive.D{
3586+
{"bool", mybool(true)},
3587+
{"boolArray", primitive.A{mybool(true)}},
3588+
}
3589+
got := val.Interface().(primitive.D)
3590+
assert.Equal(t, want, got, "want document %v, got %v", want, got)
3591+
})
3592+
t.Run("M", func(t *testing.T) {
3593+
docBytes := bsoncore.BuildDocumentFromElements(nil,
3594+
bsoncore.AppendBooleanElement(nil, "bool", true),
3595+
)
3596+
3597+
type myMap map[string]mybool
3598+
dc := DecodeContext{Registry: buildDefaultRegistry()}
3599+
vr := bsonrw.NewBSONDocumentReader(docBytes)
3600+
val := reflect.New(reflect.TypeOf(myMap{})).Elem()
3601+
err := defaultMapCodec.DecodeValue(dc, vr, val)
3602+
assert.Nil(t, err, "DecodeValue error: %v", err)
3603+
3604+
want := myMap{
3605+
"bool": mybool(true),
3606+
}
3607+
got := val.Interface().(myMap)
3608+
assert.Equal(t, want, got, "expected map %v, got %v", want, got)
3609+
})
3610+
})
35613611
}
35623612

35633613
type testValueUnmarshaler struct {

bson/bsoncodec/map_codec.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ func (mc *MapCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val ref
178178
if err != nil {
179179
return err
180180
}
181+
eTypeDecoder, _ := decoder.(typeDecoder)
181182

182183
if eType == tEmpty {
183184
dc.Ancestor = val.Type()
@@ -199,8 +200,7 @@ func (mc *MapCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val ref
199200
return err
200201
}
201202

202-
elem := reflect.New(eType).Elem()
203-
err = decoder.DecodeValue(dc, vr, elem)
203+
elem, err := decodeTypeOrValueWithInfo(decoder, eTypeDecoder, dc, vr, eType, true)
204204
if err != nil {
205205
return newDecodeError(key, err)
206206
}

0 commit comments

Comments
 (0)