Skip to content

Commit 5aec41b

Browse files
dsnetneild
authored andcommitted
testing/protocmp: add Message.Unwrap
The Unwrap method returns the original concrete message value. In theory this allows users to mutate the original message when the cmp documentation says that all options should be mutation free. If users want to disregard this documented restriction, they can already do so in a number of different ways. Updates #1347 Change-Id: I65225681ab5dbce0763a140fd02666a4ab542a04 Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/340489 Trust: Joe Tsai <[email protected]> Reviewed-by: Damien Neil <[email protected]>
1 parent 05be61f commit 5aec41b

File tree

4 files changed

+49
-31
lines changed

4 files changed

+49
-31
lines changed

testing/protocmp/reflect.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ func (m reflectMessage) Range(f func(fd protoreflect.FieldDescriptor, v protoref
6868
}
6969

7070
// Range over populated extension fields.
71-
for _, xd := range m[messageTypeKey].(messageType).xds {
71+
for _, xd := range m[messageTypeKey].(messageMeta).xds {
7272
if m.Has(xd) && !f(xd, m.Get(xd)) {
7373
return
7474
}
@@ -91,7 +91,7 @@ func (m reflectMessage) Get(fd protoreflect.FieldDescriptor) protoreflect.Value
9191
return protoreflect.ValueOfMap(reflectMap{})
9292
case fd.Message() != nil:
9393
return protoreflect.ValueOfMessage(reflectMessage{
94-
messageTypeKey: messageType{md: m.Descriptor()},
94+
messageTypeKey: messageMeta{md: fd.Message()},
9595
})
9696
default:
9797
return fd.Default()

testing/protocmp/util.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -297,11 +297,11 @@ func (f *nameFilters) filterFieldName(m Message, k string) bool {
297297
return true // treat missing fields as already filtered
298298
}
299299
var fd protoreflect.FieldDescriptor
300-
switch mt := m[messageTypeKey].(messageType); {
300+
switch mm := m[messageTypeKey].(messageMeta); {
301301
case protoreflect.Name(k).IsValid():
302-
fd = mt.md.Fields().ByTextName(k)
302+
fd = mm.md.Fields().ByTextName(k)
303303
default:
304-
fd = mt.xds[k]
304+
fd = mm.xds[k]
305305
}
306306
if fd != nil {
307307
return f.names[fd.FullName()]
@@ -376,11 +376,11 @@ func isDefaultScalar(m Message, k string) bool {
376376
}
377377

378378
var fd protoreflect.FieldDescriptor
379-
switch mt := m[messageTypeKey].(messageType); {
379+
switch mm := m[messageTypeKey].(messageMeta); {
380380
case protoreflect.Name(k).IsValid():
381-
fd = mt.md.Fields().ByTextName(k)
381+
fd = mm.md.Fields().ByTextName(k)
382382
default:
383-
fd = mt.xds[k]
383+
fd = mm.xds[k]
384384
}
385385
if fd == nil || !fd.Default().IsValid() {
386386
return false

testing/protocmp/xform.go

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,20 +68,28 @@ func (e Enum) String() string {
6868
}
6969

7070
const (
71-
messageTypeKey = "@type"
71+
// messageTypeKey indicates the protobuf message type.
72+
// The value type is always messageMeta.
73+
// From the public API, it presents itself as only the type, but the
74+
// underlying data structure holds arbitrary metadata about the message.
75+
messageTypeKey = "@type"
76+
77+
// messageInvalidKey indicates that the message is invalid.
78+
// The value is always the boolean "true".
7279
messageInvalidKey = "@invalid"
7380
)
7481

75-
type messageType struct {
82+
type messageMeta struct {
83+
m proto.Message
7684
md protoreflect.MessageDescriptor
7785
xds map[string]protoreflect.ExtensionDescriptor
7886
}
7987

80-
func (t messageType) String() string {
88+
func (t messageMeta) String() string {
8189
return string(t.md.FullName())
8290
}
8391

84-
func (t1 messageType) Equal(t2 messageType) bool {
92+
func (t1 messageMeta) Equal(t2 messageMeta) bool {
8593
return t1.md.FullName() == t2.md.FullName()
8694
}
8795

@@ -109,11 +117,18 @@ func (t1 messageType) Equal(t2 messageType) bool {
109117
// Message values must not be created by or mutated by users.
110118
type Message map[string]interface{}
111119

120+
// Unwrap returns the original message value.
121+
// It returns nil if this Message was not constructed from another message.
122+
func (m Message) Unwrap() proto.Message {
123+
mm, _ := m[messageTypeKey].(messageMeta)
124+
return mm.m
125+
}
126+
112127
// Descriptor return the message descriptor.
113128
// It returns nil for a zero Message value.
114129
func (m Message) Descriptor() protoreflect.MessageDescriptor {
115-
mt, _ := m[messageTypeKey].(messageType)
116-
return mt.md
130+
mm, _ := m[messageTypeKey].(messageMeta)
131+
return mm.md
117132
}
118133

119134
// ProtoReflect returns a reflective view of m.
@@ -201,7 +216,7 @@ func Transform(...option) cmp.Option {
201216
case m == nil:
202217
return nil
203218
case !m.IsValid():
204-
return Message{messageTypeKey: messageType{md: m.Descriptor()}, messageInvalidKey: true}
219+
return Message{messageTypeKey: messageMeta{m: m.Interface(), md: m.Descriptor()}, messageInvalidKey: true}
205220
default:
206221
return transformMessage(m)
207222
}
@@ -218,7 +233,7 @@ func isMessageType(t reflect.Type) bool {
218233

219234
func transformMessage(m protoreflect.Message) Message {
220235
mx := Message{}
221-
mt := messageType{md: m.Descriptor(), xds: make(map[string]protoreflect.FieldDescriptor)}
236+
mt := messageMeta{m: m.Interface(), md: m.Descriptor(), xds: make(map[string]protoreflect.FieldDescriptor)}
222237

223238
// Handle known and extension fields.
224239
m.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {

testing/protocmp/xform_test.go

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ func TestTransform(t *testing.T) {
4040
OptionalNestedMessage: &testpb.TestAllTypes_NestedMessage{A: proto.Int32(5)},
4141
},
4242
want: Message{
43-
messageTypeKey: messageTypeOf(&testpb.TestAllTypes{}),
43+
messageTypeKey: messageMetaOf(&testpb.TestAllTypes{}),
4444
"optional_bool": bool(false),
4545
"optional_int32": int32(-32),
4646
"optional_int64": int64(-64),
@@ -51,7 +51,7 @@ func TestTransform(t *testing.T) {
5151
"optional_string": string("string"),
5252
"optional_bytes": []byte("bytes"),
5353
"optional_nested_enum": enumOf(testpb.TestAllTypes_NEG),
54-
"optional_nested_message": Message{messageTypeKey: messageTypeOf(&testpb.TestAllTypes_NestedMessage{}), "a": int32(5)},
54+
"optional_nested_message": Message{messageTypeKey: messageMetaOf(&testpb.TestAllTypes_NestedMessage{}), "a": int32(5)},
5555
},
5656
}, {
5757
in: &testpb.TestAllTypes{
@@ -74,7 +74,7 @@ func TestTransform(t *testing.T) {
7474
},
7575
},
7676
want: Message{
77-
messageTypeKey: messageTypeOf(&testpb.TestAllTypes{}),
77+
messageTypeKey: messageMetaOf(&testpb.TestAllTypes{}),
7878
"repeated_bool": []bool{false, true},
7979
"repeated_int32": []int32{32, -32},
8080
"repeated_int64": []int64{64, -64},
@@ -89,8 +89,8 @@ func TestTransform(t *testing.T) {
8989
enumOf(testpb.TestAllTypes_BAR),
9090
},
9191
"repeated_nested_message": []Message{
92-
{messageTypeKey: messageTypeOf(&testpb.TestAllTypes_NestedMessage{}), "a": int32(5)},
93-
{messageTypeKey: messageTypeOf(&testpb.TestAllTypes_NestedMessage{}), "a": int32(-5)},
92+
{messageTypeKey: messageMetaOf(&testpb.TestAllTypes_NestedMessage{}), "a": int32(5)},
93+
{messageTypeKey: messageMetaOf(&testpb.TestAllTypes_NestedMessage{}), "a": int32(-5)},
9494
},
9595
},
9696
}, {
@@ -112,7 +112,7 @@ func TestTransform(t *testing.T) {
112112
},
113113
},
114114
want: Message{
115-
messageTypeKey: messageTypeOf(&testpb.TestAllTypes{}),
115+
messageTypeKey: messageMetaOf(&testpb.TestAllTypes{}),
116116
"map_bool_bool": map[bool]bool{true: false},
117117
"map_int32_int32": map[int32]int32{-32: 32},
118118
"map_int64_int64": map[int64]int64{-64: 64},
@@ -126,7 +126,7 @@ func TestTransform(t *testing.T) {
126126
"k": enumOf(testpb.TestAllTypes_FOO),
127127
},
128128
"map_string_nested_message": map[string]Message{
129-
"k": {messageTypeKey: messageTypeOf(&testpb.TestAllTypes_NestedMessage{}), "a": int32(5)},
129+
"k": {messageTypeKey: messageMetaOf(&testpb.TestAllTypes_NestedMessage{}), "a": int32(5)},
130130
},
131131
},
132132
}, {
@@ -146,7 +146,7 @@ func TestTransform(t *testing.T) {
146146
return m
147147
}(),
148148
want: Message{
149-
messageTypeKey: messageTypeOf(&testpb.TestAllExtensions{}),
149+
messageTypeKey: messageMetaOf(&testpb.TestAllExtensions{}),
150150
"[goproto.proto.test.optional_bool]": bool(false),
151151
"[goproto.proto.test.optional_int32]": int32(-32),
152152
"[goproto.proto.test.optional_int64]": int64(-64),
@@ -157,7 +157,7 @@ func TestTransform(t *testing.T) {
157157
"[goproto.proto.test.optional_string]": string("string"),
158158
"[goproto.proto.test.optional_bytes]": []byte("bytes"),
159159
"[goproto.proto.test.optional_nested_enum]": enumOf(testpb.TestAllTypes_NEG),
160-
"[goproto.proto.test.optional_nested_message]": Message{messageTypeKey: messageTypeOf(&testpb.TestAllExtensions_NestedMessage{}), "a": int32(5)},
160+
"[goproto.proto.test.optional_nested_message]": Message{messageTypeKey: messageMetaOf(&testpb.TestAllExtensions_NestedMessage{}), "a": int32(5)},
161161
},
162162
}, {
163163
in: func() proto.Message {
@@ -182,7 +182,7 @@ func TestTransform(t *testing.T) {
182182
return m
183183
}(),
184184
want: Message{
185-
messageTypeKey: messageTypeOf(&testpb.TestAllExtensions{}),
185+
messageTypeKey: messageMetaOf(&testpb.TestAllExtensions{}),
186186
"[goproto.proto.test.repeated_bool]": []bool{false, true},
187187
"[goproto.proto.test.repeated_int32]": []int32{32, -32},
188188
"[goproto.proto.test.repeated_int64]": []int64{64, -64},
@@ -197,8 +197,8 @@ func TestTransform(t *testing.T) {
197197
enumOf(testpb.TestAllTypes_BAR),
198198
},
199199
"[goproto.proto.test.repeated_nested_message]": []Message{
200-
{messageTypeKey: messageTypeOf(&testpb.TestAllExtensions_NestedMessage{}), "a": int32(5)},
201-
{messageTypeKey: messageTypeOf(&testpb.TestAllExtensions_NestedMessage{}), "a": int32(-5)},
200+
{messageTypeKey: messageMetaOf(&testpb.TestAllExtensions_NestedMessage{}), "a": int32(5)},
201+
{messageTypeKey: messageMetaOf(&testpb.TestAllExtensions_NestedMessage{}), "a": int32(-5)},
202202
},
203203
},
204204
}, {
@@ -229,7 +229,7 @@ func TestTransform(t *testing.T) {
229229
return m
230230
}(),
231231
want: Message{
232-
messageTypeKey: messageTypeOf(&testpb.TestAllTypes{}),
232+
messageTypeKey: messageMetaOf(&testpb.TestAllTypes{}),
233233
"50000": protoreflect.RawFields(protopack.Message{protopack.Tag{Number: 50000, Type: protopack.VarintType}, protopack.Uvarint(100)}.Marshal()),
234234
"50001": protoreflect.RawFields(protopack.Message{protopack.Tag{Number: 50001, Type: protopack.Fixed32Type}, protopack.Uint32(200)}.Marshal()),
235235
"50002": protoreflect.RawFields(protopack.Message{protopack.Tag{Number: 50002, Type: protopack.Fixed64Type}, protopack.Uint64(300)}.Marshal()),
@@ -258,6 +258,9 @@ func TestTransform(t *testing.T) {
258258
if diff := cmp.Diff(tt.want, got); diff != "" {
259259
t.Errorf("Transform() mismatch (-want +got):\n%v", diff)
260260
}
261+
if got.Unwrap() != tt.in {
262+
t.Errorf("got.Unwrap() = %p, want %p", got.Unwrap(), tt.in)
263+
}
261264
})
262265
}
263266
}
@@ -266,6 +269,6 @@ func enumOf(e protoreflect.Enum) Enum {
266269
return Enum{e.Number(), e.Descriptor()}
267270
}
268271

269-
func messageTypeOf(m protoreflect.ProtoMessage) messageType {
270-
return messageType{md: m.ProtoReflect().Descriptor()}
272+
func messageMetaOf(m protoreflect.ProtoMessage) messageMeta {
273+
return messageMeta{m: m, md: m.ProtoReflect().Descriptor()}
271274
}

0 commit comments

Comments
 (0)