@@ -20,12 +20,29 @@ var defaultMapCodec = NewMapCodec()
20
20
21
21
// MapCodec is the Codec used for map values.
22
22
type MapCodec struct {
23
- DecodeZerosMap bool
24
- EncodeNilAsEmpty bool
23
+ DecodeZerosMap bool
24
+ EncodeNilAsEmpty bool
25
+ EncodeKeysWithStringer bool
25
26
}
26
27
27
28
var _ ValueCodec = & MapCodec {}
28
29
30
+ // KeyMarshaler is the interface implemented by an object that can marshal itself into a string key.
31
+ // This applies to types used as map keys and is similar to encoding.TextMarshaler.
32
+ type KeyMarshaler interface {
33
+ MarshalKey () (key string , err error )
34
+ }
35
+
36
+ // KeyUnmarshaler is the interface implemented by an object that can unmarshal a string representation
37
+ // of itself. This applies to types used as map keys and is similar to encoding.TextUnmarshaler.
38
+ //
39
+ // UnmarshalKey must be able to decode the form generated by MarshalKey.
40
+ // UnmarshalKey must copy the text if it wishes to retain the text
41
+ // after returning.
42
+ type KeyUnmarshaler interface {
43
+ UnmarshalKey (key string ) error
44
+ }
45
+
29
46
// NewMapCodec returns a MapCodec with options opts.
30
47
func NewMapCodec (opts ... * bsonoptions.MapCodecOptions ) * MapCodec {
31
48
mapOpt := bsonoptions .MergeMapCodecOptions (opts ... )
@@ -37,6 +54,9 @@ func NewMapCodec(opts ...*bsonoptions.MapCodecOptions) *MapCodec {
37
54
if mapOpt .EncodeNilAsEmpty != nil {
38
55
codec .EncodeNilAsEmpty = * mapOpt .EncodeNilAsEmpty
39
56
}
57
+ if mapOpt .EncodeKeysWithStringer != nil {
58
+ codec .EncodeKeysWithStringer = * mapOpt .EncodeKeysWithStringer
59
+ }
40
60
return & codec
41
61
}
42
62
@@ -79,7 +99,11 @@ func (mc *MapCodec) mapEncodeValue(ec EncodeContext, dw bsonrw.DocumentWriter, v
79
99
80
100
keys := val .MapKeys ()
81
101
for _ , key := range keys {
82
- keyStr := fmt .Sprint (key )
102
+ keyStr , err := mc .encodeKey (key )
103
+ if err != nil {
104
+ return err
105
+ }
106
+
83
107
if collisionFn != nil && collisionFn (keyStr ) {
84
108
return fmt .Errorf ("Key %s of inlined map conflicts with a struct field name" , key )
85
109
}
@@ -160,7 +184,6 @@ func (mc *MapCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val ref
160
184
}
161
185
162
186
keyType := val .Type ().Key ()
163
- keyKind := keyType .Kind ()
164
187
165
188
for {
166
189
key , vr , err := dr .ReadElement ()
@@ -171,23 +194,9 @@ func (mc *MapCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val ref
171
194
return err
172
195
}
173
196
174
- k := reflect .ValueOf (key )
175
- if keyType != tString {
176
- switch keyKind {
177
- case reflect .Int , reflect .Int8 , reflect .Int16 , reflect .Int32 , reflect .Int64 ,
178
- reflect .Uint , reflect .Uint8 , reflect .Uint16 , reflect .Uint32 , reflect .Uint64 ,
179
- reflect .Float32 , reflect .Float64 :
180
- parsed , err := strconv .ParseFloat (k .String (), 64 )
181
- if err != nil {
182
- return fmt .Errorf ("Map key is defined to be a decimal type (%v) but got error %v" , keyKind , err )
183
- }
184
- k = reflect .ValueOf (parsed )
185
- case reflect .String : // if keyType wraps string
186
- default :
187
- return fmt .Errorf ("BSON map must have string or decimal keys. Got:%v" , val .Type ())
188
- }
189
-
190
- k = k .Convert (keyType )
197
+ k , err := mc .decodeKey (key , keyType )
198
+ if err != nil {
199
+ return err
191
200
}
192
201
193
202
elem := reflect .New (eType ).Elem ()
@@ -207,3 +216,82 @@ func clearMap(m reflect.Value) {
207
216
m .SetMapIndex (k , none )
208
217
}
209
218
}
219
+
220
+ func (mc * MapCodec ) encodeKey (val reflect.Value ) (string , error ) {
221
+ if mc .EncodeKeysWithStringer {
222
+ return fmt .Sprint (val ), nil
223
+ }
224
+
225
+ // keys of any string type are used directly
226
+ if val .Kind () == reflect .String {
227
+ return val .String (), nil
228
+ }
229
+ // KeyMarshalers are marshaled
230
+ if km , ok := val .Interface ().(KeyMarshaler ); ok {
231
+ if val .Kind () == reflect .Ptr && val .IsNil () {
232
+ return "" , nil
233
+ }
234
+ buf , err := km .MarshalKey ()
235
+ if err == nil {
236
+ return buf , nil
237
+ }
238
+ return "" , err
239
+ }
240
+
241
+ switch val .Kind () {
242
+ case reflect .Int , reflect .Int8 , reflect .Int16 , reflect .Int32 , reflect .Int64 :
243
+ return strconv .FormatInt (val .Int (), 10 ), nil
244
+ case reflect .Uint , reflect .Uint8 , reflect .Uint16 , reflect .Uint32 , reflect .Uint64 , reflect .Uintptr :
245
+ return strconv .FormatUint (val .Uint (), 10 ), nil
246
+ }
247
+ return "" , fmt .Errorf ("unsupported key type: %v" , val .Type ())
248
+ }
249
+
250
+ var keyUnmarshalerType = reflect .TypeOf ((* KeyUnmarshaler )(nil )).Elem ()
251
+
252
+ func (mc * MapCodec ) decodeKey (key string , keyType reflect.Type ) (reflect.Value , error ) {
253
+ keyVal := reflect .ValueOf (key )
254
+ var err error
255
+ switch {
256
+ // First, if EncodeKeysWithStringer is not enabled, try to decode withKeyUnmarshaler
257
+ case ! mc .EncodeKeysWithStringer && reflect .PtrTo (keyType ).Implements (keyUnmarshalerType ):
258
+ keyVal = reflect .New (keyType )
259
+ v := keyVal .Interface ().(KeyUnmarshaler )
260
+ err = v .UnmarshalKey (key )
261
+ keyVal = keyVal .Elem ()
262
+ // Otherwise, go to type specific behavior
263
+ default :
264
+ switch keyType .Kind () {
265
+ case reflect .String :
266
+ keyVal = reflect .ValueOf (key ).Convert (keyType )
267
+ case reflect .Int , reflect .Int8 , reflect .Int16 , reflect .Int32 , reflect .Int64 :
268
+ s := string (key )
269
+ n , parseErr := strconv .ParseInt (s , 10 , 64 )
270
+ if parseErr != nil || reflect .Zero (keyType ).OverflowInt (n ) {
271
+ err = fmt .Errorf ("failed to unmarshal number key %v" , s )
272
+ }
273
+ keyVal = reflect .ValueOf (n ).Convert (keyType )
274
+ case reflect .Uint , reflect .Uint8 , reflect .Uint16 , reflect .Uint32 , reflect .Uint64 , reflect .Uintptr :
275
+ s := string (key )
276
+ n , parseErr := strconv .ParseUint (s , 10 , 64 )
277
+ if parseErr != nil || reflect .Zero (keyType ).OverflowUint (n ) {
278
+ err = fmt .Errorf ("failed to unmarshal number key %v" , s )
279
+ break
280
+ }
281
+ keyVal = reflect .ValueOf (n ).Convert (keyType )
282
+ case reflect .Float32 , reflect .Float64 :
283
+ if mc .EncodeKeysWithStringer {
284
+ parsed , err := strconv .ParseFloat (key , 64 )
285
+ if err != nil {
286
+ return keyVal , fmt .Errorf ("Map key is defined to be a decimal type (%v) but got error %v" , keyType .Kind (), err )
287
+ }
288
+ keyVal = reflect .ValueOf (parsed )
289
+ break
290
+ }
291
+ fallthrough
292
+ default :
293
+ return keyVal , fmt .Errorf ("unsupported key type: %v" , keyType )
294
+ }
295
+ }
296
+ return keyVal , err
297
+ }
0 commit comments