Skip to content

Commit 202644f

Browse files
author
iwysiu
authored
GODRIVER-1494 add option for encoding/json style map encoding (#345)
1 parent 6a555e4 commit 202644f

File tree

5 files changed

+203
-24
lines changed

5 files changed

+203
-24
lines changed

bson/bson_test.go

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,17 @@ package bson
88

99
import (
1010
"bytes"
11+
"fmt"
12+
"reflect"
13+
"strings"
1114
"testing"
1215
"time"
1316

1417
"github.com/google/go-cmp/cmp"
1518
"github.com/stretchr/testify/require"
19+
"go.mongodb.org/mongo-driver/bson/bsoncodec"
20+
"go.mongodb.org/mongo-driver/bson/bsonoptions"
21+
"go.mongodb.org/mongo-driver/internal/testutil/assert"
1622
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
1723
)
1824

@@ -112,6 +118,71 @@ func TestD(t *testing.T) {
112118
})
113119
}
114120

121+
type stringerString string
122+
123+
func (ss stringerString) String() string {
124+
return "bar"
125+
}
126+
127+
type keyBool bool
128+
129+
func (kb keyBool) MarshalKey() (string, error) {
130+
return fmt.Sprintf("%v", kb), nil
131+
}
132+
133+
func (kb keyBool) UnmarshalKey(key string) error {
134+
switch key {
135+
case "true":
136+
kb = true
137+
case "false":
138+
kb = false
139+
default:
140+
return fmt.Errorf("invalid bool value %v", key)
141+
}
142+
return nil
143+
}
144+
145+
func TestMapCodec(t *testing.T) {
146+
t.Run("EncodeKeysWithStringer", func(t *testing.T) {
147+
strstr := stringerString("foo")
148+
mapObj := map[stringerString]int{strstr: 1}
149+
testCases := []struct {
150+
name string
151+
opts *bsonoptions.MapCodecOptions
152+
key string
153+
}{
154+
{"default", bsonoptions.MapCodec(), "foo"},
155+
{"true", bsonoptions.MapCodec().SetEncodeKeysWithStringer(true), "bar"},
156+
{"false", bsonoptions.MapCodec().SetEncodeKeysWithStringer(false), "foo"},
157+
}
158+
for _, tc := range testCases {
159+
t.Run(tc.name, func(t *testing.T) {
160+
mapCodec := bsoncodec.NewMapCodec(tc.opts)
161+
mapRegistry := NewRegistryBuilder().RegisterDefaultEncoder(reflect.Map, mapCodec).Build()
162+
val, err := MarshalWithRegistry(mapRegistry, mapObj)
163+
assert.Nil(t, err, "Marshal error: %v", err)
164+
assert.True(t, strings.Contains(string(val), tc.key), "expected result to contain %v, got: %v", tc.key, string(val))
165+
})
166+
}
167+
})
168+
t.Run("keys implements keyMarshaler and keyUnmarshaler", func(t *testing.T) {
169+
mapObj := map[keyBool]int{keyBool(false): 1}
170+
171+
doc, err := Marshal(mapObj)
172+
assert.Nil(t, err, "Marshal error: %v", err)
173+
idx, want := bsoncore.AppendDocumentStart(nil)
174+
want = bsoncore.AppendInt32Element(want, "false", 1)
175+
want, _ = bsoncore.AppendDocumentEnd(want, idx)
176+
assert.Equal(t, want, doc, "expected result %v, got %v", string(want), string(doc))
177+
178+
var got map[keyBool]int
179+
err = Unmarshal(doc, &got)
180+
assert.Nil(t, err, "Unmarshal error: %v", err)
181+
assert.Equal(t, mapObj, got, "expected result %v, got %v", mapObj, got)
182+
183+
})
184+
}
185+
115186
func TestExtJSONEscapeKey(t *testing.T) {
116187
doc := D{{Key: "\\usb#", Value: int32(1)}}
117188
b, err := MarshalExtJSON(&doc, false, false)

bson/bsoncodec/default_value_decoders_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -808,7 +808,7 @@ func TestDefaultValueDecoders(t *testing.T) {
808808
&DecodeContext{Registry: buildDefaultRegistry()},
809809
&bsonrwtest.ValueReaderWriter{},
810810
bsonrwtest.ReadElement,
811-
fmt.Errorf("BSON map must have string or decimal keys. Got:%v", reflect.ValueOf(map[bool]interface{}{}).Type()),
811+
fmt.Errorf("unsupported key type: %T", false),
812812
},
813813
{
814814
"ReadDocument Error",

bson/bsoncodec/map_codec.go

Lines changed: 109 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,29 @@ var defaultMapCodec = NewMapCodec()
2020

2121
// MapCodec is the Codec used for map values.
2222
type MapCodec struct {
23-
DecodeZerosMap bool
24-
EncodeNilAsEmpty bool
23+
DecodeZerosMap bool
24+
EncodeNilAsEmpty bool
25+
EncodeKeysWithStringer bool
2526
}
2627

2728
var _ ValueCodec = &MapCodec{}
2829

30+
// KeyMarshaler is the interface implemented by an object that can marshal itself into a string key.
31+
// This applies to types used as map keys and is similar to encoding.TextMarshaler.
32+
type KeyMarshaler interface {
33+
MarshalKey() (key string, err error)
34+
}
35+
36+
// KeyUnmarshaler is the interface implemented by an object that can unmarshal a string representation
37+
// of itself. This applies to types used as map keys and is similar to encoding.TextUnmarshaler.
38+
//
39+
// UnmarshalKey must be able to decode the form generated by MarshalKey.
40+
// UnmarshalKey must copy the text if it wishes to retain the text
41+
// after returning.
42+
type KeyUnmarshaler interface {
43+
UnmarshalKey(key string) error
44+
}
45+
2946
// NewMapCodec returns a MapCodec with options opts.
3047
func NewMapCodec(opts ...*bsonoptions.MapCodecOptions) *MapCodec {
3148
mapOpt := bsonoptions.MergeMapCodecOptions(opts...)
@@ -37,6 +54,9 @@ func NewMapCodec(opts ...*bsonoptions.MapCodecOptions) *MapCodec {
3754
if mapOpt.EncodeNilAsEmpty != nil {
3855
codec.EncodeNilAsEmpty = *mapOpt.EncodeNilAsEmpty
3956
}
57+
if mapOpt.EncodeKeysWithStringer != nil {
58+
codec.EncodeKeysWithStringer = *mapOpt.EncodeKeysWithStringer
59+
}
4060
return &codec
4161
}
4262

@@ -79,7 +99,11 @@ func (mc *MapCodec) mapEncodeValue(ec EncodeContext, dw bsonrw.DocumentWriter, v
7999

80100
keys := val.MapKeys()
81101
for _, key := range keys {
82-
keyStr := fmt.Sprint(key)
102+
keyStr, err := mc.encodeKey(key)
103+
if err != nil {
104+
return err
105+
}
106+
83107
if collisionFn != nil && collisionFn(keyStr) {
84108
return fmt.Errorf("Key %s of inlined map conflicts with a struct field name", key)
85109
}
@@ -160,7 +184,6 @@ func (mc *MapCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val ref
160184
}
161185

162186
keyType := val.Type().Key()
163-
keyKind := keyType.Kind()
164187

165188
for {
166189
key, vr, err := dr.ReadElement()
@@ -171,23 +194,9 @@ func (mc *MapCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val ref
171194
return err
172195
}
173196

174-
k := reflect.ValueOf(key)
175-
if keyType != tString {
176-
switch keyKind {
177-
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
178-
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
179-
reflect.Float32, reflect.Float64:
180-
parsed, err := strconv.ParseFloat(k.String(), 64)
181-
if err != nil {
182-
return fmt.Errorf("Map key is defined to be a decimal type (%v) but got error %v", keyKind, err)
183-
}
184-
k = reflect.ValueOf(parsed)
185-
case reflect.String: // if keyType wraps string
186-
default:
187-
return fmt.Errorf("BSON map must have string or decimal keys. Got:%v", val.Type())
188-
}
189-
190-
k = k.Convert(keyType)
197+
k, err := mc.decodeKey(key, keyType)
198+
if err != nil {
199+
return err
191200
}
192201

193202
elem := reflect.New(eType).Elem()
@@ -207,3 +216,82 @@ func clearMap(m reflect.Value) {
207216
m.SetMapIndex(k, none)
208217
}
209218
}
219+
220+
func (mc *MapCodec) encodeKey(val reflect.Value) (string, error) {
221+
if mc.EncodeKeysWithStringer {
222+
return fmt.Sprint(val), nil
223+
}
224+
225+
// keys of any string type are used directly
226+
if val.Kind() == reflect.String {
227+
return val.String(), nil
228+
}
229+
// KeyMarshalers are marshaled
230+
if km, ok := val.Interface().(KeyMarshaler); ok {
231+
if val.Kind() == reflect.Ptr && val.IsNil() {
232+
return "", nil
233+
}
234+
buf, err := km.MarshalKey()
235+
if err == nil {
236+
return buf, nil
237+
}
238+
return "", err
239+
}
240+
241+
switch val.Kind() {
242+
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
243+
return strconv.FormatInt(val.Int(), 10), nil
244+
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
245+
return strconv.FormatUint(val.Uint(), 10), nil
246+
}
247+
return "", fmt.Errorf("unsupported key type: %v", val.Type())
248+
}
249+
250+
var keyUnmarshalerType = reflect.TypeOf((*KeyUnmarshaler)(nil)).Elem()
251+
252+
func (mc *MapCodec) decodeKey(key string, keyType reflect.Type) (reflect.Value, error) {
253+
keyVal := reflect.ValueOf(key)
254+
var err error
255+
switch {
256+
// First, if EncodeKeysWithStringer is not enabled, try to decode withKeyUnmarshaler
257+
case !mc.EncodeKeysWithStringer && reflect.PtrTo(keyType).Implements(keyUnmarshalerType):
258+
keyVal = reflect.New(keyType)
259+
v := keyVal.Interface().(KeyUnmarshaler)
260+
err = v.UnmarshalKey(key)
261+
keyVal = keyVal.Elem()
262+
// Otherwise, go to type specific behavior
263+
default:
264+
switch keyType.Kind() {
265+
case reflect.String:
266+
keyVal = reflect.ValueOf(key).Convert(keyType)
267+
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
268+
s := string(key)
269+
n, parseErr := strconv.ParseInt(s, 10, 64)
270+
if parseErr != nil || reflect.Zero(keyType).OverflowInt(n) {
271+
err = fmt.Errorf("failed to unmarshal number key %v", s)
272+
}
273+
keyVal = reflect.ValueOf(n).Convert(keyType)
274+
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
275+
s := string(key)
276+
n, parseErr := strconv.ParseUint(s, 10, 64)
277+
if parseErr != nil || reflect.Zero(keyType).OverflowUint(n) {
278+
err = fmt.Errorf("failed to unmarshal number key %v", s)
279+
break
280+
}
281+
keyVal = reflect.ValueOf(n).Convert(keyType)
282+
case reflect.Float32, reflect.Float64:
283+
if mc.EncodeKeysWithStringer {
284+
parsed, err := strconv.ParseFloat(key, 64)
285+
if err != nil {
286+
return keyVal, fmt.Errorf("Map key is defined to be a decimal type (%v) but got error %v", keyType.Kind(), err)
287+
}
288+
keyVal = reflect.ValueOf(parsed)
289+
break
290+
}
291+
fallthrough
292+
default:
293+
return keyVal, fmt.Errorf("unsupported key type: %v", keyType)
294+
}
295+
}
296+
return keyVal, err
297+
}

bson/bsonoptions/map_codec_options.go

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@ package bsonoptions
1010
type MapCodecOptions struct {
1111
DecodeZerosMap *bool // Specifies if the map should be zeroed before decoding into it. Defaults to false.
1212
EncodeNilAsEmpty *bool // Specifies if a nil map should encode as an empty document instead of null. Defaults to false.
13+
// Specifies how keys should be handled. If false, the behavior matches encoding/json, where the encoding key type must
14+
// either be a string, an integer type, or implement bsoncodec.KeyMarshaler and the decoding key type must either be a
15+
// string, an integer type, or implement bsoncodec.KeyUnmarshaler. If true, keys are encoded with fmt.Sprint() and the
16+
// encoding key type must be a string, an integer type, or a float. If true, the use of Stringer will override
17+
// TextMarshaler/TextUnmarshaler. Defaults to false.
18+
EncodeKeysWithStringer *bool
1319
}
1420

1521
// MapCodec creates a new *MapCodecOptions
@@ -23,12 +29,22 @@ func (t *MapCodecOptions) SetDecodeZerosMap(b bool) *MapCodecOptions {
2329
return t
2430
}
2531

26-
// SetEncodeNilAsEmpty specifies if a nil map should encode as an empty document instead of null. Defaults to false.
32+
// SetEncodeNilAsEmpty specifies if a nil map should encode as an empty document instead of null. Defaults to false.
2733
func (t *MapCodecOptions) SetEncodeNilAsEmpty(b bool) *MapCodecOptions {
2834
t.EncodeNilAsEmpty = &b
2935
return t
3036
}
3137

38+
// SetEncodeKeysWithStringer specifies how keys should be handled. If false, the behavior matches encoding/json, where the
39+
// encoding key type must either be a string, an integer type, or implement bsoncodec.KeyMarshaler and the decoding key
40+
// type must either be a string, an integer type, or implement bsoncodec.KeyUnmarshaler. If true, keys are encoded with
41+
// fmt.Sprint() and the encoding key type must be a string, an integer type, or a float. If true, the use of Stringer
42+
// will override TextMarshaler/TextUnmarshaler. Defaults to false.
43+
func (t *MapCodecOptions) SetEncodeKeysWithStringer(b bool) *MapCodecOptions {
44+
t.EncodeKeysWithStringer = &b
45+
return t
46+
}
47+
3248
// MergeMapCodecOptions combines the given *MapCodecOptions into a single *MapCodecOptions in a last one wins fashion.
3349
func MergeMapCodecOptions(opts ...*MapCodecOptions) *MapCodecOptions {
3450
s := MapCodec()
@@ -42,6 +58,9 @@ func MergeMapCodecOptions(opts ...*MapCodecOptions) *MapCodecOptions {
4258
if opt.EncodeNilAsEmpty != nil {
4359
s.EncodeNilAsEmpty = opt.EncodeNilAsEmpty
4460
}
61+
if opt.EncodeKeysWithStringer != nil {
62+
s.EncodeKeysWithStringer = opt.EncodeKeysWithStringer
63+
}
4564
}
4665

4766
return s

bson/mgocompat/registry.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ func NewRegistryBuilder() *bsoncodec.RegistryBuilder {
5858
mapCodec := bsoncodec.NewMapCodec(
5959
bsonoptions.MapCodec().
6060
SetDecodeZerosMap(true).
61-
SetEncodeNilAsEmpty(true))
61+
SetEncodeNilAsEmpty(true).
62+
SetEncodeKeysWithStringer(true))
6263
uintcodec := bsoncodec.NewUIntCodec(bsonoptions.UIntCodec().SetEncodeToMinSize(true))
6364

6465
rb.RegisterTypeDecoder(tEmpty, emptyInterCodec).

0 commit comments

Comments
 (0)