Skip to content

Commit eccd8d6

Browse files
committed
Improve safety of decoding into interfaces
1 parent 25c8824 commit eccd8d6

File tree

2 files changed

+103
-63
lines changed

2 files changed

+103
-63
lines changed

decoder.go

Lines changed: 69 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -127,15 +127,16 @@ func (d *decoder) unmarshalBool(size uint, offset uint, result reflect.Value) (u
127127
return 0, err
128128
}
129129
switch result.Kind() {
130-
default:
131-
return newOffset, newUnmarshalTypeError(value, result.Type())
132130
case reflect.Bool:
133131
result.SetBool(value)
134132
return newOffset, nil
135133
case reflect.Interface:
136-
result.Set(reflect.ValueOf(value))
137-
return newOffset, nil
134+
if result.NumMethod() == 0 {
135+
result.Set(reflect.ValueOf(value))
136+
return newOffset, nil
137+
}
138138
}
139+
return newOffset, newUnmarshalTypeError(value, result.Type())
139140
}
140141

141142
// follow pointers and create values as necessary
@@ -161,15 +162,16 @@ func (d *decoder) unmarshalBytes(size uint, offset uint, result reflect.Value) (
161162
return 0, err
162163
}
163164
switch result.Kind() {
164-
default:
165-
return newOffset, newUnmarshalTypeError(value, result.Type())
166165
case reflect.Slice:
167166
result.SetBytes(value)
168167
return newOffset, nil
169168
case reflect.Interface:
170-
result.Set(reflect.ValueOf(value))
171-
return newOffset, nil
169+
if result.NumMethod() == 0 {
170+
result.Set(reflect.ValueOf(value))
171+
return newOffset, nil
172+
}
172173
}
174+
return newOffset, newUnmarshalTypeError(value, result.Type())
173175
}
174176

175177
func (d *decoder) unmarshalFloat32(size uint, offset uint, result reflect.Value) (uint, error) {
@@ -182,15 +184,16 @@ func (d *decoder) unmarshalFloat32(size uint, offset uint, result reflect.Value)
182184
}
183185

184186
switch result.Kind() {
185-
default:
186-
return newOffset, newUnmarshalTypeError(value, result.Type())
187187
case reflect.Float32, reflect.Float64:
188188
result.SetFloat(float64(value))
189189
return newOffset, nil
190190
case reflect.Interface:
191-
result.Set(reflect.ValueOf(value))
192-
return newOffset, nil
191+
if result.NumMethod() == 0 {
192+
result.Set(reflect.ValueOf(value))
193+
return newOffset, nil
194+
}
193195
}
196+
return newOffset, newUnmarshalTypeError(value, result.Type())
194197
}
195198

196199
func (d *decoder) unmarshalFloat64(size uint, offset uint, result reflect.Value) (uint, error) {
@@ -203,18 +206,19 @@ func (d *decoder) unmarshalFloat64(size uint, offset uint, result reflect.Value)
203206
return 0, err
204207
}
205208
switch result.Kind() {
206-
default:
207-
return newOffset, newUnmarshalTypeError(value, result.Type())
208209
case reflect.Float32, reflect.Float64:
209210
if result.OverflowFloat(value) {
210211
return 0, newUnmarshalTypeError(value, result.Type())
211212
}
212213
result.SetFloat(value)
213214
return newOffset, nil
214215
case reflect.Interface:
215-
result.Set(reflect.ValueOf(value))
216-
return newOffset, nil
216+
if result.NumMethod() == 0 {
217+
result.Set(reflect.ValueOf(value))
218+
return newOffset, nil
219+
}
217220
}
221+
return newOffset, newUnmarshalTypeError(value, result.Type())
218222
}
219223

220224
func (d *decoder) unmarshalInt32(size uint, offset uint, result reflect.Value) (uint, error) {
@@ -227,26 +231,25 @@ func (d *decoder) unmarshalInt32(size uint, offset uint, result reflect.Value) (
227231
}
228232

229233
switch result.Kind() {
230-
default:
231-
return newOffset, newUnmarshalTypeError(value, result.Type())
232234
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
233235
n := int64(value)
234-
if result.OverflowInt(n) {
235-
return 0, newUnmarshalTypeError(value, result.Type())
236+
if !result.OverflowInt(n) {
237+
result.SetInt(n)
238+
return newOffset, nil
236239
}
237-
result.SetInt(n)
238-
return newOffset, nil
239240
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
240241
n := uint64(value)
241-
if result.OverflowUint(n) {
242-
return 0, newUnmarshalTypeError(value, result.Type())
242+
if !result.OverflowUint(n) {
243+
result.SetUint(n)
244+
return newOffset, nil
243245
}
244-
result.SetUint(n)
245-
return newOffset, nil
246246
case reflect.Interface:
247-
result.Set(reflect.ValueOf(value))
248-
return newOffset, nil
247+
if result.NumMethod() == 0 {
248+
result.Set(reflect.ValueOf(value))
249+
return newOffset, nil
250+
}
249251
}
252+
return newOffset, newUnmarshalTypeError(value, result.Type())
250253
}
251254

252255
func (d *decoder) unmarshalMap(size uint, offset uint, result reflect.Value) (uint, error) {
@@ -259,10 +262,13 @@ func (d *decoder) unmarshalMap(size uint, offset uint, result reflect.Value) (ui
259262
case reflect.Map:
260263
return d.decodeMap(size, offset, result)
261264
case reflect.Interface:
262-
rv := reflect.ValueOf(make(map[string]interface{}, size))
263-
newOffset, err := d.decodeMap(size, offset, rv)
264-
result.Set(rv)
265-
return newOffset, err
265+
if result.NumMethod() == 0 {
266+
rv := reflect.ValueOf(make(map[string]interface{}, size))
267+
newOffset, err := d.decodeMap(size, offset, rv)
268+
result.Set(rv)
269+
return newOffset, err
270+
}
271+
return 0, newUnmarshalTypeError("map", result.Type())
266272
}
267273
}
268274

@@ -273,19 +279,19 @@ func (d *decoder) unmarshalPointer(size uint, offset uint, result reflect.Value)
273279
}
274280

275281
func (d *decoder) unmarshalSlice(size uint, offset uint, result reflect.Value) (uint, error) {
276-
277282
switch result.Kind() {
278-
default:
279-
return 0, newUnmarshalTypeError("array", result.Type())
280283
case reflect.Slice:
281284
return d.decodeSlice(size, offset, result)
282285
case reflect.Interface:
283-
a := []interface{}{}
284-
rv := reflect.ValueOf(&a).Elem()
285-
newOffset, err := d.decodeSlice(size, offset, rv)
286-
result.Set(rv)
287-
return newOffset, err
286+
if result.NumMethod() == 0 {
287+
a := []interface{}{}
288+
rv := reflect.ValueOf(&a).Elem()
289+
newOffset, err := d.decodeSlice(size, offset, rv)
290+
result.Set(rv)
291+
return newOffset, err
292+
}
288293
}
294+
return 0, newUnmarshalTypeError("array", result.Type())
289295
}
290296

291297
func (d *decoder) unmarshalString(size uint, offset uint, result reflect.Value) (uint, error) {
@@ -295,15 +301,17 @@ func (d *decoder) unmarshalString(size uint, offset uint, result reflect.Value)
295301
return 0, err
296302
}
297303
switch result.Kind() {
298-
default:
299-
return newOffset, newUnmarshalTypeError(value, result.Type())
300304
case reflect.String:
301305
result.SetString(value)
302306
return newOffset, nil
303307
case reflect.Interface:
304-
result.Set(reflect.ValueOf(value))
305-
return newOffset, nil
308+
if result.NumMethod() == 0 {
309+
result.Set(reflect.ValueOf(value))
310+
return newOffset, nil
311+
}
306312
}
313+
return newOffset, newUnmarshalTypeError(value, result.Type())
314+
307315
}
308316

309317
func (d *decoder) unmarshalUint(size uint, offset uint, result reflect.Value, uintType uint) (uint, error) {
@@ -317,25 +325,24 @@ func (d *decoder) unmarshalUint(size uint, offset uint, result reflect.Value, ui
317325
}
318326

319327
switch result.Kind() {
320-
default:
321-
return newOffset, newUnmarshalTypeError(value, result.Type())
322328
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
323329
n := int64(value)
324-
if result.OverflowInt(n) {
325-
return 0, newUnmarshalTypeError(value, result.Type())
330+
if !result.OverflowInt(n) {
331+
result.SetInt(n)
332+
return newOffset, nil
326333
}
327-
result.SetInt(n)
328-
return newOffset, nil
329334
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
330-
if result.OverflowUint(value) {
331-
return 0, newUnmarshalTypeError(value, result.Type())
335+
if !result.OverflowUint(value) {
336+
result.SetUint(value)
337+
return newOffset, nil
332338
}
333-
result.SetUint(value)
334-
return newOffset, nil
335339
case reflect.Interface:
336-
result.Set(reflect.ValueOf(value))
337-
return newOffset, nil
340+
if result.NumMethod() == 0 {
341+
result.Set(reflect.ValueOf(value))
342+
return newOffset, nil
343+
}
338344
}
345+
return newOffset, newUnmarshalTypeError(value, result.Type())
339346
}
340347

341348
func (d *decoder) unmarshalUint128(size uint, offset uint, result reflect.Value) (uint, error) {
@@ -347,18 +354,17 @@ func (d *decoder) unmarshalUint128(size uint, offset uint, result reflect.Value)
347354
return 0, err
348355
}
349356

350-
// XXX - this should allow *big.Int rather than just bigInt
351-
// Currently this is reported as invalid
352357
switch result.Kind() {
353-
default:
354-
return newOffset, newUnmarshalTypeError(value, result.Type())
355358
case reflect.Struct:
356359
result.Set(reflect.ValueOf(*value))
357360
return newOffset, nil
358-
case reflect.Interface, reflect.Ptr:
359-
result.Set(reflect.ValueOf(value))
360-
return newOffset, nil
361+
case reflect.Interface:
362+
if result.NumMethod() == 0 {
363+
result.Set(reflect.ValueOf(value))
364+
return newOffset, nil
365+
}
361366
}
367+
return newOffset, newUnmarshalTypeError(value, result.Type())
362368
}
363369

364370
func (d *decoder) decodeBool(size uint, offset uint) (bool, uint, error) {

reader_test.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,41 @@ func (s *MySuite) TestStructInterface(c *C) {
179179
c.Assert(reader.Lookup(net.ParseIP("::1.1.1.0"), &result), IsNil)
180180

181181
c.Assert(result.method(), Equals, true)
182+
}
183+
184+
func (s *MySuite) TestNonEmptyNilInterface(c *C) {
185+
var result TestInterface
186+
187+
reader, err := Open("test-data/test-data/MaxMind-DB-test-decoder.mmdb")
188+
c.Assert(err, IsNil)
189+
190+
err = reader.Lookup(net.ParseIP("::1.1.1.0"), &result)
191+
c.Assert(err.Error(), Equals, "maxminddb: cannot unmarshal map into type maxminddb.TestInterface")
192+
}
193+
194+
type BoolInterface interface {
195+
true() bool
196+
}
197+
198+
type Bool bool
199+
200+
func (b Bool) true() bool {
201+
return bool(b)
202+
}
203+
204+
type ValueTypeTestType struct {
205+
Boolean BoolInterface `maxminddb:"boolean"`
206+
}
207+
208+
func (s *MySuite) TesValueTypeInterface(c *C) {
209+
var result ValueTypeTestType
210+
result.Boolean = Bool(false)
211+
212+
reader, err := Open("test-data/test-data/MaxMind-DB-test-decoder.mmdb")
213+
c.Assert(err, IsNil)
214+
c.Assert(reader.Lookup(net.ParseIP("::1.1.1.0"), &result), IsNil)
182215

216+
c.Assert(result.Boolean.true(), Equals, true)
183217
}
184218

185219
type NestedMapX struct {

0 commit comments

Comments
 (0)