Skip to content

Commit 6a555e4

Browse files
author
Divjot Arora
authored
GODRIVER-1104 Add imperative sessions API (#364)
1 parent 370e317 commit 6a555e4

File tree

4 files changed

+129
-17
lines changed

4 files changed

+129
-17
lines changed

mongo/client.go

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -766,7 +766,7 @@ func (c *Client) ListDatabaseNames(ctx context.Context, filter interface{}, opts
766766
//
767767
// Any error returned by the fn callback will be returned without any modifications.
768768
func WithSession(ctx context.Context, sess Session, fn func(SessionContext) error) error {
769-
return fn(contextWithSession(ctx, sess))
769+
return fn(NewSessionContext(ctx, sess))
770770
}
771771

772772
// UseSession creates a new Session and uses it to create a new SessionContext, which is used to call the fn callback.
@@ -789,13 +789,7 @@ func (c *Client) UseSessionWithOptions(ctx context.Context, opts *options.Sessio
789789
}
790790

791791
defer defaultSess.EndSession(ctx)
792-
793-
sessCtx := sessionContext{
794-
Context: context.WithValue(ctx, sessionKey{}, defaultSess),
795-
Session: defaultSess,
796-
}
797-
798-
return fn(sessCtx)
792+
return fn(NewSessionContext(ctx, defaultSess))
799793
}
800794

801795
// Watch returns a change stream for all changes on the deployment. See

mongo/crud_examples_test.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,41 @@ func ExampleClient_StartSession_withTransaction() {
614614
fmt.Printf("result: %v\n", result)
615615
}
616616

617+
func ExampleNewSessionContext() {
618+
var client *mongo.Client
619+
620+
// Create a new Session and SessionContext.
621+
sess, err := client.StartSession()
622+
if err != nil {
623+
panic(err)
624+
}
625+
defer sess.EndSession(context.TODO())
626+
sessCtx := mongo.NewSessionContext(context.TODO(), sess)
627+
628+
// Start a transaction and sessCtx as the Context parameter to InsertOne and FindOne so both operations will be
629+
// run in the transaction.
630+
if err = sess.StartTransaction(); err != nil {
631+
panic(err)
632+
}
633+
634+
coll := client.Database("db").Collection("coll")
635+
res, err := coll.InsertOne(sessCtx, bson.D{{"x", 1}})
636+
if err != nil {
637+
panic(err)
638+
}
639+
640+
var result bson.M
641+
if err = coll.FindOne(sessCtx, bson.D{{"_id", res.InsertedID}}).Decode(&result); err != nil {
642+
panic(err)
643+
}
644+
fmt.Printf("result: %v\n", result)
645+
646+
// Commit the transaction so the inserted document will be stored.
647+
if err = sess.CommitTransaction(sessCtx); err != nil {
648+
panic(err)
649+
}
650+
}
651+
617652
// Cursor examples
618653

619654
func ExampleCursor_All() {

mongo/integration/sessions_test.go

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,68 @@ func TestSessions(t *testing.T) {
254254
})
255255
}
256256
})
257+
258+
mt.Run("imperative API", func(mt *mtest.T) {
259+
mt.Run("round trip Session object", func(mt *mtest.T) {
260+
// Rountrip a Session object through NewSessionContext/ContextFromSession and assert that it is correctly
261+
// stored/retrieved.
262+
263+
sess, err := mt.Client.StartSession()
264+
assert.Nil(mt, err, "StartSession error: %v", err)
265+
defer sess.EndSession(mtest.Background)
266+
267+
sessCtx := mongo.NewSessionContext(mtest.Background, sess)
268+
assert.Equal(mt, sess.ID(), sessCtx.ID(), "expected Session ID %v, got %v", sess.ID(), sessCtx.ID())
269+
270+
gotSess := mongo.SessionFromContext(sessCtx)
271+
assert.NotNil(mt, gotSess, "expected SessionFromContext to return non-nil value, got nil")
272+
assert.Equal(mt, sess.ID(), gotSess.ID(), "expected Session ID %v, got %v", sess.ID(), gotSess.ID())
273+
})
274+
275+
txnOpts := mtest.NewOptions().RunOn(
276+
mtest.RunOnBlock{Topology: []mtest.TopologyKind{mtest.ReplicaSet}, MinServerVersion: "4.0"},
277+
mtest.RunOnBlock{Topology: []mtest.TopologyKind{mtest.Sharded}, MinServerVersion: "4.2"},
278+
)
279+
mt.RunOpts("run transaction", txnOpts, func(mt *mtest.T) {
280+
// Test that the imperative sessions API can be used to run a transaction.
281+
282+
createSessionContext := func(mt *mtest.T) mongo.SessionContext {
283+
sess, err := mt.Client.StartSession()
284+
assert.Nil(mt, err, "StartSession error: %v", err)
285+
286+
return mongo.NewSessionContext(mtest.Background, sess)
287+
}
288+
289+
sessCtx := createSessionContext(mt)
290+
sess := mongo.SessionFromContext(sessCtx)
291+
assert.NotNil(mt, sess, "expected SessionFromContext to return non-nil value, got nil")
292+
defer sess.EndSession(mtest.Background)
293+
294+
err := sess.StartTransaction()
295+
assert.Nil(mt, err, "StartTransaction error: %v", err)
296+
297+
numDocs := 2
298+
for i := 0; i < numDocs; i++ {
299+
_, err = mt.Coll.InsertOne(sessCtx, bson.D{{"x", 1}})
300+
assert.Nil(mt, err, "InsertOne error at index %d: %v", i, err)
301+
}
302+
303+
// Assert that the collection count is 0 before committing and numDocs after. This tests that the InsertOne
304+
// calls were actually executed in the transaction because the pre-commit count does not include them.
305+
assertCollectionCount(mt, 0)
306+
err = sess.CommitTransaction(sessCtx)
307+
assert.Nil(mt, err, "CommitTransaction error: %v", err)
308+
assertCollectionCount(mt, int64(numDocs))
309+
})
310+
})
311+
}
312+
313+
func assertCollectionCount(mt *mtest.T, expectedCount int64) {
314+
mt.Helper()
315+
316+
count, err := mt.Coll.CountDocuments(mtest.Background, bson.D{})
317+
assert.Nil(mt, err, "CountDocuments error: %v", err)
318+
assert.Equal(mt, expectedCount, count, "expected CountDocuments result %v, got %v", expectedCount, count)
257319
}
258320

259321
type sessionFunction struct {

mongo/session.go

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ var withTransactionTimeout = 120 * time.Second
3030
// SessionContext combines the context.Context and mongo.Session interfaces. It should be used as the Context arguments
3131
// to operations that should be executed in a session. This type is not goroutine safe and must not be used concurrently
3232
// by multiple goroutines.
33+
//
34+
// There are two ways to create a SessionContext and use it in a session/transaction. The first is to use one of the
35+
// callback-based functions such as WithSession and UseSession. These functions create a SessionContext and pass it to
36+
// the provided callback. The other is to use NewSessionContext to explicitly create a SessionContext.
3337
type SessionContext interface {
3438
context.Context
3539
Session
@@ -43,6 +47,31 @@ type sessionContext struct {
4347
type sessionKey struct {
4448
}
4549

50+
// NewSessionContext creates a new SessionContext associated with the given Context and Session parameters.
51+
func NewSessionContext(ctx context.Context, sess Session) SessionContext {
52+
return &sessionContext{
53+
Context: context.WithValue(ctx, sessionKey{}, sess),
54+
Session: sess,
55+
}
56+
}
57+
58+
// SessionFromContext extracts the mongo.Session object stored in a Context. This can be used on a SessionContext that
59+
// was created implicitly through one of the callback-based session APIs or explicitly by calling NewSessionContext. If
60+
// there is no Session stored in the provided Context, nil is returned.
61+
func SessionFromContext(ctx context.Context) Session {
62+
val := ctx.Value(sessionKey{})
63+
if val == nil {
64+
return nil
65+
}
66+
67+
sess, ok := val.(Session)
68+
if !ok {
69+
return nil
70+
}
71+
72+
return sess
73+
}
74+
4675
// Session is an interface that represents a MongoDB logical session. Sessions can be used to enable causal consistency
4776
// for a group of operations or to execute operations in an ACID transaction. A new Session can be created from a Client
4877
// instance. A Session created from a Client must only be used to execute operations using that Client or a Database or
@@ -145,7 +174,7 @@ func (s *sessionImpl) WithTransaction(ctx context.Context, fn func(sessCtx Sessi
145174
return nil, err
146175
}
147176

148-
res, err := fn(contextWithSession(ctx, s))
177+
res, err := fn(NewSessionContext(ctx, s))
149178
if err != nil {
150179
if s.clientSession.TransactionRunning() {
151180
_ = s.AbortTransaction(ctx)
@@ -322,11 +351,3 @@ func sessionFromContext(ctx context.Context) *session.Client {
322351

323352
return nil
324353
}
325-
326-
// contextWithSession creates a new SessionContext associated with the given Context and Session parameters.
327-
func contextWithSession(ctx context.Context, sess Session) SessionContext {
328-
return &sessionContext{
329-
Context: context.WithValue(ctx, sessionKey{}, sess),
330-
Session: sess,
331-
}
332-
}

0 commit comments

Comments
 (0)