Skip to content

Commit c918d8c

Browse files
authored
GODRIVER-2252 Don't call UnmarshalBSON for pointer values if the BSON field value is empty. (#833)
1 parent c999a05 commit c918d8c

File tree

4 files changed

+192
-114
lines changed

4 files changed

+192
-114
lines changed

bson/bsoncodec/default_value_decoders.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1504,6 +1504,18 @@ func (dvd DefaultValueDecoders) UnmarshalerDecodeValue(dc DecodeContext, vr bson
15041504
return err
15051505
}
15061506

1507+
// If the target Go value is a pointer and the BSON field value is empty, set the value to the
1508+
// zero value of the pointer (nil) and don't call UnmarshalBSON. UnmarshalBSON has no way to
1509+
// change the pointer value from within the function (only the value at the pointer address),
1510+
// so it can't set the pointer to "nil" itself. Since the most common Go value for an empty BSON
1511+
// field value is "nil", we set "nil" here and don't call UnmarshalBSON. This behavior matches
1512+
// the behavior of the Go "encoding/json" unmarshaler when the target Go value is a pointer and
1513+
// the JSON field value is "null".
1514+
if val.Kind() == reflect.Ptr && len(src) == 0 {
1515+
val.Set(reflect.Zero(val.Type()))
1516+
return nil
1517+
}
1518+
15071519
fn := val.Convert(tUnmarshaler).MethodByName("UnmarshalBSON")
15081520
errVal := fn.Call([]reflect.Value{reflect.ValueOf(src)})[0]
15091521
if !errVal.IsNil() {

bson/decoder_test.go

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"testing"
1414

1515
"github.com/google/go-cmp/cmp"
16+
"github.com/stretchr/testify/assert"
1617
"go.mongodb.org/mongo-driver/bson/bsoncodec"
1718
"go.mongodb.org/mongo-driver/bson/bsonrw"
1819
"go.mongodb.org/mongo-driver/bson/bsonrw/bsonrwtest"
@@ -21,7 +22,7 @@ import (
2122
)
2223

2324
func TestBasicDecode(t *testing.T) {
24-
for _, tc := range unmarshalingTestCases {
25+
for _, tc := range unmarshalingTestCases() {
2526
t.Run(tc.name, func(t *testing.T) {
2627
got := reflect.New(tc.sType).Elem()
2728
vr := bsonrw.NewBSONDocumentReader(tc.data)
@@ -30,34 +31,22 @@ func TestBasicDecode(t *testing.T) {
3031
noerr(t, err)
3132
err = decoder.DecodeValue(bsoncodec.DecodeContext{Registry: reg}, vr, got)
3233
noerr(t, err)
33-
34-
if !reflect.DeepEqual(got.Addr().Interface(), tc.want) {
35-
t.Errorf("Results do not match. got %+v; want %+v", got, tc.want)
36-
}
34+
assert.Equal(t, tc.want, got.Addr().Interface(), "Results do not match.")
3735
})
3836
}
3937
}
4038

4139
func TestDecoderv2(t *testing.T) {
4240
t.Run("Decode", func(t *testing.T) {
43-
for _, tc := range unmarshalingTestCases {
41+
for _, tc := range unmarshalingTestCases() {
4442
t.Run(tc.name, func(t *testing.T) {
4543
got := reflect.New(tc.sType).Interface()
4644
vr := bsonrw.NewBSONDocumentReader(tc.data)
47-
var reg *bsoncodec.Registry
48-
if tc.reg != nil {
49-
reg = tc.reg
50-
} else {
51-
reg = DefaultRegistry
52-
}
53-
dec, err := NewDecoderWithContext(bsoncodec.DecodeContext{Registry: reg}, vr)
45+
dec, err := NewDecoderWithContext(bsoncodec.DecodeContext{Registry: DefaultRegistry}, vr)
5446
noerr(t, err)
5547
err = dec.Decode(got)
5648
noerr(t, err)
57-
58-
if !reflect.DeepEqual(got, tc.want) {
59-
t.Errorf("Results do not match. got %+v; want %+v", got, tc.want)
60-
}
49+
assert.Equal(t, tc.want, got, "Results do not match.")
6150
})
6251
}
6352
t.Run("lookup error", func(t *testing.T) {
@@ -70,9 +59,7 @@ func TestDecoderv2(t *testing.T) {
7059
noerr(t, err)
7160
want := bsoncodec.ErrNoDecoder{Type: reflect.TypeOf(cdeih)}
7261
got := dec.Decode(&cdeih)
73-
if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
74-
t.Errorf("Received unexpected error. got %v; want %v", got, want)
75-
}
62+
assert.Equal(t, want, got, "Received unexpected error.")
7663
})
7764
t.Run("Unmarshaler", func(t *testing.T) {
7865
testCases := []struct {
@@ -191,9 +178,7 @@ func TestDecoderv2(t *testing.T) {
191178
err = dec.Decode(&got)
192179
noerr(t, err)
193180
want := foo{Item: "canvas", Qty: 4, Bonus: 2}
194-
if !reflect.DeepEqual(got, want) {
195-
t.Errorf("Results do not match. got %+v; want %+v", got, want)
196-
}
181+
assert.Equal(t, want, got, "Results do not match.")
197182
})
198183
t.Run("Reset", func(t *testing.T) {
199184
vr1, vr2 := bsonrw.NewBSONDocumentReader([]byte{}), bsonrw.NewBSONDocumentReader([]byte{})

bson/unmarshal_test.go

Lines changed: 11 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -10,64 +10,42 @@ import (
1010
"reflect"
1111
"testing"
1212

13-
"github.com/google/go-cmp/cmp"
13+
"github.com/stretchr/testify/assert"
1414
"go.mongodb.org/mongo-driver/bson/bsoncodec"
1515
"go.mongodb.org/mongo-driver/bson/bsonrw"
16-
"go.mongodb.org/mongo-driver/internal/testutil/assert"
1716
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
1817
)
1918

2019
func TestUnmarshal(t *testing.T) {
21-
for _, tc := range unmarshalingTestCases {
20+
for _, tc := range unmarshalingTestCases() {
2221
t.Run(tc.name, func(t *testing.T) {
23-
if tc.reg != nil {
24-
t.Skip() // test requires custom registry
25-
}
2622
got := reflect.New(tc.sType).Interface()
2723
err := Unmarshal(tc.data, got)
2824
noerr(t, err)
29-
if !cmp.Equal(got, tc.want) {
30-
t.Errorf("Did not unmarshal as expected. got %v; want %v", got, tc.want)
31-
}
25+
assert.Equal(t, tc.want, got, "Did not unmarshal as expected.")
3226
})
3327
}
3428
}
3529

3630
func TestUnmarshalWithRegistry(t *testing.T) {
37-
for _, tc := range unmarshalingTestCases {
31+
for _, tc := range unmarshalingTestCases() {
3832
t.Run(tc.name, func(t *testing.T) {
39-
var reg *bsoncodec.Registry
40-
if tc.reg != nil {
41-
reg = tc.reg
42-
} else {
43-
reg = DefaultRegistry
44-
}
4533
got := reflect.New(tc.sType).Interface()
46-
err := UnmarshalWithRegistry(reg, tc.data, got)
34+
err := UnmarshalWithRegistry(DefaultRegistry, tc.data, got)
4735
noerr(t, err)
48-
if !cmp.Equal(got, tc.want) {
49-
t.Errorf("Did not unmarshal as expected. got %v; want %v", got, tc.want)
50-
}
36+
assert.Equal(t, tc.want, got, "Did not unmarshal as expected.")
5137
})
5238
}
5339
}
5440

5541
func TestUnmarshalWithContext(t *testing.T) {
56-
for _, tc := range unmarshalingTestCases {
42+
for _, tc := range unmarshalingTestCases() {
5743
t.Run(tc.name, func(t *testing.T) {
58-
var reg *bsoncodec.Registry
59-
if tc.reg != nil {
60-
reg = tc.reg
61-
} else {
62-
reg = DefaultRegistry
63-
}
64-
dc := bsoncodec.DecodeContext{Registry: reg}
44+
dc := bsoncodec.DecodeContext{Registry: DefaultRegistry}
6545
got := reflect.New(tc.sType).Interface()
6646
err := UnmarshalWithContext(dc, tc.data, got)
6747
noerr(t, err)
68-
if !cmp.Equal(got, tc.want) {
69-
t.Errorf("Did not unmarshal as expected. got %v; want %v", got, tc.want)
70-
}
48+
assert.Equal(t, tc.want, got, "Did not unmarshal as expected.")
7149
})
7250
}
7351
}
@@ -80,9 +58,7 @@ func TestUnmarshalExtJSONWithRegistry(t *testing.T) {
8058
err := UnmarshalExtJSONWithRegistry(DefaultRegistry, data, true, &got)
8159
noerr(t, err)
8260
want := teststruct{1}
83-
if !cmp.Equal(got, want) {
84-
t.Errorf("Did not unmarshal as expected. got %v; want %v", got, want)
85-
}
61+
assert.Equal(t, want, got, "Did not unmarshal as expected.")
8662
})
8763

8864
t.Run("UnmarshalExtJSONInvalidInput", func(t *testing.T) {
@@ -165,9 +141,7 @@ func TestUnmarshalExtJSONWithContext(t *testing.T) {
165141
dc := bsoncodec.DecodeContext{Registry: DefaultRegistry}
166142
err := UnmarshalExtJSONWithContext(dc, tc.data, true, got)
167143
noerr(t, err)
168-
if !cmp.Equal(got, tc.want) {
169-
t.Errorf("Did not unmarshal as expected. got %+v; want %+v", got, tc.want)
170-
}
144+
assert.Equal(t, tc.want, got, "Did not unmarshal as expected.")
171145
})
172146
}
173147
}

0 commit comments

Comments
 (0)