Skip to content

GODRIVER-672 Change session IDs to be stored as bson.Raw #339

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions mongo/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,7 @@ func (c *Client) endSessions(ctx context.Context) {
ids := c.sessionPool.IDSlice()
idx, idArray := bsoncore.AppendArrayStart(nil)
for i, id := range ids {
idDoc, _ := id.MarshalBSON()
idArray = bsoncore.AppendDocumentElement(idArray, strconv.Itoa(i), idDoc)
idArray = bsoncore.AppendDocumentElement(idArray, strconv.Itoa(i), id)
}
idArray, _ = bsoncore.AppendArrayEnd(idArray, idx)

Expand All @@ -303,8 +302,7 @@ func (c *Client) endSessions(ctx context.Context) {
idx, idArray = bsoncore.AppendArrayStart(nil)
totalNumIDs := len(ids)
for i := 0; i < totalNumIDs; i++ {
idDoc, _ := ids[i].MarshalBSON()
idArray = bsoncore.AppendDocumentElement(idArray, strconv.Itoa(i), idDoc)
idArray = bsoncore.AppendDocumentElement(idArray, strconv.Itoa(i), ids[i])
if ((i+1)%batchSize) == 0 || i == totalNumIDs-1 {
idArray, _ = bsoncore.AppendArrayEnd(idArray, idx)
_ = op.SessionIDs(idArray).Execute(ctx)
Expand Down
12 changes: 4 additions & 8 deletions mongo/integration/cmd_monitoring_helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import (
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/internal/testutil/assert"
"go.mongodb.org/mongo-driver/mongo/integration/mtest"
"go.mongodb.org/mongo-driver/x/bsonx"
)

// Helper functions to compare BSON values and command monitoring expectations.
Expand Down Expand Up @@ -186,7 +185,7 @@ func compareDocs(mt *mtest.T, expected, actual bson.Raw) error {
return compareDocsHelper(mt, expected, actual, "")
}

func checkExpectations(mt *mtest.T, expectations []*expectation, id0, id1 bsonx.Doc) {
func checkExpectations(mt *mtest.T, expectations []*expectation, id0, id1 bson.Raw) {
mt.Helper()

for idx, expectation := range expectations {
Expand All @@ -206,7 +205,7 @@ func checkExpectations(mt *mtest.T, expectations []*expectation, id0, id1 bsonx.
}
}

func compareStartedEvent(mt *mtest.T, expectation *expectation, id0, id1 bsonx.Doc) error {
func compareStartedEvent(mt *mtest.T, expectation *expectation, id0, id1 bson.Raw) error {
mt.Helper()

expected := expectation.CommandStartedEvent
Expand Down Expand Up @@ -261,16 +260,13 @@ func compareStartedEvent(mt *mtest.T, expectation *expectation, id0, id1 bsonx.D

switch sessName {
case "session0":
expectedID, err = id0.MarshalBSON()
expectedID = id0
case "session1":
expectedID, err = id1.MarshalBSON()
expectedID = id1
default:
return fmt.Errorf("unrecognized session identifier in command document: %s", sessName)
}

if err != nil {
return fmt.Errorf("error getting expected session ID bytes for session name %s: %s", sessName, err)
}
if !bytes.Equal(expectedID, actualID) {
return fmt.Errorf("session ID mismatch for session %s; expected %s, got %s", sessName, expectedID,
actualID)
Expand Down
13 changes: 6 additions & 7 deletions mongo/integration/sessions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import (
"go.mongodb.org/mongo-driver/mongo/integration/mtest"
"go.mongodb.org/mongo-driver/mongo/options"
"go.mongodb.org/mongo-driver/mongo/readpref"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
)

Expand All @@ -39,15 +38,15 @@ func TestSessionPool(t *testing.T) {
firstSess, err := mt.Client.StartSession()
assert.Nil(mt, err, "StartSession error: %v", err)
defer firstSess.EndSession(mtest.Background)
want := getSessionID(mt, bSess)
got := getSessionID(mt, firstSess)
want := bSess.ID()
got := firstSess.ID()
assert.True(mt, sessionIDsEqual(mt, want, got), "expected session ID %v, got %v", want, got)

secondSess, err := mt.Client.StartSession()
assert.Nil(mt, err, "StartSession error: %v", err)
defer secondSess.EndSession(mtest.Background)
want = getSessionID(mt, aSess)
got = getSessionID(mt, secondSess)
want = aSess.ID()
got = secondSess.ID()
assert.True(mt, sessionIDsEqual(mt, want, got), "expected session ID %v, got %v", want, got)
})
mt.Run("last use time updated", func(mt *mtest.T) {
Expand Down Expand Up @@ -127,7 +126,7 @@ func TestSessions(t *testing.T) {
mt.ClearEvents()

_ = sf.execute(mt, sess) // don't check error because we only care about lsid
_, wantID := getSessionID(mt, sess).Lookup("id").Binary()
_, wantID := sess.ID().Lookup("id").Binary()
gotID := extractSentSessionID(mt)
assert.True(mt, bytes.Equal(wantID, gotID), "expected session ID %v, got %v", wantID, gotID)

Expand Down Expand Up @@ -338,7 +337,7 @@ func createFunctionsSlice() []sessionFunction {
}
}

func sessionIDsEqual(mt *mtest.T, id1, id2 bsonx.Doc) bool {
func sessionIDsEqual(mt *mtest.T, id1, id2 bson.Raw) bool {
first, err := id1.LookupErr("id")
assert.Nil(mt, err, "id not found in document %v", id1)
second, err := id2.LookupErr("id")
Expand Down
10 changes: 1 addition & 9 deletions mongo/integration/unified_spec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (
"go.mongodb.org/mongo-driver/mongo/options"
"go.mongodb.org/mongo-driver/mongo/readconcern"
"go.mongodb.org/mongo-driver/mongo/readpref"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
)

Expand Down Expand Up @@ -184,13 +183,6 @@ func runSpecTestFile(t *testing.T, specDir, fileName string) {
}
}

func getSessionID(mt *mtest.T, sess mongo.Session) bsonx.Doc {
mt.Helper()
xsess, ok := sess.(mongo.XSession)
assert.True(mt, ok, "expected %T to implement mongo.XSession", sess)
return xsess.ID()
}

func runSpecTestCase(mt *mtest.T, test *testCase, testFile testFile) {
testClientOpts := createClientOptions(mt, test.ClientOptions)
testClientOpts.SetHeartbeatInterval(50 * time.Millisecond)
Expand Down Expand Up @@ -271,7 +263,7 @@ func runSpecTestCase(mt *mtest.T, test *testCase, testFile testFile) {
sess1.EndSession(mtest.Background)
mt.ClearFailPoints()

checkExpectations(mt, test.Expectations, getSessionID(mt, sess0), getSessionID(mt, sess1))
checkExpectations(mt, test.Expectations, sess0.ID(), sess1.ID())

if test.Outcome != nil {
verifyTestOutcome(mt, test.Outcome.Collection)
Expand Down
20 changes: 12 additions & 8 deletions mongo/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo/options"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
Expand Down Expand Up @@ -69,23 +68,29 @@ type sessionKey struct {
// errors or any operation errors that occur after the timeout expires will be returned without retrying. For a usage
// example, see the Client.StartSession method documentation.
//
// ClusterTime, OperationTime, and Client return the session's current operation time, the session's current cluster
// time, and the Client associated with the session, respectively.
// ClusterTime, OperationTime, Client, and ID return the session's current operation time, the session's current cluster
// time, the Client associated with the session, and the ID document associated with the session, respectively. The ID
// document for a session is in the form {"id": <BSON binary value>}.
//
// EndSession method should abort any existing transactions and close the session.
//
// AdvanceClusterTime and AdvanceOperationTime are for internal use only and must not be called.
type Session interface {
// Functions to modify session state.
StartTransaction(...*options.TransactionOptions) error
AbortTransaction(context.Context) error
CommitTransaction(context.Context) error
WithTransaction(ctx context.Context, fn func(sessCtx SessionContext) (interface{}, error),
opts ...*options.TransactionOptions) (interface{}, error)
EndSession(context.Context)

// Functions to retrieve session properties.
ClusterTime() bson.Raw
OperationTime() *primitive.Timestamp
Client() *Client
EndSession(context.Context)
ID() bson.Raw

// Functions to modify mutable session properties.
AdvanceClusterTime(bson.Raw) error
AdvanceOperationTime(*primitive.Timestamp) error

Expand All @@ -96,7 +101,6 @@ type Session interface {
// stability guarantee. It may be removed at any time.
type XSession interface {
ClientSession() *session.Client
ID() bsonx.Doc
}

// sessionImpl represents a set of sequential operations executed by an application that are related in some way.
Expand All @@ -115,9 +119,9 @@ func (s *sessionImpl) ClientSession() *session.Client {
return s.clientSession
}

// ID implements the XSession interface.
func (s *sessionImpl) ID() bsonx.Doc {
return s.clientSession.SessionID
// ID implements the Session interface.
func (s *sessionImpl) ID() bson.Raw {
return bson.Raw(s.clientSession.SessionID)
}

// EndSession implements the Session interface.
Expand Down
3 changes: 1 addition & 2 deletions x/mongo/driver/operation.go
Original file line number Diff line number Diff line change
Expand Up @@ -926,8 +926,7 @@ func (op Operation) addSession(dst []byte, desc description.SelectedServer) ([]b
if err := client.UpdateUseTime(); err != nil {
return dst, err
}
lsid, _ := client.SessionID.MarshalBSON()
dst = bsoncore.AppendDocumentElement(dst, "lsid", lsid)
dst = bsoncore.AppendDocumentElement(dst, "lsid", client.SessionID)

var addedTxnNumber bool
if op.Type == Write && client.RetryWrite {
Expand Down
8 changes: 5 additions & 3 deletions x/mongo/driver/session/server_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ import (

"crypto/rand"

"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver/uuid"
)

var rander = rand.Reader

// Server is an open session with the server.
type Server struct {
SessionID bsonx.Doc
SessionID bsoncore.Document
TxnNumber int64
LastUsed time.Time
Dirty bool
Expand Down Expand Up @@ -47,7 +47,9 @@ func newServerSession() (*Server, error) {
return nil, err
}

idDoc := bsonx.Doc{{"id", bsonx.Binary(UUIDSubtype, id[:])}}
idx, idDoc := bsoncore.AppendDocumentStart(nil)
idDoc = bsoncore.AppendBinaryElement(idDoc, "id", UUIDSubtype, id[:])
idDoc, _ = bsoncore.AppendDocumentEnd(idDoc, idx)

return &Server{
SessionID: idDoc,
Expand Down
6 changes: 3 additions & 3 deletions x/mongo/driver/session/session_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ package session
import (
"sync"

"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
)

Expand Down Expand Up @@ -149,11 +149,11 @@ func (p *Pool) ReturnSession(ss *Server) {
}

// IDSlice returns a slice of session IDs for each session in the pool
func (p *Pool) IDSlice() []bsonx.Doc {
func (p *Pool) IDSlice() []bsoncore.Document {
p.mutex.Lock()
defer p.mutex.Unlock()

ids := []bsonx.Doc{}
var ids []bsoncore.Document
for node := p.head; node != nil; node = node.next {
ids = append(ids, node.SessionID)
}
Expand Down
33 changes: 15 additions & 18 deletions x/mongo/driver/session/session_pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
package session

import (
"bytes"
"testing"

"go.mongodb.org/mongo-driver/internal/testutil/helpers"
"go.mongodb.org/mongo-driver/internal/testutil/assert"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
)

Expand All @@ -20,28 +21,25 @@ func TestSessionPool(t *testing.T) {
p.timeout = 30 // Set to some arbitrarily high number greater than 1 minute.

first, err := p.GetSession()
testhelpers.RequireNil(t, err, "error getting session %s", err)
assert.Nil(t, err, "GetSession error: %v", err)
firstID := first.SessionID

second, err := p.GetSession()
testhelpers.RequireNil(t, err, "error getting session %s", err)
assert.Nil(t, err, "GetSession error: %v", err)
secondID := second.SessionID

p.ReturnSession(first)
p.ReturnSession(second)

sess, err := p.GetSession()
testhelpers.RequireNil(t, err, "error getting session %s", err)
assert.Nil(t, err, "GetSession error: %v", err)
nextSess, err := p.GetSession()
testhelpers.RequireNil(t, err, "error getting session %s", err)
assert.Nil(t, err, "GetSession error: %v", err)

if !sess.SessionID.Equal(secondID) {
t.Errorf("first sesssion ID mismatch. got %s expected %s", sess.SessionID, secondID)
}

if !nextSess.SessionID.Equal(firstID) {
t.Errorf("second sesssion ID mismatch. got %s expected %s", nextSess.SessionID, firstID)
}
assert.True(t, bytes.Equal(sess.SessionID, secondID),
"first session ID mismatch; expected %s, got %s", secondID, sess.SessionID)
assert.True(t, bytes.Equal(nextSess.SessionID, firstID),
"second session ID mismatch; expected %s, got %s", firstID, nextSess.SessionID)
})

t.Run("TestExpiredRemoved", func(t *testing.T) {
Expand All @@ -51,21 +49,20 @@ func TestSessionPool(t *testing.T) {
p.timeout = 0

first, err := p.GetSession()
testhelpers.RequireNil(t, err, "error getting session %s", err)
assert.Nil(t, err, "GetSession error: %v", err)
firstID := first.SessionID

second, err := p.GetSession()
testhelpers.RequireNil(t, err, "error getting session %s", err)
assert.Nil(t, err, "GetSession error: %v", err)
secondID := second.SessionID

p.ReturnSession(first)
p.ReturnSession(second)

sess, err := p.GetSession()
testhelpers.RequireNil(t, err, "error getting session %s", err)
assert.Nil(t, err, "GetSession error: %v", err)

if sess.SessionID.Equal(firstID) || sess.SessionID.Equal(secondID) {
t.Errorf("Expired sessions not removed!")
}
assert.False(t, bytes.Equal(sess.SessionID, firstID), "first expired session was not removed")
assert.False(t, bytes.Equal(sess.SessionID, secondID), "second expired session was not removed")
})
}