@@ -33,6 +33,8 @@ type DNSForwarder struct {
33
33
34
34
dnsServer * dns.Server
35
35
mux * dns.ServeMux
36
+ tcpServer * dns.Server
37
+ tcpMux * dns.ServeMux
36
38
37
39
mutex sync.RWMutex
38
40
fwdEntries []* ForwarderEntry
@@ -50,22 +52,41 @@ func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewall.Manager
50
52
}
51
53
52
54
func (f * DNSForwarder ) Listen (entries []* ForwarderEntry ) error {
53
- log .Infof ("listen DNS forwarder on address=%s" , f .listenAddress )
54
- mux := dns .NewServeMux ()
55
+ log .Infof ("starting DNS forwarder on address=%s" , f .listenAddress )
55
56
56
- dnsServer := & dns.Server {
57
+ // UDP server
58
+ mux := dns .NewServeMux ()
59
+ f .mux = mux
60
+ f .dnsServer = & dns.Server {
57
61
Addr : f .listenAddress ,
58
62
Net : "udp" ,
59
63
Handler : mux ,
60
64
}
61
- f .dnsServer = dnsServer
62
- f .mux = mux
65
+ // TCP server
66
+ tcpMux := dns .NewServeMux ()
67
+ f .tcpMux = tcpMux
68
+ f .tcpServer = & dns.Server {
69
+ Addr : f .listenAddress ,
70
+ Net : "tcp" ,
71
+ Handler : tcpMux ,
72
+ }
63
73
64
74
f .UpdateDomains (entries )
65
75
66
- return dnsServer .ListenAndServe ()
67
- }
76
+ errCh := make (chan error , 2 )
77
+
78
+ go func () {
79
+ log .Infof ("DNS UDP listener running on %s" , f .listenAddress )
80
+ errCh <- f .dnsServer .ListenAndServe ()
81
+ }()
82
+ go func () {
83
+ log .Infof ("DNS TCP listener running on %s" , f .listenAddress )
84
+ errCh <- f .tcpServer .ListenAndServe ()
85
+ }()
68
86
87
+ // return the first error we get (e.g. bind failure or shutdown)
88
+ return <- errCh
89
+ }
69
90
func (f * DNSForwarder ) UpdateDomains (entries []* ForwarderEntry ) {
70
91
f .mutex .Lock ()
71
92
defer f .mutex .Unlock ()
@@ -77,31 +98,41 @@ func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) {
77
98
}
78
99
79
100
oldDomains := filterDomains (f .fwdEntries )
80
-
81
101
for _ , d := range oldDomains {
82
102
f .mux .HandleRemove (d .PunycodeString ())
103
+ f .tcpMux .HandleRemove (d .PunycodeString ())
83
104
}
84
105
85
106
newDomains := filterDomains (entries )
86
107
for _ , d := range newDomains {
87
- f .mux .HandleFunc (d .PunycodeString (), f .handleDNSQuery )
108
+ f .mux .HandleFunc (d .PunycodeString (), f .handleDNSQueryUDP )
109
+ f .tcpMux .HandleFunc (d .PunycodeString (), f .handleDNSQueryTCP )
88
110
}
89
111
90
112
f .fwdEntries = entries
91
-
92
113
log .Debugf ("Updated domains from %v to %v" , oldDomains , newDomains )
93
114
}
94
115
95
116
func (f * DNSForwarder ) Close (ctx context.Context ) error {
96
- if f .dnsServer == nil {
97
- return nil
117
+ var result * multierror.Error
118
+
119
+ if f .dnsServer != nil {
120
+ if err := f .dnsServer .ShutdownContext (ctx ); err != nil {
121
+ result = multierror .Append (result , fmt .Errorf ("UDP shutdown: %w" , err ))
122
+ }
123
+ }
124
+ if f .tcpServer != nil {
125
+ if err := f .tcpServer .ShutdownContext (ctx ); err != nil {
126
+ result = multierror .Append (result , fmt .Errorf ("TCP shutdown: %w" , err ))
127
+ }
98
128
}
99
- return f .dnsServer .ShutdownContext (ctx )
129
+
130
+ return nberrors .FormatErrorOrNil (result )
100
131
}
101
132
102
- func (f * DNSForwarder ) handleDNSQuery (w dns.ResponseWriter , query * dns.Msg ) {
133
+ func (f * DNSForwarder ) handleDNSQuery (w dns.ResponseWriter , query * dns.Msg ) * dns. Msg {
103
134
if len (query .Question ) == 0 {
104
- return
135
+ return nil
105
136
}
106
137
question := query .Question [0 ]
107
138
log .Tracef ("received DNS request for DNS forwarder: domain=%v type=%v class=%v" ,
@@ -123,20 +154,53 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
123
154
if err := w .WriteMsg (resp ); err != nil {
124
155
log .Errorf ("failed to write DNS response: %v" , err )
125
156
}
126
- return
157
+ return nil
127
158
}
128
159
129
160
ctx , cancel := context .WithTimeout (context .Background (), upstreamTimeout )
130
161
defer cancel ()
131
162
ips , err := net .DefaultResolver .LookupNetIP (ctx , network , domain )
132
163
if err != nil {
133
- f .handleDNSError (w , resp , domain , err )
134
- return
164
+ f .handleDNSError (w , query , resp , domain , err )
165
+ return nil
135
166
}
136
167
137
168
f .updateInternalState (domain , ips )
138
169
f .addIPsToResponse (resp , domain , ips )
139
170
171
+ return resp
172
+ }
173
+
174
+ func (f * DNSForwarder ) handleDNSQueryUDP (w dns.ResponseWriter , query * dns.Msg ) {
175
+
176
+ resp := f .handleDNSQuery (w , query )
177
+ if resp == nil {
178
+ return
179
+ }
180
+
181
+ opt := query .IsEdns0 ()
182
+ maxSize := dns .MinMsgSize
183
+ if opt != nil {
184
+ // client advertised a larger EDNS0 buffer
185
+ maxSize = int (opt .UDPSize ())
186
+ }
187
+
188
+ // if our response is too big, truncate and set the TC bit
189
+ if resp .Len () > maxSize {
190
+ resp .Truncate (maxSize )
191
+ }
192
+
193
+ if err := w .WriteMsg (resp ); err != nil {
194
+ log .Errorf ("failed to write DNS response: %v" , err )
195
+ }
196
+ }
197
+
198
+ func (f * DNSForwarder ) handleDNSQueryTCP (w dns.ResponseWriter , query * dns.Msg ) {
199
+ resp := f .handleDNSQuery (w , query )
200
+ if resp == nil {
201
+ return
202
+ }
203
+
140
204
if err := w .WriteMsg (resp ); err != nil {
141
205
log .Errorf ("failed to write DNS response: %v" , err )
142
206
}
@@ -179,7 +243,7 @@ func (f *DNSForwarder) updateFirewall(matchingEntries []*ForwarderEntry, prefixe
179
243
}
180
244
181
245
// handleDNSError processes DNS lookup errors and sends an appropriate error response
182
- func (f * DNSForwarder ) handleDNSError (w dns.ResponseWriter , resp * dns.Msg , domain string , err error ) {
246
+ func (f * DNSForwarder ) handleDNSError (w dns.ResponseWriter , query , resp * dns.Msg , domain string , err error ) {
183
247
var dnsErr * net.DNSError
184
248
185
249
switch {
@@ -191,7 +255,7 @@ func (f *DNSForwarder) handleDNSError(w dns.ResponseWriter, resp *dns.Msg, domai
191
255
}
192
256
193
257
if dnsErr .Server != "" {
194
- log .Warnf ("failed to resolve query for domain=%s server=%s: %v" , domain , dnsErr .Server , err )
258
+ log .Warnf ("failed to resolve query for type=%s domain=%s server=%s: %v" , dns . TypeToString [ query . Question [ 0 ]. Qtype ] , domain , dnsErr .Server , err )
195
259
} else {
196
260
log .Warnf (errResolveFailed , domain , err )
197
261
}
0 commit comments