@@ -195,7 +195,7 @@ func TestConnection(t *testing.T) {
195
195
for _ , tc := range testCases {
196
196
t .Run (tc .name , func (t * testing.T ) {
197
197
var sentCfg * tls.Config
198
- var testTLSConnectionSource tlsConnectionSourceFn = func (nc net.Conn , cfg * tls.Config ) * tls. Conn {
198
+ var testTLSConnectionSource tlsConnectionSourceFn = func (nc net.Conn , cfg * tls.Config ) tlsConn {
199
199
sentCfg = cfg
200
200
return tls .Client (nc , cfg )
201
201
}
@@ -228,6 +228,143 @@ func TestConnection(t *testing.T) {
228
228
}
229
229
})
230
230
})
231
+ t .Run ("connectTimeout is applied correctly" , func (t * testing.T ) {
232
+ testCases := []struct {
233
+ name string
234
+ contextTimeout time.Duration
235
+ connectTimeout time.Duration
236
+ maxConnectTime time.Duration
237
+ }{
238
+ // The timeout to dial a connection should be min(context timeout, connectTimeoutMS), so 1ms for
239
+ // both of the tests declared below. Both tests also specify a 10ms max connect time to provide
240
+ // a large buffer for lag and avoid test flakiness.
241
+
242
+ {"context timeout is lower" , 1 * time .Millisecond , 100 * time .Millisecond , 10 * time .Millisecond },
243
+ {"connect timeout is lower" , 100 * time .Millisecond , 1 * time .Millisecond , 10 * time .Millisecond },
244
+ }
245
+
246
+ for _ , tc := range testCases {
247
+ t .Run ("timeout applied to socket establishment: " + tc .name , func (t * testing.T ) {
248
+ // Ensure the initial connection dial can be timed out and the connection propagates the error
249
+ // from the dialer in this case.
250
+
251
+ connOpts := []ConnectionOption {
252
+ WithDialer (func (Dialer ) Dialer {
253
+ return DialerFunc (func (ctx context.Context , _ , _ string ) (net.Conn , error ) {
254
+ <- ctx .Done ()
255
+ return nil , ctx .Err ()
256
+ })
257
+ }),
258
+ WithConnectTimeout (func (time.Duration ) time.Duration {
259
+ return tc .connectTimeout
260
+ }),
261
+ }
262
+ conn , err := newConnection ("" , connOpts ... )
263
+ assert .Nil (t , err , "newConnection error: %v" , err )
264
+
265
+ ctx , cancel := context .WithTimeout (context .Background (), tc .contextTimeout )
266
+ defer cancel ()
267
+ var connectErr error
268
+ callback := func () {
269
+ conn .connect (ctx )
270
+ connectErr = conn .wait ()
271
+ }
272
+ assert .Soon (t , callback , tc .maxConnectTime )
273
+
274
+ ce , ok := connectErr .(ConnectionError )
275
+ assert .True (t , ok , "expected error %v to be of type %T" , connectErr , ConnectionError {})
276
+ assert .Equal (t , context .DeadlineExceeded , ce .Unwrap (), "expected wrapped error to be %v, got %v" ,
277
+ context .DeadlineExceeded , ce .Unwrap ())
278
+ })
279
+ t .Run ("timeout applied to TLS handshake: " + tc .name , func (t * testing.T ) {
280
+ // Ensure the TLS handshake can be timed out and the connection propagates the error from the
281
+ // tlsConn in this case.
282
+
283
+ var hangingTLSConnectionSource tlsConnectionSourceFn = func (nc net.Conn , cfg * tls.Config ) tlsConn {
284
+ tlsConn := tls .Client (nc , cfg )
285
+ return newHangingTLSConn (tlsConn , tc .maxConnectTime )
286
+ }
287
+
288
+ connOpts := []ConnectionOption {
289
+ WithConnectTimeout (func (time.Duration ) time.Duration {
290
+ return tc .connectTimeout
291
+ }),
292
+ WithDialer (func (Dialer ) Dialer {
293
+ return DialerFunc (func (context.Context , string , string ) (net.Conn , error ) {
294
+ return & net.TCPConn {}, nil
295
+ })
296
+ }),
297
+ WithTLSConfig (func (* tls.Config ) * tls.Config {
298
+ return & tls.Config {}
299
+ }),
300
+ withTLSConnectionSource (func (tlsConnectionSource ) tlsConnectionSource {
301
+ return hangingTLSConnectionSource
302
+ }),
303
+ }
304
+ conn , err := newConnection ("" , connOpts ... )
305
+ assert .Nil (t , err , "newConnection error: %v" , err )
306
+
307
+ ctx , cancel := context .WithTimeout (context .Background (), tc .contextTimeout )
308
+ defer cancel ()
309
+ var connectErr error
310
+ callback := func () {
311
+ conn .connect (ctx )
312
+ connectErr = conn .wait ()
313
+ }
314
+ assert .Soon (t , callback , tc .maxConnectTime )
315
+
316
+ ce , ok := connectErr .(ConnectionError )
317
+ assert .True (t , ok , "expected error %v to be of type %T" , connectErr , ConnectionError {})
318
+ assert .Equal (t , context .DeadlineExceeded , ce .Unwrap (), "expected wrapped error to be %v, got %v" ,
319
+ context .DeadlineExceeded , ce .Unwrap ())
320
+ })
321
+ t .Run ("timeout is not applied to handshaker: " + tc .name , func (t * testing.T ) {
322
+ // Ensure that no additional timeout is applied to the handshake after the connection has been
323
+ // established.
324
+
325
+ var getInfoCtx , finishCtx context.Context
326
+ handshaker := & testHandshaker {
327
+ getDescription : func (ctx context.Context , _ address.Address , _ driver.Connection ) (description.Server , error ) {
328
+ getInfoCtx = ctx
329
+ return description.Server {}, nil
330
+ },
331
+ finishHandshake : func (ctx context.Context , _ driver.Connection ) error {
332
+ finishCtx = ctx
333
+ return nil
334
+ },
335
+ }
336
+
337
+ connOpts := []ConnectionOption {
338
+ WithConnectTimeout (func (time.Duration ) time.Duration {
339
+ return tc .connectTimeout
340
+ }),
341
+ WithDialer (func (Dialer ) Dialer {
342
+ return DialerFunc (func (context.Context , string , string ) (net.Conn , error ) {
343
+ return & net.TCPConn {}, nil
344
+ })
345
+ }),
346
+ WithHandshaker (func (Handshaker ) Handshaker {
347
+ return handshaker
348
+ }),
349
+ }
350
+ conn , err := newConnection ("" , connOpts ... )
351
+ assert .Nil (t , err , "newConnection error: %v" , err )
352
+
353
+ bgCtx := context .Background ()
354
+ conn .connect (bgCtx )
355
+ err = conn .wait ()
356
+ assert .Nil (t , err , "connect error: %v" , err )
357
+
358
+ assertNoContextTimeout := func (t * testing.T , ctx context.Context ) {
359
+ t .Helper ()
360
+ dl , ok := ctx .Deadline ()
361
+ assert .False (t , ok , "expected context to have no deadline, but got deadline %v" , dl )
362
+ }
363
+ assertNoContextTimeout (t , getInfoCtx )
364
+ assertNoContextTimeout (t , finishCtx )
365
+ })
366
+ }
367
+ })
231
368
})
232
369
t .Run ("writeWireMessage" , func (t * testing.T ) {
233
370
t .Run ("closed connection" , func (t * testing.T ) {
@@ -689,3 +826,24 @@ func (d *dialer) lenclosed() int {
689
826
defer d .Unlock ()
690
827
return len (d .closed )
691
828
}
829
+
830
+ // hangingTLSConn is an implementation of tlsConn that wraps the tls.Conn type and overrides the Handshake function to
831
+ // sleep for a fixed amount of time.
832
+ type hangingTLSConn struct {
833
+ * tls.Conn
834
+ sleepTime time.Duration
835
+ }
836
+
837
+ var _ tlsConn = (* hangingTLSConn )(nil )
838
+
839
+ func newHangingTLSConn (conn * tls.Conn , sleepTime time.Duration ) * hangingTLSConn {
840
+ return & hangingTLSConn {
841
+ Conn : conn ,
842
+ sleepTime : sleepTime ,
843
+ }
844
+ }
845
+
846
+ func (h * hangingTLSConn ) Handshake () error {
847
+ time .Sleep (h .sleepTime )
848
+ return h .Conn .Handshake ()
849
+ }
0 commit comments