Skip to content

Commit d080bd0

Browse files
author
Divjot Arora
authored
GODRIVER-672 Change session IDs to be stored as bson.Raw (#339)
1 parent 4a201aa commit d080bd0

File tree

9 files changed

+49
-62
lines changed

9 files changed

+49
-62
lines changed

mongo/client.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -291,8 +291,7 @@ func (c *Client) endSessions(ctx context.Context) {
291291
ids := c.sessionPool.IDSlice()
292292
idx, idArray := bsoncore.AppendArrayStart(nil)
293293
for i, id := range ids {
294-
idDoc, _ := id.MarshalBSON()
295-
idArray = bsoncore.AppendDocumentElement(idArray, strconv.Itoa(i), idDoc)
294+
idArray = bsoncore.AppendDocumentElement(idArray, strconv.Itoa(i), id)
296295
}
297296
idArray, _ = bsoncore.AppendArrayEnd(idArray, idx)
298297

@@ -303,8 +302,7 @@ func (c *Client) endSessions(ctx context.Context) {
303302
idx, idArray = bsoncore.AppendArrayStart(nil)
304303
totalNumIDs := len(ids)
305304
for i := 0; i < totalNumIDs; i++ {
306-
idDoc, _ := ids[i].MarshalBSON()
307-
idArray = bsoncore.AppendDocumentElement(idArray, strconv.Itoa(i), idDoc)
305+
idArray = bsoncore.AppendDocumentElement(idArray, strconv.Itoa(i), ids[i])
308306
if ((i+1)%batchSize) == 0 || i == totalNumIDs-1 {
309307
idArray, _ = bsoncore.AppendArrayEnd(idArray, idx)
310308
_ = op.SessionIDs(idArray).Execute(ctx)

mongo/integration/cmd_monitoring_helpers_test.go

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ import (
1515
"go.mongodb.org/mongo-driver/bson/bsontype"
1616
"go.mongodb.org/mongo-driver/internal/testutil/assert"
1717
"go.mongodb.org/mongo-driver/mongo/integration/mtest"
18-
"go.mongodb.org/mongo-driver/x/bsonx"
1918
)
2019

2120
// Helper functions to compare BSON values and command monitoring expectations.
@@ -186,7 +185,7 @@ func compareDocs(mt *mtest.T, expected, actual bson.Raw) error {
186185
return compareDocsHelper(mt, expected, actual, "")
187186
}
188187

189-
func checkExpectations(mt *mtest.T, expectations []*expectation, id0, id1 bsonx.Doc) {
188+
func checkExpectations(mt *mtest.T, expectations []*expectation, id0, id1 bson.Raw) {
190189
mt.Helper()
191190

192191
for idx, expectation := range expectations {
@@ -206,7 +205,7 @@ func checkExpectations(mt *mtest.T, expectations []*expectation, id0, id1 bsonx.
206205
}
207206
}
208207

209-
func compareStartedEvent(mt *mtest.T, expectation *expectation, id0, id1 bsonx.Doc) error {
208+
func compareStartedEvent(mt *mtest.T, expectation *expectation, id0, id1 bson.Raw) error {
210209
mt.Helper()
211210

212211
expected := expectation.CommandStartedEvent
@@ -261,16 +260,13 @@ func compareStartedEvent(mt *mtest.T, expectation *expectation, id0, id1 bsonx.D
261260

262261
switch sessName {
263262
case "session0":
264-
expectedID, err = id0.MarshalBSON()
263+
expectedID = id0
265264
case "session1":
266-
expectedID, err = id1.MarshalBSON()
265+
expectedID = id1
267266
default:
268267
return fmt.Errorf("unrecognized session identifier in command document: %s", sessName)
269268
}
270269

271-
if err != nil {
272-
return fmt.Errorf("error getting expected session ID bytes for session name %s: %s", sessName, err)
273-
}
274270
if !bytes.Equal(expectedID, actualID) {
275271
return fmt.Errorf("session ID mismatch for session %s; expected %s, got %s", sessName, expectedID,
276272
actualID)

mongo/integration/sessions_test.go

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ import (
1818
"go.mongodb.org/mongo-driver/mongo/integration/mtest"
1919
"go.mongodb.org/mongo-driver/mongo/options"
2020
"go.mongodb.org/mongo-driver/mongo/readpref"
21-
"go.mongodb.org/mongo-driver/x/bsonx"
2221
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
2322
)
2423

@@ -39,15 +38,15 @@ func TestSessionPool(t *testing.T) {
3938
firstSess, err := mt.Client.StartSession()
4039
assert.Nil(mt, err, "StartSession error: %v", err)
4140
defer firstSess.EndSession(mtest.Background)
42-
want := getSessionID(mt, bSess)
43-
got := getSessionID(mt, firstSess)
41+
want := bSess.ID()
42+
got := firstSess.ID()
4443
assert.True(mt, sessionIDsEqual(mt, want, got), "expected session ID %v, got %v", want, got)
4544

4645
secondSess, err := mt.Client.StartSession()
4746
assert.Nil(mt, err, "StartSession error: %v", err)
4847
defer secondSess.EndSession(mtest.Background)
49-
want = getSessionID(mt, aSess)
50-
got = getSessionID(mt, secondSess)
48+
want = aSess.ID()
49+
got = secondSess.ID()
5150
assert.True(mt, sessionIDsEqual(mt, want, got), "expected session ID %v, got %v", want, got)
5251
})
5352
mt.Run("last use time updated", func(mt *mtest.T) {
@@ -127,7 +126,7 @@ func TestSessions(t *testing.T) {
127126
mt.ClearEvents()
128127

129128
_ = sf.execute(mt, sess) // don't check error because we only care about lsid
130-
_, wantID := getSessionID(mt, sess).Lookup("id").Binary()
129+
_, wantID := sess.ID().Lookup("id").Binary()
131130
gotID := extractSentSessionID(mt)
132131
assert.True(mt, bytes.Equal(wantID, gotID), "expected session ID %v, got %v", wantID, gotID)
133132

@@ -338,7 +337,7 @@ func createFunctionsSlice() []sessionFunction {
338337
}
339338
}
340339

341-
func sessionIDsEqual(mt *mtest.T, id1, id2 bsonx.Doc) bool {
340+
func sessionIDsEqual(mt *mtest.T, id1, id2 bson.Raw) bool {
342341
first, err := id1.LookupErr("id")
343342
assert.Nil(mt, err, "id not found in document %v", id1)
344343
second, err := id2.LookupErr("id")

mongo/integration/unified_spec_test.go

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ import (
2424
"go.mongodb.org/mongo-driver/mongo/options"
2525
"go.mongodb.org/mongo-driver/mongo/readconcern"
2626
"go.mongodb.org/mongo-driver/mongo/readpref"
27-
"go.mongodb.org/mongo-driver/x/bsonx"
2827
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
2928
)
3029

@@ -184,13 +183,6 @@ func runSpecTestFile(t *testing.T, specDir, fileName string) {
184183
}
185184
}
186185

187-
func getSessionID(mt *mtest.T, sess mongo.Session) bsonx.Doc {
188-
mt.Helper()
189-
xsess, ok := sess.(mongo.XSession)
190-
assert.True(mt, ok, "expected %T to implement mongo.XSession", sess)
191-
return xsess.ID()
192-
}
193-
194186
func runSpecTestCase(mt *mtest.T, test *testCase, testFile testFile) {
195187
testClientOpts := createClientOptions(mt, test.ClientOptions)
196188
testClientOpts.SetHeartbeatInterval(50 * time.Millisecond)
@@ -271,7 +263,7 @@ func runSpecTestCase(mt *mtest.T, test *testCase, testFile testFile) {
271263
sess1.EndSession(mtest.Background)
272264
mt.ClearFailPoints()
273265

274-
checkExpectations(mt, test.Expectations, getSessionID(mt, sess0), getSessionID(mt, sess1))
266+
checkExpectations(mt, test.Expectations, sess0.ID(), sess1.ID())
275267

276268
if test.Outcome != nil {
277269
verifyTestOutcome(mt, test.Outcome.Collection)

mongo/session.go

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ import (
1414
"go.mongodb.org/mongo-driver/bson"
1515
"go.mongodb.org/mongo-driver/bson/primitive"
1616
"go.mongodb.org/mongo-driver/mongo/options"
17-
"go.mongodb.org/mongo-driver/x/bsonx"
1817
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
1918
"go.mongodb.org/mongo-driver/x/mongo/driver"
2019
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
@@ -69,23 +68,29 @@ type sessionKey struct {
6968
// errors or any operation errors that occur after the timeout expires will be returned without retrying. For a usage
7069
// example, see the Client.StartSession method documentation.
7170
//
72-
// ClusterTime, OperationTime, and Client return the session's current operation time, the session's current cluster
73-
// time, and the Client associated with the session, respectively.
71+
// ClusterTime, OperationTime, Client, and ID return the session's current operation time, the session's current cluster
72+
// time, the Client associated with the session, and the ID document associated with the session, respectively. The ID
73+
// document for a session is in the form {"id": <BSON binary value>}.
7474
//
7575
// EndSession method should abort any existing transactions and close the session.
7676
//
7777
// AdvanceClusterTime and AdvanceOperationTime are for internal use only and must not be called.
7878
type Session interface {
79+
// Functions to modify session state.
7980
StartTransaction(...*options.TransactionOptions) error
8081
AbortTransaction(context.Context) error
8182
CommitTransaction(context.Context) error
8283
WithTransaction(ctx context.Context, fn func(sessCtx SessionContext) (interface{}, error),
8384
opts ...*options.TransactionOptions) (interface{}, error)
85+
EndSession(context.Context)
86+
87+
// Functions to retrieve session properties.
8488
ClusterTime() bson.Raw
8589
OperationTime() *primitive.Timestamp
8690
Client() *Client
87-
EndSession(context.Context)
91+
ID() bson.Raw
8892

93+
// Functions to modify mutable session properties.
8994
AdvanceClusterTime(bson.Raw) error
9095
AdvanceOperationTime(*primitive.Timestamp) error
9196

@@ -96,7 +101,6 @@ type Session interface {
96101
// stability guarantee. It may be removed at any time.
97102
type XSession interface {
98103
ClientSession() *session.Client
99-
ID() bsonx.Doc
100104
}
101105

102106
// sessionImpl represents a set of sequential operations executed by an application that are related in some way.
@@ -115,9 +119,9 @@ func (s *sessionImpl) ClientSession() *session.Client {
115119
return s.clientSession
116120
}
117121

118-
// ID implements the XSession interface.
119-
func (s *sessionImpl) ID() bsonx.Doc {
120-
return s.clientSession.SessionID
122+
// ID implements the Session interface.
123+
func (s *sessionImpl) ID() bson.Raw {
124+
return bson.Raw(s.clientSession.SessionID)
121125
}
122126

123127
// EndSession implements the Session interface.

x/mongo/driver/operation.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -930,8 +930,7 @@ func (op Operation) addSession(dst []byte, desc description.SelectedServer) ([]b
930930
if err := client.UpdateUseTime(); err != nil {
931931
return dst, err
932932
}
933-
lsid, _ := client.SessionID.MarshalBSON()
934-
dst = bsoncore.AppendDocumentElement(dst, "lsid", lsid)
933+
dst = bsoncore.AppendDocumentElement(dst, "lsid", client.SessionID)
935934

936935
var addedTxnNumber bool
937936
if op.Type == Write && client.RetryWrite {

x/mongo/driver/session/server_session.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,15 @@ import (
1111

1212
"crypto/rand"
1313

14-
"go.mongodb.org/mongo-driver/x/bsonx"
14+
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
1515
"go.mongodb.org/mongo-driver/x/mongo/driver/uuid"
1616
)
1717

1818
var rander = rand.Reader
1919

2020
// Server is an open session with the server.
2121
type Server struct {
22-
SessionID bsonx.Doc
22+
SessionID bsoncore.Document
2323
TxnNumber int64
2424
LastUsed time.Time
2525
Dirty bool
@@ -47,7 +47,9 @@ func newServerSession() (*Server, error) {
4747
return nil, err
4848
}
4949

50-
idDoc := bsonx.Doc{{"id", bsonx.Binary(UUIDSubtype, id[:])}}
50+
idx, idDoc := bsoncore.AppendDocumentStart(nil)
51+
idDoc = bsoncore.AppendBinaryElement(idDoc, "id", UUIDSubtype, id[:])
52+
idDoc, _ = bsoncore.AppendDocumentEnd(idDoc, idx)
5153

5254
return &Server{
5355
SessionID: idDoc,

x/mongo/driver/session/session_pool.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ package session
99
import (
1010
"sync"
1111

12-
"go.mongodb.org/mongo-driver/x/bsonx"
12+
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
1313
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
1414
)
1515

@@ -149,11 +149,11 @@ func (p *Pool) ReturnSession(ss *Server) {
149149
}
150150

151151
// IDSlice returns a slice of session IDs for each session in the pool
152-
func (p *Pool) IDSlice() []bsonx.Doc {
152+
func (p *Pool) IDSlice() []bsoncore.Document {
153153
p.mutex.Lock()
154154
defer p.mutex.Unlock()
155155

156-
ids := []bsonx.Doc{}
156+
var ids []bsoncore.Document
157157
for node := p.head; node != nil; node = node.next {
158158
ids = append(ids, node.SessionID)
159159
}

x/mongo/driver/session/session_pool_test.go

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
package session
88

99
import (
10+
"bytes"
1011
"testing"
1112

12-
"go.mongodb.org/mongo-driver/internal/testutil/helpers"
13+
"go.mongodb.org/mongo-driver/internal/testutil/assert"
1314
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
1415
)
1516

@@ -20,28 +21,25 @@ func TestSessionPool(t *testing.T) {
2021
p.timeout = 30 // Set to some arbitrarily high number greater than 1 minute.
2122

2223
first, err := p.GetSession()
23-
testhelpers.RequireNil(t, err, "error getting session %s", err)
24+
assert.Nil(t, err, "GetSession error: %v", err)
2425
firstID := first.SessionID
2526

2627
second, err := p.GetSession()
27-
testhelpers.RequireNil(t, err, "error getting session %s", err)
28+
assert.Nil(t, err, "GetSession error: %v", err)
2829
secondID := second.SessionID
2930

3031
p.ReturnSession(first)
3132
p.ReturnSession(second)
3233

3334
sess, err := p.GetSession()
34-
testhelpers.RequireNil(t, err, "error getting session %s", err)
35+
assert.Nil(t, err, "GetSession error: %v", err)
3536
nextSess, err := p.GetSession()
36-
testhelpers.RequireNil(t, err, "error getting session %s", err)
37+
assert.Nil(t, err, "GetSession error: %v", err)
3738

38-
if !sess.SessionID.Equal(secondID) {
39-
t.Errorf("first sesssion ID mismatch. got %s expected %s", sess.SessionID, secondID)
40-
}
41-
42-
if !nextSess.SessionID.Equal(firstID) {
43-
t.Errorf("second sesssion ID mismatch. got %s expected %s", nextSess.SessionID, firstID)
44-
}
39+
assert.True(t, bytes.Equal(sess.SessionID, secondID),
40+
"first session ID mismatch; expected %s, got %s", secondID, sess.SessionID)
41+
assert.True(t, bytes.Equal(nextSess.SessionID, firstID),
42+
"second session ID mismatch; expected %s, got %s", firstID, nextSess.SessionID)
4543
})
4644

4745
t.Run("TestExpiredRemoved", func(t *testing.T) {
@@ -51,21 +49,20 @@ func TestSessionPool(t *testing.T) {
5149
p.timeout = 0
5250

5351
first, err := p.GetSession()
54-
testhelpers.RequireNil(t, err, "error getting session %s", err)
52+
assert.Nil(t, err, "GetSession error: %v", err)
5553
firstID := first.SessionID
5654

5755
second, err := p.GetSession()
58-
testhelpers.RequireNil(t, err, "error getting session %s", err)
56+
assert.Nil(t, err, "GetSession error: %v", err)
5957
secondID := second.SessionID
6058

6159
p.ReturnSession(first)
6260
p.ReturnSession(second)
6361

6462
sess, err := p.GetSession()
65-
testhelpers.RequireNil(t, err, "error getting session %s", err)
63+
assert.Nil(t, err, "GetSession error: %v", err)
6664

67-
if sess.SessionID.Equal(firstID) || sess.SessionID.Equal(secondID) {
68-
t.Errorf("Expired sessions not removed!")
69-
}
65+
assert.False(t, bytes.Equal(sess.SessionID, firstID), "first expired session was not removed")
66+
assert.False(t, bytes.Equal(sess.SessionID, secondID), "second expired session was not removed")
7067
})
7168
}

0 commit comments

Comments
 (0)