|
7 | 7 | package bson
|
8 | 8 |
|
9 | 9 | import (
|
| 10 | + "crypto/rand" |
10 | 11 | "reflect"
|
11 | 12 | "testing"
|
| 13 | + "unsafe" |
12 | 14 |
|
13 | 15 | "github.com/google/go-cmp/cmp"
|
14 | 16 | "go.mongodb.org/mongo-driver/bson/bsoncodec"
|
15 | 17 | "go.mongodb.org/mongo-driver/bson/bsonrw"
|
| 18 | + "go.mongodb.org/mongo-driver/bson/primitive" |
16 | 19 | "go.mongodb.org/mongo-driver/internal/testutil/assert"
|
17 | 20 | "go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
|
18 | 21 | )
|
@@ -522,3 +525,254 @@ func TestUnmarshalBSONWithUndefinedField(t *testing.T) {
|
522 | 525 | })
|
523 | 526 | }
|
524 | 527 | }
|
| 528 | + |
| 529 | +// GODRIVER-2311 |
| 530 | +// Assert that unmarshaled values containing byte slices do not reference the same underlying byte |
| 531 | +// array as the BSON input data byte slice. |
| 532 | +func TestUnmarshalByteSlicesUseDistinctArrays(t *testing.T) { |
| 533 | + type fooBytes struct { |
| 534 | + Foo []byte |
| 535 | + } |
| 536 | + |
| 537 | + type myBytes []byte |
| 538 | + type fooMyBytes struct { |
| 539 | + Foo myBytes |
| 540 | + } |
| 541 | + |
| 542 | + type fooBinary struct { |
| 543 | + Foo primitive.Binary |
| 544 | + } |
| 545 | + |
| 546 | + type fooObjectID struct { |
| 547 | + Foo primitive.ObjectID |
| 548 | + } |
| 549 | + |
| 550 | + type fooDBPointer struct { |
| 551 | + Foo primitive.DBPointer |
| 552 | + } |
| 553 | + |
| 554 | + testCases := []struct { |
| 555 | + description string |
| 556 | + data []byte |
| 557 | + sType reflect.Type |
| 558 | + want interface{} |
| 559 | + |
| 560 | + // getByteSlice returns the byte slice from the unmarshaled value, allowing the test to |
| 561 | + // inspect the addresses of the underlying byte array. |
| 562 | + getByteSlice func(interface{}) []byte |
| 563 | + }{ |
| 564 | + { |
| 565 | + description: "struct with byte slice", |
| 566 | + data: docToBytes(fooBytes{ |
| 567 | + Foo: []byte{0, 1, 2, 3, 4, 5}, |
| 568 | + }), |
| 569 | + sType: reflect.TypeOf(fooBytes{}), |
| 570 | + want: &fooBytes{ |
| 571 | + Foo: []byte{0, 1, 2, 3, 4, 5}, |
| 572 | + }, |
| 573 | + getByteSlice: func(val interface{}) []byte { |
| 574 | + return (*(val.(*fooBytes))).Foo |
| 575 | + }, |
| 576 | + }, |
| 577 | + { |
| 578 | + description: "bson.D with byte slice", |
| 579 | + data: docToBytes(D{ |
| 580 | + {"foo", []byte{0, 1, 2, 3, 4, 5}}, |
| 581 | + }), |
| 582 | + sType: reflect.TypeOf(D{}), |
| 583 | + want: &D{ |
| 584 | + {"foo", primitive.Binary{Subtype: 0, Data: []byte{0, 1, 2, 3, 4, 5}}}, |
| 585 | + }, |
| 586 | + getByteSlice: func(val interface{}) []byte { |
| 587 | + return (*(val.(*D)))[0].Value.(primitive.Binary).Data |
| 588 | + }, |
| 589 | + }, |
| 590 | + { |
| 591 | + description: "struct with custom byte slice type", |
| 592 | + data: docToBytes(fooMyBytes{ |
| 593 | + Foo: myBytes{0, 1, 2, 3, 4, 5}, |
| 594 | + }), |
| 595 | + sType: reflect.TypeOf(fooMyBytes{}), |
| 596 | + want: &fooMyBytes{ |
| 597 | + Foo: myBytes{0, 1, 2, 3, 4, 5}, |
| 598 | + }, |
| 599 | + getByteSlice: func(val interface{}) []byte { |
| 600 | + return (*(val.(*fooMyBytes))).Foo |
| 601 | + }, |
| 602 | + }, |
| 603 | + { |
| 604 | + description: "bson.D with custom byte slice type", |
| 605 | + data: docToBytes(D{ |
| 606 | + {"foo", myBytes{0, 1, 2, 3, 4, 5}}, |
| 607 | + }), |
| 608 | + sType: reflect.TypeOf(D{}), |
| 609 | + want: &D{ |
| 610 | + {"foo", primitive.Binary{Subtype: 0, Data: myBytes{0, 1, 2, 3, 4, 5}}}, |
| 611 | + }, |
| 612 | + getByteSlice: func(val interface{}) []byte { |
| 613 | + return (*(val.(*D)))[0].Value.(primitive.Binary).Data |
| 614 | + }, |
| 615 | + }, |
| 616 | + { |
| 617 | + description: "struct with primitive.Binary", |
| 618 | + data: docToBytes(fooBinary{ |
| 619 | + Foo: primitive.Binary{Subtype: 0, Data: []byte{0, 1, 2, 3, 4, 5}}, |
| 620 | + }), |
| 621 | + sType: reflect.TypeOf(fooBinary{}), |
| 622 | + want: &fooBinary{ |
| 623 | + Foo: primitive.Binary{Subtype: 0, Data: []byte{0, 1, 2, 3, 4, 5}}, |
| 624 | + }, |
| 625 | + getByteSlice: func(val interface{}) []byte { |
| 626 | + return (*(val.(*fooBinary))).Foo.Data |
| 627 | + }, |
| 628 | + }, |
| 629 | + { |
| 630 | + description: "bson.D with primitive.Binary", |
| 631 | + data: docToBytes(D{ |
| 632 | + {"foo", primitive.Binary{Subtype: 0, Data: []byte{0, 1, 2, 3, 4, 5}}}, |
| 633 | + }), |
| 634 | + sType: reflect.TypeOf(D{}), |
| 635 | + want: &D{ |
| 636 | + {"foo", primitive.Binary{Subtype: 0, Data: []byte{0, 1, 2, 3, 4, 5}}}, |
| 637 | + }, |
| 638 | + getByteSlice: func(val interface{}) []byte { |
| 639 | + return (*(val.(*D)))[0].Value.(primitive.Binary).Data |
| 640 | + }, |
| 641 | + }, |
| 642 | + { |
| 643 | + description: "struct with primitive.ObjectID", |
| 644 | + data: docToBytes(fooObjectID{ |
| 645 | + Foo: primitive.ObjectID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, |
| 646 | + }), |
| 647 | + sType: reflect.TypeOf(fooObjectID{}), |
| 648 | + want: &fooObjectID{ |
| 649 | + Foo: primitive.ObjectID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, |
| 650 | + }, |
| 651 | + getByteSlice: func(val interface{}) []byte { |
| 652 | + return (*(val.(*fooObjectID))).Foo[:] |
| 653 | + }, |
| 654 | + }, |
| 655 | + { |
| 656 | + description: "bson.D with primitive.ObjectID", |
| 657 | + data: docToBytes(D{ |
| 658 | + {"foo", primitive.ObjectID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}}, |
| 659 | + }), |
| 660 | + sType: reflect.TypeOf(D{}), |
| 661 | + want: &D{ |
| 662 | + {"foo", primitive.ObjectID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}}, |
| 663 | + }, |
| 664 | + getByteSlice: func(val interface{}) []byte { |
| 665 | + oid := (*(val.(*D)))[0].Value.(primitive.ObjectID) |
| 666 | + return oid[:] |
| 667 | + }, |
| 668 | + }, |
| 669 | + { |
| 670 | + description: "struct with primitive.DBPointer", |
| 671 | + data: docToBytes(fooDBPointer{ |
| 672 | + Foo: primitive.DBPointer{ |
| 673 | + DB: "test", |
| 674 | + Pointer: primitive.ObjectID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, |
| 675 | + }, |
| 676 | + }), |
| 677 | + sType: reflect.TypeOf(fooDBPointer{}), |
| 678 | + want: &fooDBPointer{ |
| 679 | + Foo: primitive.DBPointer{ |
| 680 | + DB: "test", |
| 681 | + Pointer: primitive.ObjectID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, |
| 682 | + }, |
| 683 | + }, |
| 684 | + getByteSlice: func(val interface{}) []byte { |
| 685 | + return (*(val.(*fooDBPointer))).Foo.Pointer[:] |
| 686 | + }, |
| 687 | + }, |
| 688 | + { |
| 689 | + description: "bson.D with primitive.DBPointer", |
| 690 | + data: docToBytes(D{ |
| 691 | + {"foo", primitive.DBPointer{ |
| 692 | + DB: "test", |
| 693 | + Pointer: primitive.ObjectID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, |
| 694 | + }}, |
| 695 | + }), |
| 696 | + sType: reflect.TypeOf(D{}), |
| 697 | + want: &D{ |
| 698 | + {"foo", primitive.DBPointer{ |
| 699 | + DB: "test", |
| 700 | + Pointer: primitive.ObjectID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, |
| 701 | + }}, |
| 702 | + }, |
| 703 | + getByteSlice: func(val interface{}) []byte { |
| 704 | + oid := (*(val.(*D)))[0].Value.(primitive.DBPointer).Pointer |
| 705 | + return oid[:] |
| 706 | + }, |
| 707 | + }, |
| 708 | + } |
| 709 | + |
| 710 | + for _, tc := range testCases { |
| 711 | + tc := tc // Capture range variable. |
| 712 | + t.Run(tc.description, func(t *testing.T) { |
| 713 | + t.Parallel() |
| 714 | + |
| 715 | + // Make a copy of the test data so we can modify it later. |
| 716 | + data := make([]byte, len(tc.data)) |
| 717 | + copy(data, tc.data) |
| 718 | + |
| 719 | + // Assert that unmarshaling the input data results in the expected value. |
| 720 | + got := reflect.New(tc.sType).Interface() |
| 721 | + err := Unmarshal(data, got) |
| 722 | + noerr(t, err) |
| 723 | + assert.Equal(t, tc.want, got, "unmarshaled value does not match the expected value") |
| 724 | + |
| 725 | + // Fill the input data slice with random bytes and then assert that the result still |
| 726 | + // matches the expected value. |
| 727 | + _, err = rand.Read(data) |
| 728 | + noerr(t, err) |
| 729 | + assert.Equal(t, tc.want, got, "unmarshaled value does not match expected after modifying the input bytes") |
| 730 | + |
| 731 | + // Assert that the byte slice in the unmarshaled value does not share any memory |
| 732 | + // addresses with the input byte slice. |
| 733 | + assertDifferentArrays(t, data, tc.getByteSlice(got)) |
| 734 | + }) |
| 735 | + } |
| 736 | +} |
| 737 | + |
| 738 | +// assertDifferentArrays asserts that two byte slices reference distinct memory ranges, meaning |
| 739 | +// they reference different underlying byte arrays. |
| 740 | +func assertDifferentArrays(t *testing.T, a, b []byte) { |
| 741 | + // Find the start and end memory addresses for the underlying byte array for each input byte |
| 742 | + // slice. |
| 743 | + sliceAddrRange := func(b []byte) (uintptr, uintptr) { |
| 744 | + sh := (*reflect.SliceHeader)(unsafe.Pointer(&b)) |
| 745 | + return sh.Data, sh.Data + uintptr(sh.Cap-1) |
| 746 | + } |
| 747 | + aStart, aEnd := sliceAddrRange(a) |
| 748 | + bStart, bEnd := sliceAddrRange(b) |
| 749 | + |
| 750 | + // If "b" starts after "a" ends or "a" starts after "b" ends, there is no overlap. |
| 751 | + if bStart > aEnd || aStart > bEnd { |
| 752 | + return |
| 753 | + } |
| 754 | + |
| 755 | + // Otherwise, calculate the overlap start and end and print the memory overlap error message. |
| 756 | + min := func(a, b uintptr) uintptr { |
| 757 | + if a < b { |
| 758 | + return a |
| 759 | + } |
| 760 | + return b |
| 761 | + } |
| 762 | + max := func(a, b uintptr) uintptr { |
| 763 | + if a > b { |
| 764 | + return a |
| 765 | + } |
| 766 | + return b |
| 767 | + } |
| 768 | + overlapLow := max(aStart, bStart) |
| 769 | + overlapHigh := min(aEnd, bEnd) |
| 770 | + |
| 771 | + t.Errorf("Byte slices point to the same the same underlying byte array:\n"+ |
| 772 | + "\ta addresses:\t%d ... %d\n"+ |
| 773 | + "\tb addresses:\t%d ... %d\n"+ |
| 774 | + "\toverlap:\t%d ... %d", |
| 775 | + aStart, aEnd, |
| 776 | + bStart, bEnd, |
| 777 | + overlapLow, overlapHigh) |
| 778 | +} |
0 commit comments