Skip to content

Commit 10529d0

Browse files
committed
Inline traversal functions
In order to make it easier to refactor and optimize.
1 parent c8040fe commit 10529d0

File tree

1 file changed

+90
-115
lines changed

1 file changed

+90
-115
lines changed

traverse.go

Lines changed: 90 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,7 @@ type netNode struct {
1414
pointer uint
1515
}
1616

17-
// networks represents a set of subnets that we are iterating over.
18-
type networks struct {
19-
err error
20-
reader *Reader
21-
nodes []netNode
22-
lastNode netNode
17+
type networkOptions struct {
2318
includeAliasedNetworks bool
2419
}
2520

@@ -29,12 +24,12 @@ var (
2924
)
3025

3126
// NetworksOption are options for Networks and NetworksWithin.
32-
type NetworksOption func(*networks)
27+
type NetworksOption func(*networkOptions)
3328

3429
// IncludeAliasedNetworks is an option for Networks and NetworksWithin
3530
// that makes them iterate over aliases of the IPv4 subtree in an IPv6
3631
// database, e.g., ::ffff:0:0/96, 2001::/32, and 2002::/16.
37-
func IncludeAliasedNetworks(networks *networks) {
32+
func IncludeAliasedNetworks(networks *networkOptions) {
3833
networks.includeAliasedNetworks = true
3934
}
4035

@@ -63,131 +58,111 @@ func (r *Reader) Networks(options ...NetworksOption) iter.Seq[Result] {
6358
// If the provided prefix is contained within a network in the database, the
6459
// iterator will iterate over exactly one network, the containing network.
6560
func (r *Reader) NetworksWithin(prefix netip.Prefix, options ...NetworksOption) iter.Seq[Result] {
66-
n := r.networksWithin(prefix, options...)
6761
return func(yield func(Result) bool) {
68-
for n.next() {
69-
if n.err != nil {
70-
yield(Result{err: n.err})
71-
return
72-
}
73-
74-
ip := n.lastNode.ip
75-
if isInIPv4Subtree(ip) {
76-
ip = v6ToV4(ip)
77-
}
78-
79-
offset, err := r.resolveDataPointer(n.lastNode.pointer)
80-
ok := yield(Result{
81-
decoder: r.decoder,
82-
ip: ip,
83-
offset: uint(offset),
84-
prefixLen: uint8(n.lastNode.bit),
85-
err: err,
62+
if r.Metadata.IPVersion == 4 && prefix.Addr().Is6() {
63+
yield(Result{
64+
err: fmt.Errorf(
65+
"error getting networks with '%s': you attempted to use an IPv6 network in an IPv4-only database",
66+
prefix,
67+
),
8668
})
87-
if !ok {
88-
return
89-
}
90-
}
91-
if n.err != nil {
92-
yield(Result{err: n.err})
69+
return
9370
}
94-
}
95-
}
9671

97-
func (r *Reader) networksWithin(prefix netip.Prefix, options ...NetworksOption) *networks {
98-
if r.Metadata.IPVersion == 4 && prefix.Addr().Is6() {
99-
return &networks{
100-
err: fmt.Errorf(
101-
"error getting networks with '%s': you attempted to use an IPv6 network in an IPv4-only database",
102-
prefix,
103-
),
72+
n := &networkOptions{}
73+
for _, option := range options {
74+
option(n)
10475
}
105-
}
106-
107-
networks := &networks{reader: r}
108-
for _, option := range options {
109-
option(networks)
110-
}
11176

112-
ip := prefix.Addr()
113-
netIP := ip
114-
stopBit := prefix.Bits()
115-
if ip.Is4() {
116-
netIP = v4ToV16(ip)
117-
stopBit += 96
118-
}
119-
120-
pointer, bit := r.traverseTree(ip, 0, stopBit)
77+
ip := prefix.Addr()
78+
netIP := ip
79+
stopBit := prefix.Bits()
80+
if ip.Is4() {
81+
netIP = v4ToV16(ip)
82+
stopBit += 96
83+
}
12184

122-
prefix, err := netIP.Prefix(bit)
123-
if err != nil {
124-
networks.err = fmt.Errorf("prefixing %s with %d", netIP, bit)
125-
}
85+
pointer, bit := r.traverseTree(ip, 0, stopBit)
12686

127-
networks.nodes = []netNode{
128-
{
129-
ip: prefix.Addr(),
130-
bit: uint(bit),
131-
pointer: pointer,
132-
},
133-
}
87+
prefix, err := netIP.Prefix(bit)
88+
if err != nil {
89+
yield(Result{
90+
err: fmt.Errorf("prefixing %s with %d", netIP, bit),
91+
})
92+
}
13493

135-
return networks
136-
}
94+
nodes := []netNode{
95+
{
96+
ip: prefix.Addr(),
97+
bit: uint(bit),
98+
pointer: pointer,
99+
},
100+
}
137101

138-
// next prepares the next network for reading with the Network method. It
139-
// returns true if there is another network to be processed and false if there
140-
// are no more networks or if there is an error.
141-
func (n *networks) next() bool {
142-
if n.err != nil {
143-
return false
144-
}
145-
for len(n.nodes) > 0 {
146-
node := n.nodes[len(n.nodes)-1]
147-
n.nodes = n.nodes[:len(n.nodes)-1]
148-
149-
for node.pointer != n.reader.Metadata.NodeCount {
150-
// This skips IPv4 aliases without hardcoding the networks that the writer
151-
// currently aliases.
152-
if !n.includeAliasedNetworks && n.reader.ipv4Start != 0 &&
153-
node.pointer == n.reader.ipv4Start && !isInIPv4Subtree(node.ip) {
154-
break
155-
}
102+
for len(nodes) > 0 {
103+
node := nodes[len(nodes)-1]
104+
nodes = nodes[:len(nodes)-1]
156105

157-
if node.pointer > n.reader.Metadata.NodeCount {
158-
n.lastNode = node
159-
return true
160-
}
161-
ipRight := node.ip.As16()
162-
if len(ipRight) <= int(node.bit>>3) {
163-
displayAddr := node.ip
164-
displayBits := node.bit
165-
if isInIPv4Subtree(node.ip) {
166-
displayAddr = v6ToV4(displayAddr)
167-
displayBits -= 96
106+
for node.pointer != r.Metadata.NodeCount {
107+
// This skips IPv4 aliases without hardcoding the networks that the writer
108+
// currently aliases.
109+
if !n.includeAliasedNetworks && r.ipv4Start != 0 &&
110+
node.pointer == r.ipv4Start && !isInIPv4Subtree(node.ip) {
111+
break
168112
}
169113

170-
n.err = newInvalidDatabaseError(
171-
"invalid search tree at %s/%d", displayAddr, displayBits)
172-
return false
173-
}
174-
ipRight[node.bit>>3] |= 1 << (7 - (node.bit % 8))
114+
if node.pointer > r.Metadata.NodeCount {
115+
ip := node.ip
116+
if isInIPv4Subtree(ip) {
117+
ip = v6ToV4(ip)
118+
}
119+
120+
offset, err := r.resolveDataPointer(node.pointer)
121+
ok := yield(Result{
122+
decoder: r.decoder,
123+
ip: ip,
124+
offset: uint(offset),
125+
prefixLen: uint8(node.bit),
126+
err: err,
127+
})
128+
if !ok {
129+
return
130+
}
131+
break
132+
}
133+
ipRight := node.ip.As16()
134+
if len(ipRight) <= int(node.bit>>3) {
135+
displayAddr := node.ip
136+
displayBits := node.bit
137+
if isInIPv4Subtree(node.ip) {
138+
displayAddr = v6ToV4(displayAddr)
139+
displayBits -= 96
140+
}
141+
142+
yield(Result{
143+
ip: displayAddr,
144+
prefixLen: uint8(node.bit),
145+
err: newInvalidDatabaseError(
146+
"invalid search tree at %s/%d", displayAddr, displayBits),
147+
})
148+
return
149+
}
150+
ipRight[node.bit>>3] |= 1 << (7 - (node.bit % 8))
175151

176-
offset := node.pointer * n.reader.nodeOffsetMult
177-
rightPointer := n.reader.nodeReader.readRight(offset)
152+
offset := node.pointer * r.nodeOffsetMult
153+
rightPointer := r.nodeReader.readRight(offset)
178154

179-
node.bit++
180-
n.nodes = append(n.nodes, netNode{
181-
pointer: rightPointer,
182-
ip: netip.AddrFrom16(ipRight),
183-
bit: node.bit,
184-
})
155+
node.bit++
156+
nodes = append(nodes, netNode{
157+
pointer: rightPointer,
158+
ip: netip.AddrFrom16(ipRight),
159+
bit: node.bit,
160+
})
185161

186-
node.pointer = n.reader.nodeReader.readLeft(offset)
162+
node.pointer = r.nodeReader.readLeft(offset)
163+
}
187164
}
188165
}
189-
190-
return false
191166
}
192167

193168
var ipv4SubtreeBoundary = netip.MustParseAddr("::255.255.255.255").Next()

0 commit comments

Comments
 (0)