Skip to content

Commit e18bcf0

Browse files
committed
Add Session Client Methods
GODRIVER-52 Change-Id: I4db2bbee2cb5105482c0a5d2182bd909af51c857
1 parent bf02b96 commit e18bcf0

File tree

8 files changed

+573
-102
lines changed

8 files changed

+573
-102
lines changed

mongo/change_stream.go

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"github.com/mongodb/mongo-go-driver/bson"
1515
"github.com/mongodb/mongo-go-driver/core/command"
1616
"github.com/mongodb/mongo-go-driver/core/option"
17+
"github.com/mongodb/mongo-go-driver/core/session"
1718
"github.com/mongodb/mongo-go-driver/mongo/changestreamopt"
1819
)
1920

@@ -26,6 +27,8 @@ type changeStream struct {
2627
options []option.ChangeStreamOptioner
2728
coll *Collection
2829
cursor Cursor
30+
session *session.Client
31+
clock *session.ClusterClock
2932
resumeToken *bson.Document
3033
err error
3134
}
@@ -41,7 +44,12 @@ func newChangeStream(ctx context.Context, coll *Collection, pipeline interface{}
4144
return nil, err
4245
}
4346

44-
csOpts, err := changestreamopt.BundleChangeStream(opts...).Unbundle(true)
47+
csOpts, sess, err := changestreamopt.BundleChangeStream(opts...).Unbundle(true)
48+
if err != nil {
49+
return nil, err
50+
}
51+
52+
err = coll.client.ValidSession(sess)
4553
if err != nil {
4654
return nil, err
4755
}
@@ -70,6 +78,8 @@ func newChangeStream(ctx context.Context, coll *Collection, pipeline interface{}
7078
options: csOpts,
7179
coll: coll,
7280
cursor: cursor,
81+
session: sess,
82+
clock: coll.client.clock,
7383
}
7484

7585
return cs, nil
@@ -152,8 +162,13 @@ func (cs *changeStream) Next(ctx context.Context) bool {
152162
aggCmd := command.Aggregate{
153163
NS: command.Namespace{DB: oldns.DB, Collection: oldns.Collection},
154164
Pipeline: cs.pipeline,
165+
Session: cs.session,
166+
Clock: cs.coll.client.clock,
155167
}
156-
cs.cursor, cs.err = aggCmd.RoundTrip(ctx, ss.Description(), ss, conn)
168+
169+
cur, err := aggCmd.RoundTrip(ctx, ss.Description(), ss, conn)
170+
cs.cursor = cur
171+
cs.err = err
157172

158173
if cs.err != nil {
159174
return false

mongo/client.go

Lines changed: 86 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,24 +14,28 @@ import (
1414
"github.com/mongodb/mongo-go-driver/core/connstring"
1515
"github.com/mongodb/mongo-go-driver/core/description"
1616
"github.com/mongodb/mongo-go-driver/core/dispatch"
17-
"github.com/mongodb/mongo-go-driver/core/option"
1817
"github.com/mongodb/mongo-go-driver/core/readconcern"
1918
"github.com/mongodb/mongo-go-driver/core/readpref"
19+
"github.com/mongodb/mongo-go-driver/core/session"
2020
"github.com/mongodb/mongo-go-driver/core/tag"
2121
"github.com/mongodb/mongo-go-driver/core/topology"
22+
"github.com/mongodb/mongo-go-driver/core/uuid"
2223
"github.com/mongodb/mongo-go-driver/core/writeconcern"
2324
"github.com/mongodb/mongo-go-driver/mongo/clientopt"
2425
"github.com/mongodb/mongo-go-driver/mongo/dbopt"
26+
"github.com/mongodb/mongo-go-driver/mongo/listdbopt"
2527
)
2628

2729
const defaultLocalThreshold = 15 * time.Millisecond
2830

2931
// Client performs operations on a given topology.
3032
type Client struct {
33+
id uuid.UUID
3134
topologyOptions []topology.Option
3235
topology *topology.Topology
3336
connString connstring.ConnString
3437
localThreshold time.Duration
38+
clock *session.ClusterClock
3539
readPreference *readpref.ReadPref
3640
readConcern *readconcern.ReadConcern
3741
writeConcern *writeconcern.WriteConcern
@@ -81,7 +85,13 @@ func NewClientFromConnString(cs connstring.ConnString) (*Client, error) {
8185
// Connect initializes the Client by starting background monitoring goroutines.
8286
// This method must be called before a Client can be used.
8387
func (c *Client) Connect(ctx context.Context) error {
84-
return c.topology.Connect(ctx)
88+
err := c.topology.Connect(ctx)
89+
if err != nil {
90+
return err
91+
}
92+
93+
return nil
94+
8595
}
8696

8797
// Disconnect closes sockets to the topology referenced by this Client. It will
@@ -93,9 +103,33 @@ func (c *Client) Connect(ctx context.Context) error {
93103
// or write operations. If this method returns with no errors, all connections
94104
// associated with this Client have been closed.
95105
func (c *Client) Disconnect(ctx context.Context) error {
106+
c.endSessions(ctx)
96107
return c.topology.Disconnect(ctx)
97108
}
98109

110+
// StartSession starts a new session.
111+
func (c *Client) StartSession() (*Session, error) {
112+
if c.topology.SessionPool == nil {
113+
return nil, topology.ErrTopologyClosed
114+
}
115+
116+
sess, err := session.NewClientSession(c.topology.SessionPool, c.id, session.Explicit)
117+
if err != nil {
118+
return nil, err
119+
}
120+
121+
return &Session{Client: sess}, nil
122+
}
123+
124+
func (c *Client) endSessions(ctx context.Context) {
125+
cmd := command.EndSessions{
126+
Clock: c.clock,
127+
SessionIDs: c.topology.SessionPool.IDSlice(),
128+
}
129+
130+
_, _ = dispatch.EndSessions(ctx, cmd, c.topology, description.ReadPrefSelector(readpref.PrimaryPreferred()))
131+
}
132+
99133
func newClient(cs connstring.ConnString, opts ...clientopt.Option) (*Client, error) {
100134
clientOpt, err := clientopt.BundleClient(opts...).Unbundle(cs)
101135
if err != nil {
@@ -108,15 +142,27 @@ func newClient(cs connstring.ConnString, opts ...clientopt.Option) (*Client, err
108142
localThreshold: defaultLocalThreshold,
109143
}
110144

145+
uuid, err := uuid.New()
146+
if err != nil {
147+
return nil, err
148+
}
149+
client.id = uuid
150+
111151
topts := append(
112152
client.topologyOptions,
113153
topology.WithConnString(func(connstring.ConnString) connstring.ConnString { return client.connString }),
154+
topology.WithServerOptions(func(opts ...topology.ServerOption) []topology.ServerOption {
155+
return append(opts, topology.WithClock(func(clock *session.ClusterClock) *session.ClusterClock {
156+
return client.clock
157+
}))
158+
}),
114159
)
115160
topo, err := topology.New(topts...)
116161
if err != nil {
117162
return nil, err
118163
}
119164
client.topology = topo
165+
client.clock = &session.ClusterClock{}
120166

121167
if client.readConcern == nil {
122168
client.readConcern = readConcernFromConnString(&client.connString)
@@ -212,6 +258,14 @@ func readPreferenceFromConnString(cs *connstring.ConnString) (*readpref.ReadPref
212258
return rp, nil
213259
}
214260

261+
// ValidSession returns an error if the session doesn't belong to the client
262+
func (c *Client) ValidSession(sess *session.Client) error {
263+
if sess != nil && !uuid.Equal(sess.ClientID, c.id) {
264+
return ErrWrongClient
265+
}
266+
return nil
267+
}
268+
215269
// Database returns a handle for a given database.
216270
func (c *Client) Database(name string, opts ...dbopt.Option) *Database {
217271
return newDatabase(c, name, opts...)
@@ -222,39 +276,51 @@ func (c *Client) ConnectionString() string {
222276
return c.connString.Original
223277
}
224278

225-
func (c *Client) listDatabasesHelper(ctx context.Context, filter interface{},
226-
nameOnly bool) (ListDatabasesResult, error) {
227-
228-
f, err := TransformDocument(filter)
279+
// ListDatabases returns a ListDatabasesResult.
280+
func (c *Client) ListDatabases(ctx context.Context, filter interface{}, opts ...listdbopt.ListDatabases) (ListDatabasesResult, error) {
281+
if ctx == nil {
282+
ctx = context.Background()
283+
}
284+
listDbOpts, sess, err := listdbopt.BundleListDatabases(opts...).Unbundle(true)
229285
if err != nil {
230286
return ListDatabasesResult{}, err
231287
}
232288

233-
opts := []option.ListDatabasesOptioner{}
289+
err = c.ValidSession(sess)
290+
if err != nil {
291+
return ListDatabasesResult{}, err
292+
}
234293

235-
if nameOnly {
236-
opts = append(opts, option.OptNameOnly(nameOnly))
294+
f, err := TransformDocument(filter)
295+
if err != nil {
296+
return ListDatabasesResult{}, err
237297
}
238298

239-
cmd := command.ListDatabases{Filter: f, Opts: opts}
299+
cmd := command.ListDatabases{
300+
Filter: f,
301+
Opts: listDbOpts,
302+
Session: sess,
303+
Clock: c.clock,
304+
}
240305

241-
// The spec indicates that we should not run the listDatabase command on a secondary in a
242-
// replica set.
243-
res, err := dispatch.ListDatabases(ctx, cmd, c.topology, description.ReadPrefSelector(readpref.Primary()))
306+
res, err := dispatch.ListDatabases(
307+
ctx, cmd,
308+
c.topology,
309+
description.ReadPrefSelector(readpref.Primary()),
310+
c.id,
311+
c.topology.SessionPool,
312+
)
244313
if err != nil {
245314
return ListDatabasesResult{}, err
246315
}
247-
return (ListDatabasesResult{}).fromResult(res), nil
248-
}
249316

250-
// ListDatabases returns a ListDatabasesResult.
251-
func (c *Client) ListDatabases(ctx context.Context, filter interface{}) (ListDatabasesResult, error) {
252-
return c.listDatabasesHelper(ctx, filter, false)
317+
return (ListDatabasesResult{}).fromResult(res), nil
253318
}
254319

255320
// ListDatabaseNames returns a slice containing the names of all of the databases on the server.
256-
func (c *Client) ListDatabaseNames(ctx context.Context, filter interface{}) ([]string, error) {
257-
res, err := c.listDatabasesHelper(ctx, filter, true)
321+
func (c *Client) ListDatabaseNames(ctx context.Context, filter interface{}, opts ...listdbopt.ListDatabases) ([]string, error) {
322+
opts = append(opts, listdbopt.NameOnly(true))
323+
res, err := c.ListDatabases(ctx, filter, opts...)
258324
if err != nil {
259325
return nil, err
260326
}

mongo/client_internal_test.go

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,20 @@ import (
2222

2323
"time"
2424

25+
"github.com/mongodb/mongo-go-driver/core/session"
26+
"github.com/mongodb/mongo-go-driver/core/uuid"
27+
"github.com/mongodb/mongo-go-driver/core/writeconcern"
2528
"github.com/mongodb/mongo-go-driver/mongo/clientopt"
2629
)
2730

2831
func createTestClient(t *testing.T) *Client {
32+
id, _ := uuid.New()
2933
return &Client{
34+
id: id,
3035
topology: testutil.Topology(t),
3136
connString: testutil.ConnString(t),
3237
readPreference: readpref.Primary(),
38+
clock: &session.ClusterClock{},
3339
}
3440
}
3541

@@ -197,10 +203,11 @@ func TestClient_ListDatabases_noFilter(t *testing.T) {
197203
}
198204

199205
dbName := "listDatabases_noFilter"
200-
201206
c := createTestClient(t)
202207
db := c.Database(dbName)
203-
_, err := db.Collection("test").InsertOne(
208+
coll := db.Collection("test")
209+
coll.writeConcern = writeconcern.New(writeconcern.WMajority())
210+
_, err := coll.InsertOne(
204211
context.Background(),
205212
bson.NewDocument(
206213
bson.EC.Int32("x", 1),
@@ -235,7 +242,9 @@ func TestClient_ListDatabases_filter(t *testing.T) {
235242

236243
c := createTestClient(t)
237244
db := c.Database(dbName)
238-
_, err := db.Collection("test").InsertOne(
245+
coll := db.Collection("test")
246+
coll.writeConcern = writeconcern.New(writeconcern.WMajority())
247+
_, err := coll.InsertOne(
239248
context.Background(),
240249
bson.NewDocument(
241250
bson.EC.Int32("x", 1),
@@ -265,7 +274,10 @@ func TestClient_ListDatabaseNames_noFilter(t *testing.T) {
265274

266275
c := createTestClient(t)
267276
db := c.Database(dbName)
268-
_, err := db.Collection("test").InsertOne(
277+
coll := db.Collection("test")
278+
279+
coll.writeConcern = writeconcern.New(writeconcern.WMajority())
280+
_, err := coll.InsertOne(
269281
context.Background(),
270282
bson.NewDocument(
271283
bson.EC.Int32("x", 1),
@@ -298,7 +310,9 @@ func TestClient_ListDatabaseNames_filter(t *testing.T) {
298310

299311
c := createTestClient(t)
300312
db := c.Database(dbName)
301-
_, err := db.Collection("test").InsertOne(
313+
coll := db.Collection("test")
314+
coll.writeConcern = writeconcern.New(writeconcern.WMajority())
315+
_, err := coll.InsertOne(
302316
context.Background(),
303317
bson.NewDocument(
304318
bson.EC.Int32("x", 1),

0 commit comments

Comments
 (0)