Skip to content

Commit 29b591c

Browse files
author
Divjot Arora
authored
GODRIVER-1750 Ensure contexts are always cancelled during server monitoring (#654)
1 parent b47cf2e commit 29b591c

File tree

2 files changed

+40
-0
lines changed

2 files changed

+40
-0
lines changed

x/mongo/driver/topology/server.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,10 @@ func (s *Server) setupHeartbeatConnection() error {
617617

618618
// Take the lock when assigning the context and connection because they're accessed by cancelCheck.
619619
s.heartbeatLock.Lock()
620+
if s.heartbeatCtxCancel != nil {
621+
// Ensure the previous context is cancelled to avoid a leak.
622+
s.heartbeatCtxCancel()
623+
}
620624
s.heartbeatCtx, s.heartbeatCtxCancel = context.WithCancel(s.globalCtx)
621625
s.conn = conn
622626
s.heartbeatLock.Unlock()

x/mongo/driver/topology/server_test.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,42 @@ func TestServer(t *testing.T) {
627627
assert.Equal(t, s.cfg.heartbeatTimeout, conn.readTimeout, "expected readTimeout to be: %v, got: %v", s.cfg.heartbeatTimeout, conn.readTimeout)
628628
assert.Equal(t, s.cfg.heartbeatTimeout, conn.writeTimeout, "expected writeTimeout to be: %v, got: %v", s.cfg.heartbeatTimeout, conn.writeTimeout)
629629
})
630+
t.Run("heartbeat contexts are not leaked", func(t *testing.T) {
631+
// The context created for heartbeats should be cancelled when it is no longer needed to avoid leaks.
632+
633+
server, err := ConnectServer(
634+
address.Address("invalid"),
635+
nil,
636+
primitive.NewObjectID(),
637+
withMonitoringDisabled(func(bool) bool {
638+
return true
639+
}),
640+
)
641+
assert.Nil(t, err, "ConnectServer error: %v", err)
642+
643+
// Expect check to return an error in the server description because the server address doesn't exist. This is
644+
// OK because we just want to ensure the heartbeat context is created.
645+
desc, err := server.check()
646+
assert.Nil(t, err, "check error: %v", err)
647+
assert.NotNil(t, desc.LastError, "expected server description to contain an error, got nil")
648+
assert.NotNil(t, server.heartbeatCtx, "expected heartbeatCtx to be non-nil, got nil")
649+
assert.Nil(t, server.heartbeatCtx.Err(), "expected heartbeatCtx error to be nil, got %v", server.heartbeatCtx.Err())
650+
651+
// Override heartbeatCtxCancel with a wrapper that records whether or not it was called.
652+
oldCancelFn := server.heartbeatCtxCancel
653+
var previousCtxCancelled bool
654+
server.heartbeatCtxCancel = func() {
655+
previousCtxCancelled = true
656+
oldCancelFn()
657+
}
658+
659+
// The second check call should attempt to create a new heartbeat connection and should cancel the previous
660+
// heartbeatCtx during the process.
661+
desc, err = server.check()
662+
assert.Nil(t, err, "check error: %v", err)
663+
assert.NotNil(t, desc.LastError, "expected server description to contain an error, got nil")
664+
assert.True(t, previousCtxCancelled, "expected check to cancel previous context but did not")
665+
})
630666
}
631667

632668
func includesMetadata(t *testing.T, wm []byte) bool {

0 commit comments

Comments
 (0)