Skip to content

GODRIVER-2311 Ensure unmarshaled BSON values always use distinct unde… #892

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion bson/bsonrw/value_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -384,9 +384,13 @@ func (vr *valueReader) ReadBinary() (b []byte, btype byte, err error) {
if err != nil {
return nil, 0, err
}
// Make a copy of the returned byte slice because it's just a subslice from the valueReader's
// buffer and is not safe to return in the unmarshaled value.
cp := make([]byte, len(b))
copy(cp, b)

vr.pop()
return b, btype, nil
return cp, btype, nil
}

func (vr *valueReader) ReadBoolean() (bool, error) {
Expand Down Expand Up @@ -737,6 +741,9 @@ func (vr *valueReader) ReadValue() (ValueReader, error) {
return vr, nil
}

// readBytes reads length bytes from the valueReader starting at the current offset. Note that the
// returned byte slice is a subslice from the valueReader buffer and must be converted or copied
// before returning in an unmarshaled value.
func (vr *valueReader) readBytes(length int32) ([]byte, error) {
if length < 0 {
return nil, fmt.Errorf("invalid length: %d", length)
Expand All @@ -748,6 +755,7 @@ func (vr *valueReader) readBytes(length int32) ([]byte, error) {

start := vr.offset
vr.offset += int64(length)

return vr.d[start : start+int64(length)], nil
}

Expand Down
254 changes: 254 additions & 0 deletions bson/unmarshal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@
package bson

import (
"crypto/rand"
"reflect"
"testing"
"unsafe"

"github.com/google/go-cmp/cmp"
"go.mongodb.org/mongo-driver/bson/bsoncodec"
"go.mongodb.org/mongo-driver/bson/bsonrw"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/internal/testutil/assert"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
)
Expand Down Expand Up @@ -522,3 +525,254 @@ func TestUnmarshalBSONWithUndefinedField(t *testing.T) {
})
}
}

// GODRIVER-2311
// Assert that unmarshaled values containing byte slices do not reference the same underlying byte
// array as the BSON input data byte slice.
func TestUnmarshalByteSlicesUseDistinctArrays(t *testing.T) {
type fooBytes struct {
Foo []byte
}

type myBytes []byte
type fooMyBytes struct {
Foo myBytes
}

type fooBinary struct {
Foo primitive.Binary
}

type fooObjectID struct {
Foo primitive.ObjectID
}

type fooDBPointer struct {
Foo primitive.DBPointer
}

testCases := []struct {
description string
data []byte
sType reflect.Type
want interface{}

// getByteSlice returns the byte slice from the unmarshaled value, allowing the test to
// inspect the addresses of the underlying byte array.
getByteSlice func(interface{}) []byte
}{
{
description: "struct with byte slice",
data: docToBytes(fooBytes{
Foo: []byte{0, 1, 2, 3, 4, 5},
}),
sType: reflect.TypeOf(fooBytes{}),
want: &fooBytes{
Foo: []byte{0, 1, 2, 3, 4, 5},
},
getByteSlice: func(val interface{}) []byte {
return (*(val.(*fooBytes))).Foo
},
},
{
description: "bson.D with byte slice",
data: docToBytes(D{
{"foo", []byte{0, 1, 2, 3, 4, 5}},
}),
sType: reflect.TypeOf(D{}),
want: &D{
{"foo", primitive.Binary{Subtype: 0, Data: []byte{0, 1, 2, 3, 4, 5}}},
},
getByteSlice: func(val interface{}) []byte {
return (*(val.(*D)))[0].Value.(primitive.Binary).Data
},
},
{
description: "struct with custom byte slice type",
data: docToBytes(fooMyBytes{
Foo: myBytes{0, 1, 2, 3, 4, 5},
}),
sType: reflect.TypeOf(fooMyBytes{}),
want: &fooMyBytes{
Foo: myBytes{0, 1, 2, 3, 4, 5},
},
getByteSlice: func(val interface{}) []byte {
return (*(val.(*fooMyBytes))).Foo
},
},
{
description: "bson.D with custom byte slice type",
data: docToBytes(D{
{"foo", myBytes{0, 1, 2, 3, 4, 5}},
}),
sType: reflect.TypeOf(D{}),
want: &D{
{"foo", primitive.Binary{Subtype: 0, Data: myBytes{0, 1, 2, 3, 4, 5}}},
},
getByteSlice: func(val interface{}) []byte {
return (*(val.(*D)))[0].Value.(primitive.Binary).Data
},
},
{
description: "struct with primitive.Binary",
data: docToBytes(fooBinary{
Foo: primitive.Binary{Subtype: 0, Data: []byte{0, 1, 2, 3, 4, 5}},
}),
sType: reflect.TypeOf(fooBinary{}),
want: &fooBinary{
Foo: primitive.Binary{Subtype: 0, Data: []byte{0, 1, 2, 3, 4, 5}},
},
getByteSlice: func(val interface{}) []byte {
return (*(val.(*fooBinary))).Foo.Data
},
},
{
description: "bson.D with primitive.Binary",
data: docToBytes(D{
{"foo", primitive.Binary{Subtype: 0, Data: []byte{0, 1, 2, 3, 4, 5}}},
}),
sType: reflect.TypeOf(D{}),
want: &D{
{"foo", primitive.Binary{Subtype: 0, Data: []byte{0, 1, 2, 3, 4, 5}}},
},
getByteSlice: func(val interface{}) []byte {
return (*(val.(*D)))[0].Value.(primitive.Binary).Data
},
},
{
description: "struct with primitive.ObjectID",
data: docToBytes(fooObjectID{
Foo: primitive.ObjectID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
}),
sType: reflect.TypeOf(fooObjectID{}),
want: &fooObjectID{
Foo: primitive.ObjectID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
},
getByteSlice: func(val interface{}) []byte {
return (*(val.(*fooObjectID))).Foo[:]
},
},
{
description: "bson.D with primitive.ObjectID",
data: docToBytes(D{
{"foo", primitive.ObjectID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}},
}),
sType: reflect.TypeOf(D{}),
want: &D{
{"foo", primitive.ObjectID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}},
},
getByteSlice: func(val interface{}) []byte {
oid := (*(val.(*D)))[0].Value.(primitive.ObjectID)
return oid[:]
},
},
{
description: "struct with primitive.DBPointer",
data: docToBytes(fooDBPointer{
Foo: primitive.DBPointer{
DB: "test",
Pointer: primitive.ObjectID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
},
}),
sType: reflect.TypeOf(fooDBPointer{}),
want: &fooDBPointer{
Foo: primitive.DBPointer{
DB: "test",
Pointer: primitive.ObjectID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
},
},
getByteSlice: func(val interface{}) []byte {
return (*(val.(*fooDBPointer))).Foo.Pointer[:]
},
},
{
description: "bson.D with primitive.DBPointer",
data: docToBytes(D{
{"foo", primitive.DBPointer{
DB: "test",
Pointer: primitive.ObjectID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
}},
}),
sType: reflect.TypeOf(D{}),
want: &D{
{"foo", primitive.DBPointer{
DB: "test",
Pointer: primitive.ObjectID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
}},
},
getByteSlice: func(val interface{}) []byte {
oid := (*(val.(*D)))[0].Value.(primitive.DBPointer).Pointer
return oid[:]
},
},
}

for _, tc := range testCases {
tc := tc // Capture range variable.
t.Run(tc.description, func(t *testing.T) {
t.Parallel()

// Make a copy of the test data so we can modify it later.
data := make([]byte, len(tc.data))
copy(data, tc.data)

// Assert that unmarshaling the input data results in the expected value.
got := reflect.New(tc.sType).Interface()
err := Unmarshal(data, got)
noerr(t, err)
assert.Equal(t, tc.want, got, "unmarshaled value does not match the expected value")

// Fill the input data slice with random bytes and then assert that the result still
// matches the expected value.
_, err = rand.Read(data)
noerr(t, err)
assert.Equal(t, tc.want, got, "unmarshaled value does not match expected after modifying the input bytes")

// Assert that the byte slice in the unmarshaled value does not share any memory
// addresses with the input byte slice.
assertDifferentArrays(t, data, tc.getByteSlice(got))
})
}
}

// assertDifferentArrays asserts that two byte slices reference distinct memory ranges, meaning
// they reference different underlying byte arrays.
func assertDifferentArrays(t *testing.T, a, b []byte) {
// Find the start and end memory addresses for the underlying byte array for each input byte
// slice.
sliceAddrRange := func(b []byte) (uintptr, uintptr) {
sh := (*reflect.SliceHeader)(unsafe.Pointer(&b))
return sh.Data, sh.Data + uintptr(sh.Cap-1)
}
aStart, aEnd := sliceAddrRange(a)
bStart, bEnd := sliceAddrRange(b)

// If "b" starts after "a" ends or "a" starts after "b" ends, there is no overlap.
if bStart > aEnd || aStart > bEnd {
return
}

// Otherwise, calculate the overlap start and end and print the memory overlap error message.
min := func(a, b uintptr) uintptr {
if a < b {
return a
}
return b
}
max := func(a, b uintptr) uintptr {
if a > b {
return a
}
return b
}
overlapLow := max(aStart, bStart)
overlapHigh := min(aEnd, bEnd)

t.Errorf("Byte slices point to the same the same underlying byte array:\n"+
"\ta addresses:\t%d ... %d\n"+
"\tb addresses:\t%d ... %d\n"+
"\toverlap:\t%d ... %d",
aStart, aEnd,
bStart, bEnd,
overlapLow, overlapHigh)
}