Skip to content

Commit 0daa2cc

Browse files
author
Divjot Arora
committed
GODRIVER-1682 - Use typeDecoder in recursive decoders
1 parent 46d5d76 commit 0daa2cc

File tree

3 files changed

+25
-20
lines changed

3 files changed

+25
-20
lines changed

bson/bsoncodec/bsoncodec.go

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,12 +184,28 @@ var _ typeDecoder = decodeAdapter{}
184184
// decodeTypeOrValue calls decoder.decodeType is decoder is a typeDecoder. Otherwise, it allocates a new element of type
185185
// t and calls decoder.DecodeValue on it.
186186
func decodeTypeOrValue(decoder ValueDecoder, dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) {
187-
if typeDecoder, ok := decoder.(typeDecoder); ok {
188-
return typeDecoder.decodeType(dc, vr, t)
187+
td, _ := decoder.(typeDecoder)
188+
return decodeTypeOrValueWithInfo(decoder, td, dc, vr, t)
189+
}
190+
191+
func decodeTypeOrValueWithInfo(vd ValueDecoder, td typeDecoder, dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) {
192+
if td != nil {
193+
val, err := td.decodeType(dc, vr, t)
194+
if err == nil && val.Type() != t {
195+
// This conversion step is necessary for slices and maps. If a user declares variables like:
196+
//
197+
// type myBool bool
198+
// var m map[string]myBool
199+
//
200+
// and tries to decode BSON bytes into the map, the decoding will fail if this conversion is not present
201+
// because we'll try to assign a value of type bool to one of type myBool.
202+
val = val.Convert(t)
203+
}
204+
return val, err
189205
}
190206

191207
val := reflect.New(t).Elem()
192-
err := decoder.DecodeValue(dc, vr, val)
208+
err := vd.DecodeValue(dc, vr, val)
193209
return val, err
194210
}
195211

bson/bsoncodec/default_value_decoders.go

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ func (dvd DefaultValueDecoders) DDecodeValue(dc DecodeContext, vr bsonrw.ValueRe
130130
if err != nil {
131131
return err
132132
}
133-
typeDecoder, isTypeDecoder := decoder.(typeDecoder)
133+
tEmptyTypeDecoder, _ := decoder.(typeDecoder)
134134

135135
// Use the elements in the provided value if it's non nil. Otherwise, allocate a new D instance.
136136
var elems primitive.D
@@ -149,17 +149,7 @@ func (dvd DefaultValueDecoders) DDecodeValue(dc DecodeContext, vr bsonrw.ValueRe
149149
return err
150150
}
151151

152-
// Delegate out to the typeDecoder for interface{} if it exists. If not, create a new interface{} value and
153-
// delegate out to the ValueDecoder. This could be accomplished by calling decodeTypeOrValue, but this would
154-
// require casting decoder to typeDecoder for every element. Because decoder isn't changing, we can optimize and
155-
// only cast once.
156-
var elem reflect.Value
157-
if isTypeDecoder {
158-
elem, err = typeDecoder.decodeType(dc, elemVr, tEmpty)
159-
} else {
160-
elem = reflect.New(tEmpty).Elem()
161-
err = decoder.DecodeValue(dc, elemVr, elem)
162-
}
152+
elem, err := decodeTypeOrValueWithInfo(decoder, tEmptyTypeDecoder, dc, elemVr, tEmpty)
163153
if err != nil {
164154
return err
165155
}
@@ -1273,6 +1263,7 @@ func (dvd DefaultValueDecoders) decodeDefault(dc DecodeContext, vr bsonrw.ValueR
12731263
if err != nil {
12741264
return nil, err
12751265
}
1266+
eTypeDecoder, _ := decoder.(typeDecoder)
12761267

12771268
idx := 0
12781269
for {
@@ -1284,9 +1275,7 @@ func (dvd DefaultValueDecoders) decodeDefault(dc DecodeContext, vr bsonrw.ValueR
12841275
return nil, err
12851276
}
12861277

1287-
elem := reflect.New(eType).Elem()
1288-
1289-
err = decoder.DecodeValue(dc, vr, elem)
1278+
elem, err := decodeTypeOrValueWithInfo(decoder, eTypeDecoder, dc, vr, eType)
12901279
if err != nil {
12911280
return nil, newDecodeError(strconv.Itoa(idx), err)
12921281
}

bson/bsoncodec/map_codec.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ func (mc *MapCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val ref
178178
if err != nil {
179179
return err
180180
}
181+
eTypeDecoder, _ := decoder.(typeDecoder)
181182

182183
if eType == tEmpty {
183184
dc.Ancestor = val.Type()
@@ -199,8 +200,7 @@ func (mc *MapCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val ref
199200
return err
200201
}
201202

202-
elem := reflect.New(eType).Elem()
203-
err = decoder.DecodeValue(dc, vr, elem)
203+
elem, err := decodeTypeOrValueWithInfo(decoder, eTypeDecoder, dc, vr, eType)
204204
if err != nil {
205205
return newDecodeError(key, err)
206206
}

0 commit comments

Comments
 (0)