Skip to content

Commit 2f34e98

Browse files
authored
[client] Add TCP support to DNS forwarder service listener (#3790)
[client] Add TCP support to DNS forwarder service listener
1 parent d5b52e8 commit 2f34e98

File tree

2 files changed

+98
-20
lines changed

2 files changed

+98
-20
lines changed

client/internal/dnsfwd/forwarder.go

Lines changed: 84 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ type DNSForwarder struct {
3333

3434
dnsServer *dns.Server
3535
mux *dns.ServeMux
36+
tcpServer *dns.Server
37+
tcpMux *dns.ServeMux
3638

3739
mutex sync.RWMutex
3840
fwdEntries []*ForwarderEntry
@@ -50,22 +52,41 @@ func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewall.Manager
5052
}
5153

5254
func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error {
53-
log.Infof("listen DNS forwarder on address=%s", f.listenAddress)
54-
mux := dns.NewServeMux()
55+
log.Infof("starting DNS forwarder on address=%s", f.listenAddress)
5556

56-
dnsServer := &dns.Server{
57+
// UDP server
58+
mux := dns.NewServeMux()
59+
f.mux = mux
60+
f.dnsServer = &dns.Server{
5761
Addr: f.listenAddress,
5862
Net: "udp",
5963
Handler: mux,
6064
}
61-
f.dnsServer = dnsServer
62-
f.mux = mux
65+
// TCP server
66+
tcpMux := dns.NewServeMux()
67+
f.tcpMux = tcpMux
68+
f.tcpServer = &dns.Server{
69+
Addr: f.listenAddress,
70+
Net: "tcp",
71+
Handler: tcpMux,
72+
}
6373

6474
f.UpdateDomains(entries)
6575

66-
return dnsServer.ListenAndServe()
67-
}
76+
errCh := make(chan error, 2)
77+
78+
go func() {
79+
log.Infof("DNS UDP listener running on %s", f.listenAddress)
80+
errCh <- f.dnsServer.ListenAndServe()
81+
}()
82+
go func() {
83+
log.Infof("DNS TCP listener running on %s", f.listenAddress)
84+
errCh <- f.tcpServer.ListenAndServe()
85+
}()
6886

87+
// return the first error we get (e.g. bind failure or shutdown)
88+
return <-errCh
89+
}
6990
func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) {
7091
f.mutex.Lock()
7192
defer f.mutex.Unlock()
@@ -77,31 +98,41 @@ func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) {
7798
}
7899

79100
oldDomains := filterDomains(f.fwdEntries)
80-
81101
for _, d := range oldDomains {
82102
f.mux.HandleRemove(d.PunycodeString())
103+
f.tcpMux.HandleRemove(d.PunycodeString())
83104
}
84105

85106
newDomains := filterDomains(entries)
86107
for _, d := range newDomains {
87-
f.mux.HandleFunc(d.PunycodeString(), f.handleDNSQuery)
108+
f.mux.HandleFunc(d.PunycodeString(), f.handleDNSQueryUDP)
109+
f.tcpMux.HandleFunc(d.PunycodeString(), f.handleDNSQueryTCP)
88110
}
89111

90112
f.fwdEntries = entries
91-
92113
log.Debugf("Updated domains from %v to %v", oldDomains, newDomains)
93114
}
94115

95116
func (f *DNSForwarder) Close(ctx context.Context) error {
96-
if f.dnsServer == nil {
97-
return nil
117+
var result *multierror.Error
118+
119+
if f.dnsServer != nil {
120+
if err := f.dnsServer.ShutdownContext(ctx); err != nil {
121+
result = multierror.Append(result, fmt.Errorf("UDP shutdown: %w", err))
122+
}
123+
}
124+
if f.tcpServer != nil {
125+
if err := f.tcpServer.ShutdownContext(ctx); err != nil {
126+
result = multierror.Append(result, fmt.Errorf("TCP shutdown: %w", err))
127+
}
98128
}
99-
return f.dnsServer.ShutdownContext(ctx)
129+
130+
return nberrors.FormatErrorOrNil(result)
100131
}
101132

102-
func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
133+
func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns.Msg {
103134
if len(query.Question) == 0 {
104-
return
135+
return nil
105136
}
106137
question := query.Question[0]
107138
log.Tracef("received DNS request for DNS forwarder: domain=%v type=%v class=%v",
@@ -123,20 +154,53 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
123154
if err := w.WriteMsg(resp); err != nil {
124155
log.Errorf("failed to write DNS response: %v", err)
125156
}
126-
return
157+
return nil
127158
}
128159

129160
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
130161
defer cancel()
131162
ips, err := net.DefaultResolver.LookupNetIP(ctx, network, domain)
132163
if err != nil {
133-
f.handleDNSError(w, resp, domain, err)
134-
return
164+
f.handleDNSError(w, query, resp, domain, err)
165+
return nil
135166
}
136167

137168
f.updateInternalState(domain, ips)
138169
f.addIPsToResponse(resp, domain, ips)
139170

171+
return resp
172+
}
173+
174+
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
175+
176+
resp := f.handleDNSQuery(w, query)
177+
if resp == nil {
178+
return
179+
}
180+
181+
opt := query.IsEdns0()
182+
maxSize := dns.MinMsgSize
183+
if opt != nil {
184+
// client advertised a larger EDNS0 buffer
185+
maxSize = int(opt.UDPSize())
186+
}
187+
188+
// if our response is too big, truncate and set the TC bit
189+
if resp.Len() > maxSize {
190+
resp.Truncate(maxSize)
191+
}
192+
193+
if err := w.WriteMsg(resp); err != nil {
194+
log.Errorf("failed to write DNS response: %v", err)
195+
}
196+
}
197+
198+
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
199+
resp := f.handleDNSQuery(w, query)
200+
if resp == nil {
201+
return
202+
}
203+
140204
if err := w.WriteMsg(resp); err != nil {
141205
log.Errorf("failed to write DNS response: %v", err)
142206
}
@@ -179,7 +243,7 @@ func (f *DNSForwarder) updateFirewall(matchingEntries []*ForwarderEntry, prefixe
179243
}
180244

181245
// handleDNSError processes DNS lookup errors and sends an appropriate error response
182-
func (f *DNSForwarder) handleDNSError(w dns.ResponseWriter, resp *dns.Msg, domain string, err error) {
246+
func (f *DNSForwarder) handleDNSError(w dns.ResponseWriter, query, resp *dns.Msg, domain string, err error) {
183247
var dnsErr *net.DNSError
184248

185249
switch {
@@ -191,7 +255,7 @@ func (f *DNSForwarder) handleDNSError(w dns.ResponseWriter, resp *dns.Msg, domai
191255
}
192256

193257
if dnsErr.Server != "" {
194-
log.Warnf("failed to resolve query for domain=%s server=%s: %v", domain, dnsErr.Server, err)
258+
log.Warnf("failed to resolve query for type=%s domain=%s server=%s: %v", dns.TypeToString[query.Question[0].Qtype], domain, dnsErr.Server, err)
195259
} else {
196260
log.Warnf(errResolveFailed, domain, err)
197261
}

client/internal/dnsfwd/manager.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ type Manager struct {
3333
statusRecorder *peer.Status
3434

3535
fwRules []firewall.Rule
36+
tcpRules []firewall.Rule
3637
dnsForwarder *DNSForwarder
3738
}
3839

@@ -107,6 +108,13 @@ func (m *Manager) allowDNSFirewall() error {
107108
}
108109
m.fwRules = dnsRules
109110

111+
tcpRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolTCP, nil, dport, firewall.ActionAccept, "")
112+
if err != nil {
113+
log.Errorf("failed to add allow DNS router rules, err: %v", err)
114+
return err
115+
}
116+
m.tcpRules = tcpRules
117+
110118
return nil
111119
}
112120

@@ -117,7 +125,13 @@ func (m *Manager) dropDNSFirewall() error {
117125
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err))
118126
}
119127
}
128+
for _, rule := range m.tcpRules {
129+
if err := m.firewall.DeletePeerRule(rule); err != nil {
130+
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err))
131+
}
132+
}
120133

121134
m.fwRules = nil
135+
m.tcpRules = nil
122136
return nberrors.FormatErrorOrNil(mErr)
123137
}

0 commit comments

Comments
 (0)