Skip to content

Support custom deserializer #74

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 135 additions & 0 deletions decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,26 @@ func (d *decoder) decode(offset uint, result reflect.Value, depth int) (uint, er
return d.decodeFromType(typeNum, size, newOffset, result, depth+1)
}

func (d *decoder) decodeToDeserializer(offset uint, dser deserializer, depth int) (uint, error) {
if depth > maximumDataStructureDepth {
return 0, newInvalidDatabaseError("exceeded maximum data structure depth; database is likely corrupt")
}
typeNum, size, newOffset, err := d.decodeCtrlData(offset)
if err != nil {
return 0, err
}

skip, err := dser.ShouldSkip(uintptr(offset))
if err != nil {
return 0, err
}
if skip {
return d.nextValueOffset(offset, 1)
}

return d.decodeFromTypeToDeserializer(typeNum, size, newOffset, dser, depth+1)
}

func (d *decoder) decodeCtrlData(offset uint) (dataType, uint, uint, error) {
newOffset := offset + 1
if offset >= uint(len(d.buffer)) {
Expand Down Expand Up @@ -157,6 +177,68 @@ func (d *decoder) decodeFromType(
}
}

func (d *decoder) decodeFromTypeToDeserializer(
dtype dataType,
size uint,
offset uint,
dser deserializer,
depth int,
) (uint, error) {
// For these types, size has a special meaning
switch dtype {
case _Bool:
v, offset := d.decodeBool(size, offset)
return offset, dser.Bool(v)
case _Map:
return d.decodeMapToDeserializer(size, offset, dser, depth)
case _Pointer:
pointer, newOffset, err := d.decodePointer(size, offset)
if err != nil {
return 0, err
}
_, err = d.decodeToDeserializer(pointer, dser, depth)
return newOffset, err
case _Slice:
return d.decodeSliceToDeserializer(size, offset, dser, depth)
}

// For the remaining types, size is the byte size
if offset+size > uint(len(d.buffer)) {
return 0, newOffsetError()
}
switch dtype {
case _Bytes:
v, offset := d.decodeBytes(size, offset)
return offset, dser.Bytes(v)
case _Float32:
v, offset := d.decodeFloat32(size, offset)
return offset, dser.Float32(v)
case _Float64:
v, offset := d.decodeFloat64(size, offset)
return offset, dser.Float64(v)
case _Int32:
v, offset := d.decodeInt(size, offset)
return offset, dser.Int32(int32(v))
case _String:
v, offset := d.decodeString(size, offset)
return offset, dser.String(v)
case _Uint16:
v, offset := d.decodeUint(size, offset)
return offset, dser.Uint16(uint16(v))
case _Uint32:
v, offset := d.decodeUint(size, offset)
return offset, dser.Uint32(uint32(v))
case _Uint64:
v, offset := d.decodeUint(size, offset)
return offset, dser.Uint64(v)
case _Uint128:
v, offset := d.decodeUint128(size, offset)
return offset, dser.Uint128(v)
default:
return 0, newInvalidDatabaseError("unknown type: %d", dtype)
}
}

func (d *decoder) unmarshalBool(size, offset uint, result reflect.Value) (uint, error) {
if size > 1 {
return 0, newInvalidDatabaseError("the MaxMind DB file's data section contains bad data (bool size of %v)", size)
Expand Down Expand Up @@ -199,6 +281,7 @@ func (d *decoder) indirect(result reflect.Value) reflect.Value {
if result.IsNil() {
result.Set(reflect.New(result.Type().Elem()))
}

result = result.Elem()
}
return result
Expand Down Expand Up @@ -486,6 +569,35 @@ func (d *decoder) decodeMap(
return offset, nil
}

func (d *decoder) decodeMapToDeserializer(
size uint,
offset uint,
dser deserializer,
depth int,
) (uint, error) {
err := dser.StartMap(size)
if err != nil {
return 0, err
}
for i := uint(0); i < size; i++ {
// TODO - implement key/value skipping?
offset, err = d.decodeToDeserializer(offset, dser, depth)
if err != nil {
return 0, err
}

offset, err = d.decodeToDeserializer(offset, dser, depth)
if err != nil {
return 0, err
}
}
err = dser.End()
if err != nil {
return 0, err
}
return offset, nil
}

func (d *decoder) decodePointer(
size uint,
offset uint,
Expand Down Expand Up @@ -538,6 +650,29 @@ func (d *decoder) decodeSlice(
return offset, nil
}

func (d *decoder) decodeSliceToDeserializer(
size uint,
offset uint,
dser deserializer,
depth int,
) (uint, error) {
err := dser.StartSlice(size)
if err != nil {
return 0, err
}
for i := uint(0); i < size; i++ {
offset, err = d.decodeToDeserializer(offset, dser, depth)
if err != nil {
return 0, err
}
}
err = dser.End()
if err != nil {
return 0, err
}
return offset, nil
}

func (d *decoder) decodeString(size, offset uint) (string, uint) {
newOffset := offset + size
return string(d.buffer[offset:newOffset]), newOffset
Expand Down
31 changes: 31 additions & 0 deletions deserializer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package maxminddb

import "math/big"

// deserializer is an interface for a type that deserializes an MaxMind DB
// data record to some other type. This exists as an alternative to the
// standard reflection API.
//
// This is fundamentally different than the Unmarshaler interface that
// several packages provide. A Deserializer will generally create the
// final struct or value rather than unmarshaling to itself.
//
// This interface and the associated unmarshaling code is EXPERIMENTAL!
// It is not currently covered by any Semantic Versioning guarantees.
// Use at your own risk.
type deserializer interface {
ShouldSkip(offset uintptr) (bool, error)
StartSlice(size uint) error
StartMap(size uint) error
End() error
String(string) error
Float64(float64) error
Bytes([]byte) error
Uint16(uint16) error
Uint32(uint32) error
Int32(int32) error
Uint64(uint64) error
Uint128(*big.Int) error
Bool(bool) error
Float32(float32) error
}
119 changes: 119 additions & 0 deletions deserializer.go deserializer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
package maxminddb

import (
"math/big"
"net"
"testing"

"github.com/stretchr/testify/require"
)

func TestDecodingToDeserializer(t *testing.T) {
reader, err := Open(testFile("MaxMind-DB-test-decoder.mmdb"))
require.NoError(t, err, "unexpected error while opening database: %v", err)

dser := testDeserializer{}
err = reader.Lookup(net.ParseIP("::1.1.1.0"), &dser)
require.NoError(t, err, "unexpected error while doing lookup: %v", err)

checkDecodingToInterface(t, dser.rv)
}

type stackValue struct {
value interface{}
curNum int
}

type testDeserializer struct {
stack []*stackValue
rv interface{}
key *string
}

func (d *testDeserializer) ShouldSkip(offset uintptr) (bool, error) {
return false, nil
}

func (d *testDeserializer) StartSlice(size uint) error {
return d.add(make([]interface{}, size))
}

func (d *testDeserializer) StartMap(size uint) error {
return d.add(map[string]interface{}{})
}

func (d *testDeserializer) End() error {
d.stack = d.stack[:len(d.stack)-1]
return nil
}

func (d *testDeserializer) String(v string) error {
return d.add(v)
}

func (d *testDeserializer) Float64(v float64) error {
return d.add(v)
}

func (d *testDeserializer) Bytes(v []byte) error {
return d.add(v)
}

func (d *testDeserializer) Uint16(v uint16) error {
return d.add(uint64(v))
}

func (d *testDeserializer) Uint32(v uint32) error {
return d.add(uint64(v))
}

func (d *testDeserializer) Int32(v int32) error {
return d.add(int(v))
}

func (d *testDeserializer) Uint64(v uint64) error {
return d.add(v)
}

func (d *testDeserializer) Uint128(v *big.Int) error {
return d.add(v)
}

func (d *testDeserializer) Bool(v bool) error {
return d.add(v)
}

func (d *testDeserializer) Float32(v float32) error {
return d.add(v)
}

func (d *testDeserializer) add(v interface{}) error {
if len(d.stack) == 0 {
d.rv = v
} else {
top := d.stack[len(d.stack)-1]
switch parent := top.value.(type) {
case map[string]interface{}:
if d.key == nil {
key := v.(string)
d.key = &key
} else {
parent[*d.key] = v
d.key = nil
}

case []interface{}:
parent[top.curNum] = v
top.curNum++
default:
}
}

switch v := v.(type) {
case map[string]interface{}, []interface{}:
d.stack = append(d.stack, &stackValue{value: v})
default:
}

return nil
}
8 changes: 0 additions & 8 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,11 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.5.0 h1:DMOzIV76tmoDNE9pX6RSN0aDtCYeCg5VueieJaAo1uw=
github.com/stretchr/testify v1.5.0/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4=
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
golang.org/x/sys v0.0.0-20191224085550-c709ea063b76 h1:Dho5nD6R3PcW2SH1or8vS0dszDaXRxIw55lBX7XiE5g=
golang.org/x/sys v0.0.0-20191224085550-c709ea063b76/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
5 changes: 5 additions & 0 deletions reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,11 @@ func (r *Reader) decode(offset uintptr, result interface{}) error {
return errors.New("result param must be a pointer")
}

if dser, ok := result.(deserializer); ok {
_, err := r.decoder.decodeToDeserializer(uint(offset), dser, 0)
return err
}

_, err := r.decoder.decode(uint(offset), rv, 0)
return err
}
Expand Down
4 changes: 4 additions & 0 deletions reader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,10 @@ func TestDecodingToInterface(t *testing.T) {
err = reader.Lookup(net.ParseIP("::1.1.1.0"), &recordInterface)
require.NoError(t, err, "unexpected error while doing lookup: %v", err)

checkDecodingToInterface(t, recordInterface)
}

func checkDecodingToInterface(t *testing.T, recordInterface interface{}) {
record := recordInterface.(map[string]interface{})
assert.Equal(t, []interface{}{uint64(1), uint64(2), uint64(3)}, record["array"])
assert.Equal(t, true, record["boolean"])
Expand Down