Skip to content

Commit 8e61246

Browse files
GODRIVER-2311 Ensure unmarshaled BSON values always use distinct unde… (#892)
Co-authored-by: Matt Dale <[email protected]>
1 parent 4a387b8 commit 8e61246

File tree

2 files changed

+263
-1
lines changed

2 files changed

+263
-1
lines changed

bson/bsonrw/value_reader.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,9 +384,13 @@ func (vr *valueReader) ReadBinary() (b []byte, btype byte, err error) {
384384
if err != nil {
385385
return nil, 0, err
386386
}
387+
// Make a copy of the returned byte slice because it's just a subslice from the valueReader's
388+
// buffer and is not safe to return in the unmarshaled value.
389+
cp := make([]byte, len(b))
390+
copy(cp, b)
387391

388392
vr.pop()
389-
return b, btype, nil
393+
return cp, btype, nil
390394
}
391395

392396
func (vr *valueReader) ReadBoolean() (bool, error) {
@@ -737,6 +741,9 @@ func (vr *valueReader) ReadValue() (ValueReader, error) {
737741
return vr, nil
738742
}
739743

744+
// readBytes reads length bytes from the valueReader starting at the current offset. Note that the
745+
// returned byte slice is a subslice from the valueReader buffer and must be converted or copied
746+
// before returning in an unmarshaled value.
740747
func (vr *valueReader) readBytes(length int32) ([]byte, error) {
741748
if length < 0 {
742749
return nil, fmt.Errorf("invalid length: %d", length)
@@ -748,6 +755,7 @@ func (vr *valueReader) readBytes(length int32) ([]byte, error) {
748755

749756
start := vr.offset
750757
vr.offset += int64(length)
758+
751759
return vr.d[start : start+int64(length)], nil
752760
}
753761

bson/unmarshal_test.go

Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,15 @@
77
package bson
88

99
import (
10+
"crypto/rand"
1011
"reflect"
1112
"testing"
13+
"unsafe"
1214

1315
"github.com/google/go-cmp/cmp"
1416
"go.mongodb.org/mongo-driver/bson/bsoncodec"
1517
"go.mongodb.org/mongo-driver/bson/bsonrw"
18+
"go.mongodb.org/mongo-driver/bson/primitive"
1619
"go.mongodb.org/mongo-driver/internal/testutil/assert"
1720
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
1821
)
@@ -522,3 +525,254 @@ func TestUnmarshalBSONWithUndefinedField(t *testing.T) {
522525
})
523526
}
524527
}
528+
529+
// GODRIVER-2311
530+
// Assert that unmarshaled values containing byte slices do not reference the same underlying byte
531+
// array as the BSON input data byte slice.
532+
func TestUnmarshalByteSlicesUseDistinctArrays(t *testing.T) {
533+
type fooBytes struct {
534+
Foo []byte
535+
}
536+
537+
type myBytes []byte
538+
type fooMyBytes struct {
539+
Foo myBytes
540+
}
541+
542+
type fooBinary struct {
543+
Foo primitive.Binary
544+
}
545+
546+
type fooObjectID struct {
547+
Foo primitive.ObjectID
548+
}
549+
550+
type fooDBPointer struct {
551+
Foo primitive.DBPointer
552+
}
553+
554+
testCases := []struct {
555+
description string
556+
data []byte
557+
sType reflect.Type
558+
want interface{}
559+
560+
// getByteSlice returns the byte slice from the unmarshaled value, allowing the test to
561+
// inspect the addresses of the underlying byte array.
562+
getByteSlice func(interface{}) []byte
563+
}{
564+
{
565+
description: "struct with byte slice",
566+
data: docToBytes(fooBytes{
567+
Foo: []byte{0, 1, 2, 3, 4, 5},
568+
}),
569+
sType: reflect.TypeOf(fooBytes{}),
570+
want: &fooBytes{
571+
Foo: []byte{0, 1, 2, 3, 4, 5},
572+
},
573+
getByteSlice: func(val interface{}) []byte {
574+
return (*(val.(*fooBytes))).Foo
575+
},
576+
},
577+
{
578+
description: "bson.D with byte slice",
579+
data: docToBytes(D{
580+
{"foo", []byte{0, 1, 2, 3, 4, 5}},
581+
}),
582+
sType: reflect.TypeOf(D{}),
583+
want: &D{
584+
{"foo", primitive.Binary{Subtype: 0, Data: []byte{0, 1, 2, 3, 4, 5}}},
585+
},
586+
getByteSlice: func(val interface{}) []byte {
587+
return (*(val.(*D)))[0].Value.(primitive.Binary).Data
588+
},
589+
},
590+
{
591+
description: "struct with custom byte slice type",
592+
data: docToBytes(fooMyBytes{
593+
Foo: myBytes{0, 1, 2, 3, 4, 5},
594+
}),
595+
sType: reflect.TypeOf(fooMyBytes{}),
596+
want: &fooMyBytes{
597+
Foo: myBytes{0, 1, 2, 3, 4, 5},
598+
},
599+
getByteSlice: func(val interface{}) []byte {
600+
return (*(val.(*fooMyBytes))).Foo
601+
},
602+
},
603+
{
604+
description: "bson.D with custom byte slice type",
605+
data: docToBytes(D{
606+
{"foo", myBytes{0, 1, 2, 3, 4, 5}},
607+
}),
608+
sType: reflect.TypeOf(D{}),
609+
want: &D{
610+
{"foo", primitive.Binary{Subtype: 0, Data: myBytes{0, 1, 2, 3, 4, 5}}},
611+
},
612+
getByteSlice: func(val interface{}) []byte {
613+
return (*(val.(*D)))[0].Value.(primitive.Binary).Data
614+
},
615+
},
616+
{
617+
description: "struct with primitive.Binary",
618+
data: docToBytes(fooBinary{
619+
Foo: primitive.Binary{Subtype: 0, Data: []byte{0, 1, 2, 3, 4, 5}},
620+
}),
621+
sType: reflect.TypeOf(fooBinary{}),
622+
want: &fooBinary{
623+
Foo: primitive.Binary{Subtype: 0, Data: []byte{0, 1, 2, 3, 4, 5}},
624+
},
625+
getByteSlice: func(val interface{}) []byte {
626+
return (*(val.(*fooBinary))).Foo.Data
627+
},
628+
},
629+
{
630+
description: "bson.D with primitive.Binary",
631+
data: docToBytes(D{
632+
{"foo", primitive.Binary{Subtype: 0, Data: []byte{0, 1, 2, 3, 4, 5}}},
633+
}),
634+
sType: reflect.TypeOf(D{}),
635+
want: &D{
636+
{"foo", primitive.Binary{Subtype: 0, Data: []byte{0, 1, 2, 3, 4, 5}}},
637+
},
638+
getByteSlice: func(val interface{}) []byte {
639+
return (*(val.(*D)))[0].Value.(primitive.Binary).Data
640+
},
641+
},
642+
{
643+
description: "struct with primitive.ObjectID",
644+
data: docToBytes(fooObjectID{
645+
Foo: primitive.ObjectID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
646+
}),
647+
sType: reflect.TypeOf(fooObjectID{}),
648+
want: &fooObjectID{
649+
Foo: primitive.ObjectID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
650+
},
651+
getByteSlice: func(val interface{}) []byte {
652+
return (*(val.(*fooObjectID))).Foo[:]
653+
},
654+
},
655+
{
656+
description: "bson.D with primitive.ObjectID",
657+
data: docToBytes(D{
658+
{"foo", primitive.ObjectID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}},
659+
}),
660+
sType: reflect.TypeOf(D{}),
661+
want: &D{
662+
{"foo", primitive.ObjectID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}},
663+
},
664+
getByteSlice: func(val interface{}) []byte {
665+
oid := (*(val.(*D)))[0].Value.(primitive.ObjectID)
666+
return oid[:]
667+
},
668+
},
669+
{
670+
description: "struct with primitive.DBPointer",
671+
data: docToBytes(fooDBPointer{
672+
Foo: primitive.DBPointer{
673+
DB: "test",
674+
Pointer: primitive.ObjectID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
675+
},
676+
}),
677+
sType: reflect.TypeOf(fooDBPointer{}),
678+
want: &fooDBPointer{
679+
Foo: primitive.DBPointer{
680+
DB: "test",
681+
Pointer: primitive.ObjectID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
682+
},
683+
},
684+
getByteSlice: func(val interface{}) []byte {
685+
return (*(val.(*fooDBPointer))).Foo.Pointer[:]
686+
},
687+
},
688+
{
689+
description: "bson.D with primitive.DBPointer",
690+
data: docToBytes(D{
691+
{"foo", primitive.DBPointer{
692+
DB: "test",
693+
Pointer: primitive.ObjectID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
694+
}},
695+
}),
696+
sType: reflect.TypeOf(D{}),
697+
want: &D{
698+
{"foo", primitive.DBPointer{
699+
DB: "test",
700+
Pointer: primitive.ObjectID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
701+
}},
702+
},
703+
getByteSlice: func(val interface{}) []byte {
704+
oid := (*(val.(*D)))[0].Value.(primitive.DBPointer).Pointer
705+
return oid[:]
706+
},
707+
},
708+
}
709+
710+
for _, tc := range testCases {
711+
tc := tc // Capture range variable.
712+
t.Run(tc.description, func(t *testing.T) {
713+
t.Parallel()
714+
715+
// Make a copy of the test data so we can modify it later.
716+
data := make([]byte, len(tc.data))
717+
copy(data, tc.data)
718+
719+
// Assert that unmarshaling the input data results in the expected value.
720+
got := reflect.New(tc.sType).Interface()
721+
err := Unmarshal(data, got)
722+
noerr(t, err)
723+
assert.Equal(t, tc.want, got, "unmarshaled value does not match the expected value")
724+
725+
// Fill the input data slice with random bytes and then assert that the result still
726+
// matches the expected value.
727+
_, err = rand.Read(data)
728+
noerr(t, err)
729+
assert.Equal(t, tc.want, got, "unmarshaled value does not match expected after modifying the input bytes")
730+
731+
// Assert that the byte slice in the unmarshaled value does not share any memory
732+
// addresses with the input byte slice.
733+
assertDifferentArrays(t, data, tc.getByteSlice(got))
734+
})
735+
}
736+
}
737+
738+
// assertDifferentArrays asserts that two byte slices reference distinct memory ranges, meaning
739+
// they reference different underlying byte arrays.
740+
func assertDifferentArrays(t *testing.T, a, b []byte) {
741+
// Find the start and end memory addresses for the underlying byte array for each input byte
742+
// slice.
743+
sliceAddrRange := func(b []byte) (uintptr, uintptr) {
744+
sh := (*reflect.SliceHeader)(unsafe.Pointer(&b))
745+
return sh.Data, sh.Data + uintptr(sh.Cap-1)
746+
}
747+
aStart, aEnd := sliceAddrRange(a)
748+
bStart, bEnd := sliceAddrRange(b)
749+
750+
// If "b" starts after "a" ends or "a" starts after "b" ends, there is no overlap.
751+
if bStart > aEnd || aStart > bEnd {
752+
return
753+
}
754+
755+
// Otherwise, calculate the overlap start and end and print the memory overlap error message.
756+
min := func(a, b uintptr) uintptr {
757+
if a < b {
758+
return a
759+
}
760+
return b
761+
}
762+
max := func(a, b uintptr) uintptr {
763+
if a > b {
764+
return a
765+
}
766+
return b
767+
}
768+
overlapLow := max(aStart, bStart)
769+
overlapHigh := min(aEnd, bEnd)
770+
771+
t.Errorf("Byte slices point to the same the same underlying byte array:\n"+
772+
"\ta addresses:\t%d ... %d\n"+
773+
"\tb addresses:\t%d ... %d\n"+
774+
"\toverlap:\t%d ... %d",
775+
aStart, aEnd,
776+
bStart, bEnd,
777+
overlapLow, overlapHigh)
778+
}

0 commit comments

Comments
 (0)