Skip to content

Commit ed18ae6

Browse files
author
Divjot Arora
committed
GODRIVER-1535 Fix session IDs batching in Disconnect
1 parent c0be473 commit ed18ae6

File tree

2 files changed

+109
-24
lines changed

2 files changed

+109
-24
lines changed

mongo/client.go

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ import (
1010
"context"
1111
"crypto/tls"
1212
"errors"
13-
"strconv"
1413
"strings"
1514
"time"
1615

@@ -21,6 +20,7 @@ import (
2120
"go.mongodb.org/mongo-driver/mongo/readconcern"
2221
"go.mongodb.org/mongo-driver/mongo/readpref"
2322
"go.mongodb.org/mongo-driver/mongo/writeconcern"
23+
"go.mongodb.org/mongo-driver/x/bsonx"
2424
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
2525
"go.mongodb.org/mongo-driver/x/mongo/driver"
2626
"go.mongodb.org/mongo-driver/x/mongo/driver/auth"
@@ -33,11 +33,14 @@ import (
3333
)
3434

3535
const defaultLocalThreshold = 15 * time.Millisecond
36-
const batchSize = 10000
3736

38-
// keyVaultCollOpts specifies options used to communicate with the key vault collection
39-
var keyVaultCollOpts = options.Collection().SetReadConcern(readconcern.Majority()).
40-
SetWriteConcern(writeconcern.New(writeconcern.WMajority()))
37+
var (
38+
// keyVaultCollOpts specifies options used to communicate with the key vault collection
39+
keyVaultCollOpts = options.Collection().SetReadConcern(readconcern.Majority()).
40+
SetWriteConcern(writeconcern.New(writeconcern.WMajority()))
41+
42+
endSessionsBatchSize = 10000
43+
)
4144

4245
// Client is a handle representing a pool of connections to a MongoDB deployment. It is safe for concurrent use by
4346
// multiple goroutines.
@@ -287,31 +290,27 @@ func (c *Client) endSessions(ctx context.Context) {
287290
return
288291
}
289292

290-
ids := c.sessionPool.IDSlice()
291-
idx, idArray := bsoncore.AppendArrayStart(nil)
292-
for i, id := range ids {
293-
idDoc, _ := id.MarshalBSON()
294-
idArray = bsoncore.AppendDocumentElement(idArray, strconv.Itoa(i), idDoc)
295-
}
296-
idArray, _ = bsoncore.AppendArrayEnd(idArray, idx)
297-
298-
op := operation.NewEndSessions(idArray).ClusterClock(c.clock).Deployment(c.deployment).
293+
sessionIDs := c.sessionPool.IDSlice()
294+
op := operation.NewEndSessions(nil).ClusterClock(c.clock).Deployment(c.deployment).
299295
ServerSelector(description.ReadPrefSelector(readpref.PrimaryPreferred())).CommandMonitor(c.monitor).
300296
Database("admin").Crypt(c.crypt)
301297

302-
idx, idArray = bsoncore.AppendArrayStart(nil)
303-
totalNumIDs := len(ids)
298+
totalNumIDs := len(sessionIDs)
299+
var currentBatch []bsonx.Doc
304300
for i := 0; i < totalNumIDs; i++ {
305-
idDoc, _ := ids[i].MarshalBSON()
306-
idArray = bsoncore.AppendDocumentElement(idArray, strconv.Itoa(i), idDoc)
307-
if ((i+1)%batchSize) == 0 || i == totalNumIDs-1 {
308-
idArray, _ = bsoncore.AppendArrayEnd(idArray, idx)
309-
_ = op.SessionIDs(idArray).Execute(ctx)
310-
idArray = idArray[:0]
311-
idx = 0
301+
currentBatch = append(currentBatch, sessionIDs[i])
302+
303+
// If we are at the end of a batch or the end of the overall IDs array, execute the operation.
304+
if ((i+1)%endSessionsBatchSize) == 0 || i == totalNumIDs-1 {
305+
// Ignore all errors when ending sessions.
306+
_, marshalVal, err := bson.MarshalValue(currentBatch)
307+
if err == nil {
308+
_ = op.SessionIDs(marshalVal).Execute(ctx)
309+
}
310+
311+
currentBatch = currentBatch[:0]
312312
}
313313
}
314-
315314
}
316315

317316
func (c *Client) configure(opts *options.ClientOptions) error {

mongo/client_test.go

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,13 @@ package mongo
99
import (
1010
"context"
1111
"errors"
12+
"math"
1213
"testing"
1314
"time"
1415

1516
"go.mongodb.org/mongo-driver/bson"
17+
"go.mongodb.org/mongo-driver/event"
18+
"go.mongodb.org/mongo-driver/internal/testutil"
1619
"go.mongodb.org/mongo-driver/internal/testutil/assert"
1720
"go.mongodb.org/mongo-driver/mongo/options"
1821
"go.mongodb.org/mongo-driver/mongo/readconcern"
@@ -257,4 +260,87 @@ func TestClient(t *testing.T) {
257260
assert.Equal(t, uri, got, "expected GetURI to return %v, got %v", uri, got)
258261
})
259262
})
263+
t.Run("endSessions", func(t *testing.T) {
264+
cs := testutil.ConnString(t)
265+
originalBatchSize := endSessionsBatchSize
266+
endSessionsBatchSize = 2
267+
defer func() {
268+
endSessionsBatchSize = originalBatchSize
269+
}()
270+
271+
testCases := []struct {
272+
name string
273+
numSessions int
274+
eventBatchSizes []int
275+
}{
276+
{"number of sessions divides evenly", endSessionsBatchSize * 2, []int{endSessionsBatchSize, endSessionsBatchSize}},
277+
{"number of sessions does not divide evenly", endSessionsBatchSize + 1, []int{endSessionsBatchSize, 1}},
278+
}
279+
for _, tc := range testCases {
280+
t.Run(tc.name, func(t *testing.T) {
281+
// Setup a client and skip the test based on server version.
282+
var started []*event.CommandStartedEvent
283+
var failureReasons []string
284+
cmdMonitor := &event.CommandMonitor{
285+
Started: func(_ context.Context, evt *event.CommandStartedEvent) {
286+
if evt.CommandName == "endSessions" {
287+
started = append(started, evt)
288+
}
289+
},
290+
Failed: func(_ context.Context, evt *event.CommandFailedEvent) {
291+
if evt.CommandName == "endSessions" {
292+
failureReasons = append(failureReasons, evt.Failure)
293+
}
294+
},
295+
}
296+
clientOpts := options.Client().ApplyURI(cs.Original).SetReadPreference(readpref.Primary()).
297+
SetWriteConcern(writeconcern.New(writeconcern.WMajority())).SetMonitor(cmdMonitor)
298+
client, err := Connect(bgCtx, clientOpts)
299+
assert.Nil(t, err, "Connect error: %v", err)
300+
defer func() {
301+
_ = client.Disconnect(bgCtx)
302+
}()
303+
304+
serverVersion, err := getServerVersion(client.Database("admin"))
305+
assert.Nil(t, err, "getServerVersion error: %v", err)
306+
if compareVersions(t, serverVersion, "3.6.0") < 1 {
307+
t.Skip("skipping server version < 3.6")
308+
}
309+
310+
coll := client.Database("foo").Collection("bar")
311+
defer func() {
312+
_ = coll.Drop(bgCtx)
313+
}()
314+
315+
// Do an application operation and create the number of sessions specified by the test.
316+
_, err = coll.CountDocuments(bgCtx, bson.D{})
317+
assert.Nil(t, err, "CountDocuments error: %v", err)
318+
var sessions []Session
319+
for i := 0; i < tc.numSessions; i++ {
320+
sess, err := client.StartSession()
321+
assert.Nil(t, err, "StartSession error at index %d: %v", i, err)
322+
sessions = append(sessions, sess)
323+
}
324+
for _, sess := range sessions {
325+
sess.EndSession(bgCtx)
326+
}
327+
328+
client.endSessions(bgCtx)
329+
divisionResult := float64(tc.numSessions) / float64(endSessionsBatchSize)
330+
numEventsExpected := int(math.Ceil(divisionResult))
331+
assert.Equal(t, len(started), numEventsExpected, "expected %d started events, got %d", numEventsExpected,
332+
len(started))
333+
assert.Equal(t, len(failureReasons), 0, "endSessions errors: %v", failureReasons)
334+
335+
for i := 0; i < numEventsExpected; i++ {
336+
sentArray := started[i].Command.Lookup("endSessions").Array()
337+
values, _ := sentArray.Values()
338+
expectedNumValues := tc.eventBatchSizes[i]
339+
assert.Equal(t, len(values), expectedNumValues,
340+
"batch size mismatch at index %d; expected %d sessions in batch, got %d", i, expectedNumValues,
341+
len(values))
342+
}
343+
})
344+
}
345+
})
260346
}

0 commit comments

Comments
 (0)