Skip to content

Commit d1c282d

Browse files
committed
Proof of concept of NetworksWithin
1 parent 6a033e6 commit d1c282d

File tree

3 files changed

+104
-10
lines changed

3 files changed

+104
-10
lines changed

reader.go

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,20 @@ func (r *Reader) lookupPointer(ip net.IP) (uint, int, net.IP, error) {
249249
if bitCount == 32 {
250250
node = r.ipv4Start
251251
}
252+
node, prefixLength := r.traverseTree(ip, node, bitCount)
252253

254+
nodeCount := r.Metadata.NodeCount
255+
if node == nodeCount {
256+
// Record is empty
257+
return 0, prefixLength, ip, nil
258+
} else if node > nodeCount {
259+
return node, prefixLength, ip, nil
260+
}
261+
262+
return 0, prefixLength, ip, newInvalidDatabaseError("invalid node in search tree")
263+
}
264+
265+
func (r *Reader) traverseTree(ip net.IP, node uint, bitCount uint) (uint, int) {
253266
nodeCount := r.Metadata.NodeCount
254267

255268
i := uint(0)
@@ -263,14 +276,8 @@ func (r *Reader) lookupPointer(ip net.IP) (uint, int, net.IP, error) {
263276
node = r.nodeReader.readRight(offset)
264277
}
265278
}
266-
if node == nodeCount {
267-
// Record is empty
268-
return 0, int(i), ip, nil
269-
} else if node > nodeCount {
270-
return node, int(i), ip, nil
271-
}
272279

273-
return 0, int(i), ip, newInvalidDatabaseError("invalid node in search tree")
280+
return node, int(i)
274281
}
275282

276283
func (r *Reader) retrieveData(pointer uint, result interface{}) error {

traverse.go

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,36 @@ type Networks struct {
1717
err error
1818
}
1919

20+
var allIPv4 = &net.IPNet{IP: make(net.IP, 4), Mask: net.CIDRMask(0, 32)}
21+
var allIPv6 = &net.IPNet{IP: make(net.IP, 16), Mask: net.CIDRMask(0, 128)}
22+
2023
// Networks returns an iterator that can be used to traverse all networks in
2124
// the database.
2225
//
2326
// Please note that a MaxMind DB may map IPv4 networks into several locations
2427
// in an IPv6 database. This iterator will iterate over all of these
2528
// locations separately.
2629
func (r *Reader) Networks() *Networks {
27-
s := 4
30+
var networks *Networks
2831
if r.Metadata.IPVersion == 6 {
29-
s = 16
32+
networks = r.NetworksWithin(allIPv6)
33+
} else {
34+
networks = r.NetworksWithin(allIPv4)
3035
}
36+
37+
return networks
38+
}
39+
40+
func (r *Reader) NetworksWithin(network *net.IPNet) *Networks {
41+
prefixLength, _ := network.Mask.Size()
42+
pointer, bit := r.traverseTree(network.IP, 0, uint(prefixLength))
3143
return &Networks{
3244
reader: r,
3345
nodes: []netNode{
3446
{
35-
ip: make(net.IP, s),
47+
ip: network.IP,
48+
bit: uint(bit),
49+
pointer: pointer,
3650
},
3751
},
3852
}

traverse_test.go

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package maxminddb
22

33
import (
44
"fmt"
5+
"net"
56
"testing"
67

78
"github.com/stretchr/testify/assert"
@@ -46,3 +47,75 @@ func TestNetworksWithInvalidSearchTree(t *testing.T) {
4647
assert.NotNil(t, n.Err(), "no error received when traversing an broken search tree")
4748
assert.Equal(t, n.Err().Error(), "invalid search tree at 128.128.128.128/32")
4849
}
50+
func TestNetworksWithin(t *testing.T) {
51+
_, network, error := net.ParseCIDR("1.1.1.0/24")
52+
53+
assert.Nil(t, error)
54+
55+
for _, recordSize := range []uint{24, 28, 32} {
56+
fileName := testFile(fmt.Sprintf("MaxMind-DB-test-ipv4-%d.mmdb", recordSize))
57+
reader, err := Open(fileName)
58+
require.Nil(t, err, "unexpected error while opening database: %v", err)
59+
defer reader.Close()
60+
61+
n := reader.NetworksWithin(network)
62+
var innerIPs []string
63+
64+
for n.Next() {
65+
record := struct {
66+
IP string `maxminddb:"ip"`
67+
}{}
68+
network, err := n.Network(&record)
69+
assert.Nil(t, err)
70+
assert.Equal(t, record.IP, network.IP.String(),
71+
"expected %s got %s", record.IP, network.IP.String(),
72+
)
73+
innerIPs = append(innerIPs, record.IP)
74+
}
75+
76+
expectedIPs := []string{
77+
"1.1.1.1",
78+
"1.1.1.2",
79+
"1.1.1.4",
80+
"1.1.1.8",
81+
"1.1.1.16",
82+
"1.1.1.32",
83+
}
84+
85+
assert.Equal(t, expectedIPs, innerIPs)
86+
assert.Nil(t, n.Err())
87+
}
88+
}
89+
90+
func TestNetworksWithinSlash32(t *testing.T) {
91+
_, network, error := net.ParseCIDR("1.1.1.32/32")
92+
93+
assert.Nil(t, error)
94+
95+
for _, recordSize := range []uint{24, 28, 32} {
96+
fileName := testFile(fmt.Sprintf("MaxMind-DB-test-ipv4-%d.mmdb", recordSize))
97+
reader, err := Open(fileName)
98+
require.Nil(t, err, "unexpected error while opening database: %v", err)
99+
defer reader.Close()
100+
101+
n := reader.NetworksWithin(network)
102+
var innerIPs []string
103+
104+
for n.Next() {
105+
record := struct {
106+
IP string `maxminddb:"ip"`
107+
}{}
108+
network, err := n.Network(&record)
109+
assert.Nil(t, err)
110+
assert.Equal(t, record.IP, network.IP.String(),
111+
"expected %s got %s", record.IP, network.IP.String(),
112+
)
113+
innerIPs = append(innerIPs, record.IP)
114+
}
115+
116+
expectedIPs := []string([]string{"1.1.1.32"})
117+
118+
assert.Equal(t, expectedIPs, innerIPs)
119+
assert.Nil(t, n.Err())
120+
}
121+
}

0 commit comments

Comments
 (0)