Skip to content

Commit ede8d26

Browse files
committed
Add support for getting record network
1 parent 8774e16 commit ede8d26

File tree

2 files changed

+154
-15
lines changed

2 files changed

+154
-15
lines changed

reader.go

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -110,13 +110,36 @@ func (r *Reader) Lookup(ipAddress net.IP, result interface{}) error {
110110
if r.buffer == nil {
111111
return errors.New("cannot call Lookup on a closed database")
112112
}
113-
pointer, err := r.lookupPointer(ipAddress)
113+
pointer, _, _, err := r.lookupPointer(ipAddress)
114114
if pointer == 0 || err != nil {
115115
return err
116116
}
117117
return r.retrieveData(pointer, result)
118118
}
119119

120+
// LookupNetwork retrieves the database record for ipAddress and stores it in
121+
// the value pointed to be result. The network returned is the network
122+
// associated with the data record in the database. The ok return value
123+
// indicates whether the database contained a record for the ipAddress.
124+
//
125+
// If result is nil or not a pointer, an error is returned. If the data in the
126+
// database record cannot be stored in result because of type differences, an
127+
// UnmarshalTypeError is returned. If the database is invalid or otherwise
128+
// cannot be read, an InvalidDatabaseError is returned.
129+
func (r *Reader) LookupNetwork(ipAddress net.IP, result interface{}) (network *net.IPNet, ok bool, err error) {
130+
if r.buffer == nil {
131+
return nil, false, errors.New("cannot call Lookup on a closed database")
132+
}
133+
pointer, prefixLength, ipAddress, err := r.lookupPointer(ipAddress)
134+
135+
network = r.cidr(ipAddress, prefixLength)
136+
if pointer == 0 || err != nil {
137+
return network, false, err
138+
}
139+
140+
return network, true, r.retrieveData(pointer, result)
141+
}
142+
120143
// LookupOffset maps an argument net.IP to a corresponding record offset in the
121144
// database. NotFound is returned if no such record is found, and a record may
122145
// otherwise be extracted by passing the returned offset to Decode. LookupOffset
@@ -126,13 +149,20 @@ func (r *Reader) LookupOffset(ipAddress net.IP) (uintptr, error) {
126149
if r.buffer == nil {
127150
return 0, errors.New("cannot call LookupOffset on a closed database")
128151
}
129-
pointer, err := r.lookupPointer(ipAddress)
152+
pointer, _, _, err := r.lookupPointer(ipAddress)
130153
if pointer == 0 || err != nil {
131154
return NotFound, err
132155
}
133156
return r.resolveDataPointer(pointer)
134157
}
135158

159+
func (r *Reader) cidr(ipAddress net.IP, prefixLength int) *net.IPNet {
160+
ipBitLength := len(ipAddress) * 8
161+
mask := net.CIDRMask(prefixLength, ipBitLength)
162+
163+
return &net.IPNet{IP: ipAddress.Mask(mask), Mask: mask}
164+
}
165+
136166
// Decode the record at |offset| into |result|. The result value pointed to
137167
// must be a data value that corresponds to a record in the database. This may
138168
// include a struct representation of the data, a map capable of holding the
@@ -166,24 +196,19 @@ func (r *Reader) decode(offset uintptr, result interface{}) error {
166196
return err
167197
}
168198

169-
func (r *Reader) lookupPointer(ipAddress net.IP) (uint, error) {
199+
func (r *Reader) lookupPointer(ipAddress net.IP) (uint, int, net.IP, error) {
170200
if ipAddress == nil {
171-
return 0, errors.New("ipAddress passed to Lookup cannot be nil")
201+
return 0, 0, ipAddress, errors.New("ipAddress passed to Lookup cannot be nil")
172202
}
173203

174204
ipV4Address := ipAddress.To4()
175205
if ipV4Address != nil {
176206
ipAddress = ipV4Address
177207
}
178208
if len(ipAddress) == 16 && r.Metadata.IPVersion == 4 {
179-
return 0, fmt.Errorf("error looking up '%s': you attempted to look up an IPv6 address in an IPv4-only database", ipAddress.String())
209+
return 0, 0, ipAddress, fmt.Errorf("error looking up '%s': you attempted to look up an IPv6 address in an IPv4-only database", ipAddress.String())
180210
}
181211

182-
return r.findAddressInTree(ipAddress)
183-
}
184-
185-
func (r *Reader) findAddressInTree(ipAddress net.IP) (uint, error) {
186-
187212
bitCount := uint(len(ipAddress) * 8)
188213

189214
var node uint
@@ -193,23 +218,24 @@ func (r *Reader) findAddressInTree(ipAddress net.IP) (uint, error) {
193218

194219
nodeCount := r.Metadata.NodeCount
195220

196-
for i := uint(0); i < bitCount && node < nodeCount; i++ {
221+
i := uint(0)
222+
for ; i < bitCount && node < nodeCount; i++ {
197223
bit := uint(1) & (uint(ipAddress[i>>3]) >> (7 - (i % 8)))
198224

199225
var err error
200226
node, err = r.readNode(node, bit)
201227
if err != nil {
202-
return 0, err
228+
return 0, int(i), ipAddress, err
203229
}
204230
}
205231
if node == nodeCount {
206232
// Record is empty
207-
return 0, nil
233+
return 0, int(i), ipAddress, nil
208234
} else if node > nodeCount {
209-
return node, nil
235+
return node, int(i), ipAddress, nil
210236
}
211237

212-
return 0, newInvalidDatabaseError("invalid node in search tree")
238+
return 0, int(i), ipAddress, newInvalidDatabaseError("invalid node in search tree")
213239
}
214240

215241
func (r *Reader) readNode(nodeNumber uint, index uint) (uint, error) {

reader_test.go

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,119 @@ func TestReaderBytes(t *testing.T) {
5151
}
5252
}
5353

54+
func TestLookupNetwork(t *testing.T) {
55+
bigInt := new(big.Int)
56+
bigInt.SetString("1329227995784915872903807060280344576", 10)
57+
decoderRecord := map[string]interface{}{"array": []interface{}{uint64(1),
58+
uint64(2),
59+
uint64(3)},
60+
"boolean": true,
61+
"bytes": []uint8{
62+
0x0,
63+
0x0,
64+
0x0,
65+
0x2a,
66+
},
67+
"double": 42.123456,
68+
"float": float32(1.1),
69+
"int32": -268435456,
70+
"map": map[string]interface{}{
71+
"mapX": map[string]interface{}{
72+
"arrayX": []interface{}{
73+
uint64(0x7),
74+
uint64(0x8),
75+
uint64(0x9)},
76+
"utf8_stringX": "hello",
77+
},
78+
},
79+
"uint128": bigInt,
80+
"uint16": uint64(0x64),
81+
"uint32": uint64(0x10000000),
82+
"uint64": uint64(0x1000000000000000),
83+
"utf8_string": "unicode! ☯ - ♫",
84+
}
85+
86+
tests := []struct {
87+
IP net.IP
88+
DBFile string
89+
ExpectedCIDR string
90+
ExpectedRecord interface{}
91+
ExpectedOK bool
92+
}{
93+
// XXX - add test of IPv4 lookup in IPv6 database with no IPv4 subtree
94+
{
95+
IP: net.ParseIP("1.1.1.1"),
96+
DBFile: "MaxMind-DB-test-ipv6-32.mmdb",
97+
ExpectedCIDR: "1.0.0.0/8",
98+
ExpectedRecord: nil,
99+
ExpectedOK: false,
100+
},
101+
{
102+
IP: net.ParseIP("::1:ffff:ffff"),
103+
DBFile: "MaxMind-DB-test-ipv6-24.mmdb",
104+
ExpectedCIDR: "::1:ffff:ffff/128",
105+
ExpectedRecord: map[string]interface{}{"ip": "::1:ffff:ffff"},
106+
ExpectedOK: true,
107+
},
108+
{
109+
IP: net.ParseIP("::2:0:1"),
110+
DBFile: "MaxMind-DB-test-ipv6-24.mmdb",
111+
ExpectedCIDR: "::2:0:0/122",
112+
ExpectedRecord: map[string]interface{}{"ip": "::2:0:0"},
113+
ExpectedOK: true,
114+
},
115+
{
116+
IP: net.ParseIP("1.1.1.1"),
117+
DBFile: "MaxMind-DB-test-ipv4-24.mmdb",
118+
ExpectedCIDR: "1.1.1.1/32",
119+
ExpectedRecord: map[string]interface{}{"ip": "1.1.1.1"},
120+
ExpectedOK: true,
121+
},
122+
{
123+
IP: net.ParseIP("1.1.1.3"),
124+
DBFile: "MaxMind-DB-test-ipv4-24.mmdb",
125+
ExpectedCIDR: "1.1.1.2/31",
126+
ExpectedRecord: map[string]interface{}{"ip": "1.1.1.2"},
127+
ExpectedOK: true,
128+
},
129+
{
130+
IP: net.ParseIP("1.1.1.3"),
131+
DBFile: "MaxMind-DB-test-decoder.mmdb",
132+
ExpectedCIDR: "1.1.1.0/24",
133+
ExpectedRecord: decoderRecord,
134+
ExpectedOK: true,
135+
},
136+
{
137+
IP: net.ParseIP("::ffff:1.1.1.128"),
138+
DBFile: "MaxMind-DB-test-decoder.mmdb",
139+
ExpectedCIDR: "1.1.1.0/24",
140+
ExpectedRecord: decoderRecord,
141+
ExpectedOK: true,
142+
},
143+
{
144+
IP: net.ParseIP("::1.1.1.128"),
145+
DBFile: "MaxMind-DB-test-decoder.mmdb",
146+
ExpectedCIDR: "::101:100/120",
147+
ExpectedRecord: decoderRecord,
148+
ExpectedOK: true,
149+
},
150+
}
151+
152+
for _, test := range tests {
153+
t.Run(fmt.Sprintf("%s - %s", test.DBFile, test.IP), func(t *testing.T) {
154+
var record interface{}
155+
reader, err := Open(testFile(test.DBFile))
156+
require.NoError(t, err)
157+
158+
network, ok, err := reader.LookupNetwork(test.IP, &record)
159+
require.NoError(t, err)
160+
assert.Equal(t, test.ExpectedOK, ok)
161+
assert.Equal(t, test.ExpectedCIDR, network.String())
162+
assert.Equal(t, test.ExpectedRecord, record)
163+
})
164+
}
165+
}
166+
54167
func TestDecodingToInterface(t *testing.T) {
55168
reader, err := Open(testFile("MaxMind-DB-test-decoder.mmdb"))
56169
require.NoError(t, err, "unexpected error while opening database: %v", err)

0 commit comments

Comments
 (0)