Skip to content

Commit 82dbf47

Browse files
Divjot Aroratsedgwick
authored andcommitted
GODRIVER-1750 Ensure contexts are always cancelled during server monitoring (mongodb#654)
1 parent ce5af90 commit 82dbf47

File tree

2 files changed

+40
-0
lines changed

2 files changed

+40
-0
lines changed

x/mongo/driver/topology/server.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,10 @@ func (s *Server) setupHeartbeatConnection() error {
596596

597597
// Take the lock when assigning the context and connection because they're accessed by cancelCheck.
598598
s.heartbeatLock.Lock()
599+
if s.heartbeatCtxCancel != nil {
600+
// Ensure the previous context is cancelled to avoid a leak.
601+
s.heartbeatCtxCancel()
602+
}
599603
s.heartbeatCtx, s.heartbeatCtxCancel = context.WithCancel(s.globalCtx)
600604
s.conn = conn
601605
s.heartbeatLock.Unlock()

x/mongo/driver/topology/server_test.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,42 @@ func TestServer(t *testing.T) {
507507
assert.Equal(t, s.cfg.heartbeatTimeout, conn.readTimeout, "expected readTimeout to be: %v, got: %v", s.cfg.heartbeatTimeout, conn.readTimeout)
508508
assert.Equal(t, s.cfg.heartbeatTimeout, conn.writeTimeout, "expected writeTimeout to be: %v, got: %v", s.cfg.heartbeatTimeout, conn.writeTimeout)
509509
})
510+
t.Run("heartbeat contexts are not leaked", func(t *testing.T) {
511+
// The context created for heartbeats should be cancelled when it is no longer needed to avoid leaks.
512+
513+
server, err := ConnectServer(
514+
address.Address("invalid"),
515+
nil,
516+
primitive.NewObjectID(),
517+
withMonitoringDisabled(func(bool) bool {
518+
return true
519+
}),
520+
)
521+
assert.Nil(t, err, "ConnectServer error: %v", err)
522+
523+
// Expect check to return an error in the server description because the server address doesn't exist. This is
524+
// OK because we just want to ensure the heartbeat context is created.
525+
desc, err := server.check()
526+
assert.Nil(t, err, "check error: %v", err)
527+
assert.NotNil(t, desc.LastError, "expected server description to contain an error, got nil")
528+
assert.NotNil(t, server.heartbeatCtx, "expected heartbeatCtx to be non-nil, got nil")
529+
assert.Nil(t, server.heartbeatCtx.Err(), "expected heartbeatCtx error to be nil, got %v", server.heartbeatCtx.Err())
530+
531+
// Override heartbeatCtxCancel with a wrapper that records whether or not it was called.
532+
oldCancelFn := server.heartbeatCtxCancel
533+
var previousCtxCancelled bool
534+
server.heartbeatCtxCancel = func() {
535+
previousCtxCancelled = true
536+
oldCancelFn()
537+
}
538+
539+
// The second check call should attempt to create a new heartbeat connection and should cancel the previous
540+
// heartbeatCtx during the process.
541+
desc, err = server.check()
542+
assert.Nil(t, err, "check error: %v", err)
543+
assert.NotNil(t, desc.LastError, "expected server description to contain an error, got nil")
544+
assert.True(t, previousCtxCancelled, "expected check to cancel previous context but did not")
545+
})
510546
}
511547

512548
func includesMetadata(t *testing.T, wm []byte) bool {

0 commit comments

Comments
 (0)