Skip to content

Commit bf02b96

Browse files
committed
Add Mongo Session passable as an option
GODRIVER-52 Change-Id: I7e93694aea2d8407036b24e019f01ae59ef2e716
1 parent 3d8db41 commit bf02b96

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+2268
-321
lines changed

core/option/options.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,13 @@ type ChangeStreamOptioner interface {
129129
changeStreamOption()
130130
}
131131

132+
// DropCollectionsOptioner is the interface implemented by types that can be used as
133+
// Options for DropCollections operations.
134+
type DropCollectionsOptioner interface {
135+
Optioner
136+
dropCollectionsOption()
137+
}
138+
132139
// ListCollectionsOptioner is the interface implemented by types that can be used as
133140
// Options for ListCollections operations.
134141
type ListCollectionsOptioner interface {

mongo/aggregateopt/aggregateopt.go

Lines changed: 58 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
// Copyright (C) MongoDB, Inc. 2017-present.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License"); you may
4+
// not use this file except in compliance with the License. You may obtain
5+
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6+
17
package aggregateopt
28

39
import (
@@ -6,17 +12,29 @@ import (
612
"reflect"
713

814
"github.com/mongodb/mongo-go-driver/core/option"
15+
"github.com/mongodb/mongo-go-driver/core/session"
916
"github.com/mongodb/mongo-go-driver/mongo/mongoopt"
1017
)
1118

1219
var aggregateBundle = new(AggregateBundle)
1320

14-
// Aggregate is options for the aggregate() function
21+
// Aggregate represents all passable params for the aggregate() function.
1522
type Aggregate interface {
1623
aggregate()
24+
}
25+
26+
// AggregateOption represents the options for the aggregate() function.
27+
type AggregateOption interface {
28+
Aggregate
1729
ConvertAggregateOption() option.AggregateOptioner
1830
}
1931

32+
// AggregateSession is the session for the aggregate() function
33+
type AggregateSession interface {
34+
Aggregate
35+
ConvertAggregateSession() *session.Client
36+
}
37+
2038
// AggregateBundle is a bundle of Aggregate options
2139
type AggregateBundle struct {
2240
option Aggregate
@@ -129,21 +147,23 @@ func (ab *AggregateBundle) bundleLength() int {
129147
continue
130148
}
131149

132-
bundleLen++
150+
if _, ok := ab.option.(AggregateSessionOpt); !ok {
151+
bundleLen++
152+
}
133153
}
134154

135155
return bundleLen
136156
}
137157

138158
// Unbundle transforms a bundle into a slice of options, optionally deduplicating
139-
func (ab *AggregateBundle) Unbundle(deduplicate bool) ([]option.AggregateOptioner, error) {
140-
options, err := ab.unbundle()
159+
func (ab *AggregateBundle) Unbundle(deduplicate bool) ([]option.AggregateOptioner, *session.Client, error) {
160+
options, sess, err := ab.unbundle()
141161
if err != nil {
142-
return nil, err
162+
return nil, nil, err
143163
}
144164

145165
if !deduplicate {
146-
return options, nil
166+
return options, sess, nil
147167
}
148168

149169
// iterate backwards and make dedup slice
@@ -162,15 +182,16 @@ func (ab *AggregateBundle) Unbundle(deduplicate bool) ([]option.AggregateOptione
162182
optionsSet[optionType] = struct{}{}
163183
}
164184

165-
return options, nil
185+
return options, sess, nil
166186
}
167187

168188
// Helper that recursively unwraps bundle into slice of options
169-
func (ab *AggregateBundle) unbundle() ([]option.AggregateOptioner, error) {
189+
func (ab *AggregateBundle) unbundle() ([]option.AggregateOptioner, *session.Client, error) {
170190
if ab == nil {
171-
return nil, nil
191+
return nil, nil, nil
172192
}
173193

194+
var sess *session.Client
174195
listLen := ab.bundleLength()
175196

176197
options := make([]option.AggregateOptioner, listLen)
@@ -179,9 +200,12 @@ func (ab *AggregateBundle) unbundle() ([]option.AggregateOptioner, error) {
179200
for listHead := ab; listHead != nil && listHead.option != nil; listHead = listHead.next {
180201
// if the current option is a nested bundle, Unbundle it and add its options to the current array
181202
if converted, ok := listHead.option.(*AggregateBundle); ok {
182-
nestedOptions, err := converted.unbundle()
203+
nestedOptions, s, err := converted.unbundle()
183204
if err != nil {
184-
return nil, err
205+
return nil, nil, err
206+
}
207+
if s != nil && sess == nil {
208+
sess = s
185209
}
186210

187211
// where to start inserting nested options
@@ -196,11 +220,18 @@ func (ab *AggregateBundle) unbundle() ([]option.AggregateOptioner, error) {
196220
continue
197221
}
198222

199-
options[index] = listHead.option.ConvertAggregateOption()
200-
index--
223+
switch t := listHead.option.(type) {
224+
case AggregateOption:
225+
options[index] = t.ConvertAggregateOption()
226+
index--
227+
case AggregateSession:
228+
if sess == nil {
229+
sess = t.ConvertAggregateSession()
230+
}
231+
}
201232
}
202233

203-
return options, nil
234+
return options, sess, nil
204235
}
205236

206237
// String implements the Stringer interface
@@ -216,7 +247,9 @@ func (ab *AggregateBundle) String() string {
216247
continue
217248
}
218249

219-
str += head.option.ConvertAggregateOption().String() + "\n"
250+
if conv, ok := head.option.(AggregateOption); !ok {
251+
str += conv.ConvertAggregateOption().String() + "\n"
252+
}
220253
}
221254

222255
return str
@@ -328,3 +361,13 @@ func (OptHint) aggregate() {}
328361
func (opt OptHint) ConvertAggregateOption() option.AggregateOptioner {
329362
return option.OptHint(opt)
330363
}
364+
365+
// AggregateSessionOpt is an aggregate session option.
366+
type AggregateSessionOpt struct{}
367+
368+
func (AggregateSessionOpt) aggregate() {}
369+
370+
// ConvertAggregateSession implements the AggregateSession interface.
371+
func (AggregateSessionOpt) ConvertAggregateSession() *session.Client {
372+
return nil
373+
}

mongo/aggregateopt/aggregateopt_test.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ func TestAggregateOpt(t *testing.T) {
156156
Locale: "string locale",
157157
}
158158

159-
opts := []Aggregate{
159+
opts := []AggregateOption{
160160
AllowDiskUse(true),
161161
BatchSize(5),
162162
BypassDocumentValidation(false),
@@ -165,9 +165,13 @@ func TestAggregateOpt(t *testing.T) {
165165
Hint("hint for find"),
166166
MaxTime(5000),
167167
}
168-
bundle := BundleAggregate(opts...)
168+
params := make([]Aggregate, len(opts))
169+
for i := range opts {
170+
params[i] = opts[i]
171+
}
172+
bundle := BundleAggregate(params...)
169173

170-
deleteOpts, err := bundle.Unbundle(true)
174+
deleteOpts, _, err := bundle.Unbundle(true)
171175
testhelpers.RequireNil(t, err, "got non-nill error from unbundle: %s", err)
172176

173177
if len(deleteOpts) != len(opts) {
@@ -219,7 +223,7 @@ func TestAggregateOpt(t *testing.T) {
219223

220224
for _, tc := range cases {
221225
t.Run(tc.name, func(t *testing.T) {
222-
options, err := tc.bundle.Unbundle(tc.dedup)
226+
options, _, err := tc.bundle.Unbundle(tc.dedup)
223227
testhelpers.RequireNil(t, err, "got non-nill error from unbundle: %s", err)
224228

225229
if len(options) != len(tc.expectedOpts) {

mongo/changestreamopt/changestreamopt.go

Lines changed: 49 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,29 @@ import (
77

88
"github.com/mongodb/mongo-go-driver/bson"
99
"github.com/mongodb/mongo-go-driver/core/option"
10+
"github.com/mongodb/mongo-go-driver/core/session"
1011
"github.com/mongodb/mongo-go-driver/mongo/mongoopt"
1112
)
1213

1314
var csBundle = new(ChangeStreamBundle)
1415

15-
// ChangeStream is options for the changeStream()() function.
16+
// ChangeStream represents all passable params for the changeStream() function.
1617
type ChangeStream interface {
1718
changeStream()
19+
}
20+
21+
// ChangeStreamOption represents the options for the changeStream() function.
22+
type ChangeStreamOption interface {
23+
ChangeStream
1824
ConvertChangeStreamOption() option.ChangeStreamOptioner
1925
}
2026

27+
// ChangeStreamSession is the session for the changeStream() function
28+
type ChangeStreamSession interface {
29+
ChangeStream
30+
ConvertChangeStreamSession() *session.Client
31+
}
32+
2133
// ChangeStreamBundle is a bundle of ChangeStream options
2234
type ChangeStreamBundle struct {
2335
option ChangeStream
@@ -98,15 +110,15 @@ func (csb *ChangeStreamBundle) ResumeAfter(d *bson.Document) *ChangeStreamBundle
98110
}
99111

100112
// Unbundle transforms a bundle into a slice of options, optionally deduplicating
101-
func (csb *ChangeStreamBundle) Unbundle(deduplicate bool) ([]option.ChangeStreamOptioner, error) {
113+
func (csb *ChangeStreamBundle) Unbundle(deduplicate bool) ([]option.ChangeStreamOptioner, *session.Client, error) {
102114

103-
options, err := csb.unbundle()
115+
options, sess, err := csb.unbundle()
104116
if err != nil {
105-
return nil, err
117+
return nil, nil, err
106118
}
107119

108120
if !deduplicate {
109-
return options, nil
121+
return options, sess, nil
110122
}
111123

112124
// iterate backwards and make dedup slice
@@ -125,7 +137,7 @@ func (csb *ChangeStreamBundle) Unbundle(deduplicate bool) ([]option.ChangeStream
125137
optionsSet[optionType] = struct{}{}
126138
}
127139

128-
return options, nil
140+
return options, sess, nil
129141
}
130142

131143
// Calculates the total length of a bundle, accounting for nested bundles.
@@ -149,11 +161,12 @@ func (csb *ChangeStreamBundle) bundleLength() int {
149161
}
150162

151163
// Helper that recursively unwraps bundle into slice of options
152-
func (csb *ChangeStreamBundle) unbundle() ([]option.ChangeStreamOptioner, error) {
164+
func (csb *ChangeStreamBundle) unbundle() ([]option.ChangeStreamOptioner, *session.Client, error) {
153165
if csb == nil {
154-
return nil, nil
166+
return nil, nil, nil
155167
}
156168

169+
var sess *session.Client
157170
listLen := csb.bundleLength()
158171

159172
options := make([]option.ChangeStreamOptioner, listLen)
@@ -162,9 +175,12 @@ func (csb *ChangeStreamBundle) unbundle() ([]option.ChangeStreamOptioner, error)
162175
for listHead := csb; listHead != nil && listHead.option != nil; listHead = listHead.next {
163176
// if the current option is a nested bundle, Unbundle it and add its options to the current array
164177
if converted, ok := listHead.option.(*ChangeStreamBundle); ok {
165-
nestedOptions, err := converted.unbundle()
178+
nestedOptions, s, err := converted.unbundle()
166179
if err != nil {
167-
return nil, err
180+
return nil, nil, err
181+
}
182+
if s != nil && sess == nil {
183+
sess = s
168184
}
169185

170186
// where to start inserting nested options
@@ -179,11 +195,18 @@ func (csb *ChangeStreamBundle) unbundle() ([]option.ChangeStreamOptioner, error)
179195
continue
180196
}
181197

182-
options[index] = listHead.option.ConvertChangeStreamOption()
183-
index--
198+
switch t := listHead.option.(type) {
199+
case ChangeStreamOption:
200+
options[index] = t.ConvertChangeStreamOption()
201+
index--
202+
case ChangeStreamSession:
203+
if sess == nil {
204+
sess = t.ConvertChangeStreamSession()
205+
}
206+
}
184207
}
185208

186-
return options, nil
209+
return options, sess, nil
187210
}
188211

189212
// String implements the Stringer interface
@@ -199,7 +222,9 @@ func (csb *ChangeStreamBundle) String() string {
199222
continue
200223
}
201224

202-
str += head.option.ConvertChangeStreamOption().String() + "\n"
225+
if conv, ok := head.option.(ChangeStreamOption); !ok {
226+
str += conv.ConvertChangeStreamOption().String() + "\n"
227+
}
203228
}
204229

205230
return str
@@ -281,3 +306,13 @@ func (OptResumeAfter) changeStream() {}
281306
func (opt OptResumeAfter) ConvertChangeStreamOption() option.ChangeStreamOptioner {
282307
return option.OptResumeAfter(opt)
283308
}
309+
310+
// ChangeStreamSessionOpt is an count session option.
311+
type ChangeStreamSessionOpt struct{}
312+
313+
func (ChangeStreamSessionOpt) changeStream() {}
314+
315+
// ConvertChangeStreamSession implements the ChangeStreamSession interface.
316+
func (ChangeStreamSessionOpt) ConvertChangeStreamSession() *session.Client {
317+
return nil
318+
}

mongo/changestreamopt/changestreamopt_test.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -168,16 +168,20 @@ func TestChangeStreamOpt(t *testing.T) {
168168
Locale: "string locale",
169169
}
170170

171-
opts := []ChangeStream{
171+
opts := []ChangeStreamOption{
172172
BatchSize(5),
173173
Collation(c),
174174
FullDocument(mongoopt.UpdateLookup),
175175
MaxAwaitTime(5000),
176176
ResumeAfter(resumeAfter2),
177177
}
178-
bundle := BundleChangeStream(opts...)
178+
params := make([]ChangeStream, len(opts))
179+
for i := range opts {
180+
params[i] = opts[i]
181+
}
182+
bundle := BundleChangeStream(params...)
179183

180-
csOpts, err := bundle.Unbundle(true)
184+
csOpts, _, err := bundle.Unbundle(true)
181185
testhelpers.RequireNil(t, err, "got non-nill error from unbundle: %s", err)
182186

183187
if len(csOpts) != len(opts) {
@@ -229,7 +233,7 @@ func TestChangeStreamOpt(t *testing.T) {
229233

230234
for _, tc := range cases {
231235
t.Run(tc.name, func(t *testing.T) {
232-
options, err := tc.bundle.Unbundle(tc.dedup)
236+
options, _, err := tc.bundle.Unbundle(tc.dedup)
233237
testhelpers.RequireNil(t, err, "got non-nill error from unbundle: %s", err)
234238

235239
if len(options) != len(tc.expectedOpts) {

0 commit comments

Comments
 (0)