Skip to content

Commit 218a07c

Browse files
author
Divjot Arora
authored
GODRIVER-1535 Fix session IDs batching in Disconnect (#343)
1 parent 16d2050 commit 218a07c

File tree

2 files changed

+108
-22
lines changed

2 files changed

+108
-22
lines changed

mongo/client.go

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ import (
1010
"context"
1111
"crypto/tls"
1212
"errors"
13-
"strconv"
1413
"strings"
1514
"time"
1615

@@ -34,11 +33,14 @@ import (
3433
)
3534

3635
const defaultLocalThreshold = 15 * time.Millisecond
37-
const batchSize = 10000
3836

39-
// keyVaultCollOpts specifies options used to communicate with the key vault collection
40-
var keyVaultCollOpts = options.Collection().SetReadConcern(readconcern.Majority()).
41-
SetWriteConcern(writeconcern.New(writeconcern.WMajority()))
37+
var (
38+
// keyVaultCollOpts specifies options used to communicate with the key vault collection
39+
keyVaultCollOpts = options.Collection().SetReadConcern(readconcern.Majority()).
40+
SetWriteConcern(writeconcern.New(writeconcern.WMajority()))
41+
42+
endSessionsBatchSize = 10000
43+
)
4244

4345
// Client is a handle representing a pool of connections to a MongoDB deployment. It is safe for concurrent use by
4446
// multiple goroutines.
@@ -288,29 +290,27 @@ func (c *Client) endSessions(ctx context.Context) {
288290
return
289291
}
290292

291-
ids := c.sessionPool.IDSlice()
292-
idx, idArray := bsoncore.AppendArrayStart(nil)
293-
for i, id := range ids {
294-
idArray = bsoncore.AppendDocumentElement(idArray, strconv.Itoa(i), id)
295-
}
296-
idArray, _ = bsoncore.AppendArrayEnd(idArray, idx)
297-
298-
op := operation.NewEndSessions(idArray).ClusterClock(c.clock).Deployment(c.deployment).
293+
sessionIDs := c.sessionPool.IDSlice()
294+
op := operation.NewEndSessions(nil).ClusterClock(c.clock).Deployment(c.deployment).
299295
ServerSelector(description.ReadPrefSelector(readpref.PrimaryPreferred())).CommandMonitor(c.monitor).
300296
Database("admin").Crypt(c.crypt)
301297

302-
idx, idArray = bsoncore.AppendArrayStart(nil)
303-
totalNumIDs := len(ids)
298+
totalNumIDs := len(sessionIDs)
299+
var currentBatch []bsoncore.Document
304300
for i := 0; i < totalNumIDs; i++ {
305-
idArray = bsoncore.AppendDocumentElement(idArray, strconv.Itoa(i), ids[i])
306-
if ((i+1)%batchSize) == 0 || i == totalNumIDs-1 {
307-
idArray, _ = bsoncore.AppendArrayEnd(idArray, idx)
308-
_ = op.SessionIDs(idArray).Execute(ctx)
309-
idArray = idArray[:0]
310-
idx = 0
301+
currentBatch = append(currentBatch, sessionIDs[i])
302+
303+
// If we are at the end of a batch or the end of the overall IDs array, execute the operation.
304+
if ((i+1)%endSessionsBatchSize) == 0 || i == totalNumIDs-1 {
305+
// Ignore all errors when ending sessions.
306+
_, marshalVal, err := bson.MarshalValue(currentBatch)
307+
if err == nil {
308+
_ = op.SessionIDs(marshalVal).Execute(ctx)
309+
}
310+
311+
currentBatch = currentBatch[:0]
311312
}
312313
}
313-
314314
}
315315

316316
func (c *Client) configure(opts *options.ClientOptions) error {

mongo/client_test.go

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,13 @@ package mongo
99
import (
1010
"context"
1111
"errors"
12+
"math"
1213
"testing"
1314
"time"
1415

1516
"go.mongodb.org/mongo-driver/bson"
17+
"go.mongodb.org/mongo-driver/event"
18+
"go.mongodb.org/mongo-driver/internal/testutil"
1619
"go.mongodb.org/mongo-driver/internal/testutil/assert"
1720
"go.mongodb.org/mongo-driver/mongo/options"
1821
"go.mongodb.org/mongo-driver/mongo/readconcern"
@@ -257,4 +260,87 @@ func TestClient(t *testing.T) {
257260
assert.Equal(t, uri, got, "expected GetURI to return %v, got %v", uri, got)
258261
})
259262
})
263+
t.Run("endSessions", func(t *testing.T) {
264+
cs := testutil.ConnString(t)
265+
originalBatchSize := endSessionsBatchSize
266+
endSessionsBatchSize = 2
267+
defer func() {
268+
endSessionsBatchSize = originalBatchSize
269+
}()
270+
271+
testCases := []struct {
272+
name string
273+
numSessions int
274+
eventBatchSizes []int
275+
}{
276+
{"number of sessions divides evenly", endSessionsBatchSize * 2, []int{endSessionsBatchSize, endSessionsBatchSize}},
277+
{"number of sessions does not divide evenly", endSessionsBatchSize + 1, []int{endSessionsBatchSize, 1}},
278+
}
279+
for _, tc := range testCases {
280+
t.Run(tc.name, func(t *testing.T) {
281+
// Setup a client and skip the test based on server version.
282+
var started []*event.CommandStartedEvent
283+
var failureReasons []string
284+
cmdMonitor := &event.CommandMonitor{
285+
Started: func(_ context.Context, evt *event.CommandStartedEvent) {
286+
if evt.CommandName == "endSessions" {
287+
started = append(started, evt)
288+
}
289+
},
290+
Failed: func(_ context.Context, evt *event.CommandFailedEvent) {
291+
if evt.CommandName == "endSessions" {
292+
failureReasons = append(failureReasons, evt.Failure)
293+
}
294+
},
295+
}
296+
clientOpts := options.Client().ApplyURI(cs.Original).SetReadPreference(readpref.Primary()).
297+
SetWriteConcern(writeconcern.New(writeconcern.WMajority())).SetMonitor(cmdMonitor)
298+
client, err := Connect(bgCtx, clientOpts)
299+
assert.Nil(t, err, "Connect error: %v", err)
300+
defer func() {
301+
_ = client.Disconnect(bgCtx)
302+
}()
303+
304+
serverVersion, err := getServerVersion(client.Database("admin"))
305+
assert.Nil(t, err, "getServerVersion error: %v", err)
306+
if compareVersions(t, serverVersion, "3.6.0") < 1 {
307+
t.Skip("skipping server version < 3.6")
308+
}
309+
310+
coll := client.Database("foo").Collection("bar")
311+
defer func() {
312+
_ = coll.Drop(bgCtx)
313+
}()
314+
315+
// Do an application operation and create the number of sessions specified by the test.
316+
_, err = coll.CountDocuments(bgCtx, bson.D{})
317+
assert.Nil(t, err, "CountDocuments error: %v", err)
318+
var sessions []Session
319+
for i := 0; i < tc.numSessions; i++ {
320+
sess, err := client.StartSession()
321+
assert.Nil(t, err, "StartSession error at index %d: %v", i, err)
322+
sessions = append(sessions, sess)
323+
}
324+
for _, sess := range sessions {
325+
sess.EndSession(bgCtx)
326+
}
327+
328+
client.endSessions(bgCtx)
329+
divisionResult := float64(tc.numSessions) / float64(endSessionsBatchSize)
330+
numEventsExpected := int(math.Ceil(divisionResult))
331+
assert.Equal(t, len(started), numEventsExpected, "expected %d started events, got %d", numEventsExpected,
332+
len(started))
333+
assert.Equal(t, len(failureReasons), 0, "endSessions errors: %v", failureReasons)
334+
335+
for i := 0; i < numEventsExpected; i++ {
336+
sentArray := started[i].Command.Lookup("endSessions").Array()
337+
values, _ := sentArray.Values()
338+
expectedNumValues := tc.eventBatchSizes[i]
339+
assert.Equal(t, len(values), expectedNumValues,
340+
"batch size mismatch at index %d; expected %d sessions in batch, got %d", i, expectedNumValues,
341+
len(values))
342+
}
343+
})
344+
}
345+
})
260346
}

0 commit comments

Comments
 (0)