Skip to content

Commit 3cf67b9

Browse files
author
Divjot Arora
committed
GODRIVER-1879 Apply connectTimeoutMS to TLS handshake (#594)
1 parent 47f87bd commit 3cf67b9

File tree

4 files changed

+197
-11
lines changed

4 files changed

+197
-11
lines changed

x/mongo/driver/topology/connection.go

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,28 @@ func (c *connection) connect(ctx context.Context) {
104104
}
105105
defer close(c.connectDone)
106106

107+
// Create separate contexts for dialing a connection and doing the MongoDB/auth handshakes.
108+
//
109+
// handshakeCtx is simply a cancellable version of ctx because there's no default timeout that needs to be applied
110+
// to the full handshake. The cancellation allows consumers to bail out early when dialing a connection if it's no
111+
// longer required. This is done in lock because it accesses the shared cancelConnectContext field.
112+
//
113+
// dialCtx is equal to handshakeCtx if connectTimeoutMS=0. Otherwise, it is derived from handshakeCtx so the
114+
// cancellation still applies but with an added timeout to ensure the connectTimeoutMS option is applied to socket
115+
// establishment and the TLS handshake as a whole. This is created outside of the connectContextMutex lock to avoid
116+
// holding the lock longer than necessary.
107117
c.connectContextMutex.Lock()
108-
ctx, c.cancelConnectContext = context.WithCancel(ctx)
118+
var handshakeCtx context.Context
119+
handshakeCtx, c.cancelConnectContext = context.WithCancel(ctx)
109120
c.connectContextMutex.Unlock()
110121

122+
dialCtx := handshakeCtx
123+
var dialCancel context.CancelFunc
124+
if c.config.connectTimeout != 0 {
125+
dialCtx, dialCancel = context.WithTimeout(handshakeCtx, c.config.connectTimeout)
126+
defer dialCancel()
127+
}
128+
111129
defer func() {
112130
var cancelFn context.CancelFunc
113131

@@ -126,7 +144,7 @@ func (c *connection) connect(ctx context.Context) {
126144
// Assign the result of DialContext to a temporary net.Conn to ensure that c.nc is not set in an error case.
127145
var err error
128146
var tempNc net.Conn
129-
tempNc, err = c.config.dialer.DialContext(ctx, c.addr.Network(), c.addr.String())
147+
tempNc, err = c.config.dialer.DialContext(dialCtx, c.addr.Network(), c.addr.String())
130148
if err != nil {
131149
c.processInitializationError(err)
132150
return
@@ -142,7 +160,7 @@ func (c *connection) connect(ctx context.Context) {
142160
Cache: c.config.ocspCache,
143161
DisableEndpointChecking: c.config.disableOCSPEndpointCheck,
144162
}
145-
tlsNc, err := configureTLS(ctx, c.config.tlsConnectionSource, c.nc, c.addr, tlsConfig, ocspOpts)
163+
tlsNc, err := configureTLS(dialCtx, c.config.tlsConnectionSource, c.nc, c.addr, tlsConfig, ocspOpts)
146164
if err != nil {
147165
c.processInitializationError(err)
148166
return
@@ -160,10 +178,10 @@ func (c *connection) connect(ctx context.Context) {
160178

161179
handshakeStartTime := time.Now()
162180
handshakeConn := initConnection{c}
163-
c.desc, err = handshaker.GetDescription(ctx, c.addr, handshakeConn)
181+
c.desc, err = handshaker.GetDescription(handshakeCtx, c.addr, handshakeConn)
164182
if err == nil {
165183
c.isMasterRTT = time.Since(handshakeStartTime)
166-
err = handshaker.FinishHandshake(ctx, handshakeConn)
184+
err = handshaker.FinishHandshake(handshakeCtx, handshakeConn)
167185
}
168186
if err != nil {
169187
c.processInitializationError(err)

x/mongo/driver/topology/connection_options.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ func newConnectionConfig(opts ...ConnectionOption) (*connectionConfig, error) {
6969
}
7070

7171
if cfg.dialer == nil {
72-
cfg.dialer = &net.Dialer{Timeout: cfg.connectTimeout}
72+
cfg.dialer = &net.Dialer{}
7373
}
7474

7575
return cfg, nil

x/mongo/driver/topology/connection_test.go

Lines changed: 159 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ func TestConnection(t *testing.T) {
195195
for _, tc := range testCases {
196196
t.Run(tc.name, func(t *testing.T) {
197197
var sentCfg *tls.Config
198-
var testTLSConnectionSource tlsConnectionSourceFn = func(nc net.Conn, cfg *tls.Config) *tls.Conn {
198+
var testTLSConnectionSource tlsConnectionSourceFn = func(nc net.Conn, cfg *tls.Config) tlsConn {
199199
sentCfg = cfg
200200
return tls.Client(nc, cfg)
201201
}
@@ -228,6 +228,143 @@ func TestConnection(t *testing.T) {
228228
}
229229
})
230230
})
231+
t.Run("connectTimeout is applied correctly", func(t *testing.T) {
232+
testCases := []struct {
233+
name string
234+
contextTimeout time.Duration
235+
connectTimeout time.Duration
236+
maxConnectTime time.Duration
237+
}{
238+
// The timeout to dial a connection should be min(context timeout, connectTimeoutMS), so 1ms for
239+
// both of the tests declared below. Both tests also specify a 10ms max connect time to provide
240+
// a large buffer for lag and avoid test flakiness.
241+
242+
{"context timeout is lower", 1 * time.Millisecond, 100 * time.Millisecond, 10 * time.Millisecond},
243+
{"connect timeout is lower", 100 * time.Millisecond, 1 * time.Millisecond, 10 * time.Millisecond},
244+
}
245+
246+
for _, tc := range testCases {
247+
t.Run("timeout applied to socket establishment: "+tc.name, func(t *testing.T) {
248+
// Ensure the initial connection dial can be timed out and the connection propagates the error
249+
// from the dialer in this case.
250+
251+
connOpts := []ConnectionOption{
252+
WithDialer(func(Dialer) Dialer {
253+
return DialerFunc(func(ctx context.Context, _, _ string) (net.Conn, error) {
254+
<-ctx.Done()
255+
return nil, ctx.Err()
256+
})
257+
}),
258+
WithConnectTimeout(func(time.Duration) time.Duration {
259+
return tc.connectTimeout
260+
}),
261+
}
262+
conn, err := newConnection("", connOpts...)
263+
assert.Nil(t, err, "newConnection error: %v", err)
264+
265+
ctx, cancel := context.WithTimeout(context.Background(), tc.contextTimeout)
266+
defer cancel()
267+
var connectErr error
268+
callback := func() {
269+
conn.connect(ctx)
270+
connectErr = conn.wait()
271+
}
272+
assert.Soon(t, callback, tc.maxConnectTime)
273+
274+
ce, ok := connectErr.(ConnectionError)
275+
assert.True(t, ok, "expected error %v to be of type %T", connectErr, ConnectionError{})
276+
assert.Equal(t, context.DeadlineExceeded, ce.Unwrap(), "expected wrapped error to be %v, got %v",
277+
context.DeadlineExceeded, ce.Unwrap())
278+
})
279+
t.Run("timeout applied to TLS handshake: "+tc.name, func(t *testing.T) {
280+
// Ensure the TLS handshake can be timed out and the connection propagates the error from the
281+
// tlsConn in this case.
282+
283+
var hangingTLSConnectionSource tlsConnectionSourceFn = func(nc net.Conn, cfg *tls.Config) tlsConn {
284+
tlsConn := tls.Client(nc, cfg)
285+
return newHangingTLSConn(tlsConn, tc.maxConnectTime)
286+
}
287+
288+
connOpts := []ConnectionOption{
289+
WithConnectTimeout(func(time.Duration) time.Duration {
290+
return tc.connectTimeout
291+
}),
292+
WithDialer(func(Dialer) Dialer {
293+
return DialerFunc(func(context.Context, string, string) (net.Conn, error) {
294+
return &net.TCPConn{}, nil
295+
})
296+
}),
297+
WithTLSConfig(func(*tls.Config) *tls.Config {
298+
return &tls.Config{}
299+
}),
300+
withTLSConnectionSource(func(tlsConnectionSource) tlsConnectionSource {
301+
return hangingTLSConnectionSource
302+
}),
303+
}
304+
conn, err := newConnection("", connOpts...)
305+
assert.Nil(t, err, "newConnection error: %v", err)
306+
307+
ctx, cancel := context.WithTimeout(context.Background(), tc.contextTimeout)
308+
defer cancel()
309+
var connectErr error
310+
callback := func() {
311+
conn.connect(ctx)
312+
connectErr = conn.wait()
313+
}
314+
assert.Soon(t, callback, tc.maxConnectTime)
315+
316+
ce, ok := connectErr.(ConnectionError)
317+
assert.True(t, ok, "expected error %v to be of type %T", connectErr, ConnectionError{})
318+
assert.Equal(t, context.DeadlineExceeded, ce.Unwrap(), "expected wrapped error to be %v, got %v",
319+
context.DeadlineExceeded, ce.Unwrap())
320+
})
321+
t.Run("timeout is not applied to handshaker: "+tc.name, func(t *testing.T) {
322+
// Ensure that no additional timeout is applied to the handshake after the connection has been
323+
// established.
324+
325+
var getInfoCtx, finishCtx context.Context
326+
handshaker := &testHandshaker{
327+
getDescription: func(ctx context.Context, _ address.Address, _ driver.Connection) (description.Server, error) {
328+
getInfoCtx = ctx
329+
return description.Server{}, nil
330+
},
331+
finishHandshake: func(ctx context.Context, _ driver.Connection) error {
332+
finishCtx = ctx
333+
return nil
334+
},
335+
}
336+
337+
connOpts := []ConnectionOption{
338+
WithConnectTimeout(func(time.Duration) time.Duration {
339+
return tc.connectTimeout
340+
}),
341+
WithDialer(func(Dialer) Dialer {
342+
return DialerFunc(func(context.Context, string, string) (net.Conn, error) {
343+
return &net.TCPConn{}, nil
344+
})
345+
}),
346+
WithHandshaker(func(Handshaker) Handshaker {
347+
return handshaker
348+
}),
349+
}
350+
conn, err := newConnection("", connOpts...)
351+
assert.Nil(t, err, "newConnection error: %v", err)
352+
353+
bgCtx := context.Background()
354+
conn.connect(bgCtx)
355+
err = conn.wait()
356+
assert.Nil(t, err, "connect error: %v", err)
357+
358+
assertNoContextTimeout := func(t *testing.T, ctx context.Context) {
359+
t.Helper()
360+
dl, ok := ctx.Deadline()
361+
assert.False(t, ok, "expected context to have no deadline, but got deadline %v", dl)
362+
}
363+
assertNoContextTimeout(t, getInfoCtx)
364+
assertNoContextTimeout(t, finishCtx)
365+
})
366+
}
367+
})
231368
})
232369
t.Run("writeWireMessage", func(t *testing.T) {
233370
t.Run("closed connection", func(t *testing.T) {
@@ -689,3 +826,24 @@ func (d *dialer) lenclosed() int {
689826
defer d.Unlock()
690827
return len(d.closed)
691828
}
829+
830+
// hangingTLSConn is an implementation of tlsConn that wraps the tls.Conn type and overrides the Handshake function to
831+
// sleep for a fixed amount of time.
832+
type hangingTLSConn struct {
833+
*tls.Conn
834+
sleepTime time.Duration
835+
}
836+
837+
var _ tlsConn = (*hangingTLSConn)(nil)
838+
839+
func newHangingTLSConn(conn *tls.Conn, sleepTime time.Duration) *hangingTLSConn {
840+
return &hangingTLSConn{
841+
Conn: conn,
842+
sleepTime: sleepTime,
843+
}
844+
}
845+
846+
func (h *hangingTLSConn) Handshake() error {
847+
time.Sleep(h.sleepTime)
848+
return h.Conn.Handshake()
849+
}

x/mongo/driver/topology/tls_connection_source.go

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,26 @@ import (
1111
"net"
1212
)
1313

14+
type tlsConn interface {
15+
net.Conn
16+
Handshake() error
17+
ConnectionState() tls.ConnectionState
18+
}
19+
20+
var _ tlsConn = (*tls.Conn)(nil)
21+
1422
type tlsConnectionSource interface {
15-
Client(net.Conn, *tls.Config) *tls.Conn
23+
Client(net.Conn, *tls.Config) tlsConn
1624
}
1725

18-
type tlsConnectionSourceFn func(net.Conn, *tls.Config) *tls.Conn
26+
type tlsConnectionSourceFn func(net.Conn, *tls.Config) tlsConn
27+
28+
var _ tlsConnectionSource = (tlsConnectionSourceFn)(nil)
1929

20-
func (t tlsConnectionSourceFn) Client(nc net.Conn, cfg *tls.Config) *tls.Conn {
30+
func (t tlsConnectionSourceFn) Client(nc net.Conn, cfg *tls.Config) tlsConn {
2131
return t(nc, cfg)
2232
}
2333

24-
var defaultTLSConnectionSource tlsConnectionSourceFn = func(nc net.Conn, cfg *tls.Config) *tls.Conn {
34+
var defaultTLSConnectionSource tlsConnectionSourceFn = func(nc net.Conn, cfg *tls.Config) tlsConn {
2535
return tls.Client(nc, cfg)
2636
}

0 commit comments

Comments
 (0)