Skip to content

Commit 7c7f13e

Browse files
Fix maxAwaitTimeMS implementation
GODRIVER-266 Change-Id: Ie7f9ca4d0bf5f9fa75ca5332c48888157ef3f2de
1 parent e9228a4 commit 7c7f13e

File tree

10 files changed

+382
-11
lines changed

10 files changed

+382
-11
lines changed

core/command/aggregate.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ func (a *Aggregate) encode(desc description.SelectedServer) (*Read, error) {
5959

6060
for _, opt := range a.Opts {
6161
switch t := opt.(type) {
62-
case nil:
62+
case nil, option.OptMaxAwaitTime:
6363
continue
6464
case option.OptBatchSize:
6565
if t == 0 && a.HasDollarOut() {

core/command/find.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ func (f *Find) encode(desc description.SelectedServer) (*Read, error) {
6161

6262
for _, opt := range f.Opts {
6363
switch t := opt.(type) {
64-
case nil:
64+
case nil, option.OptMaxAwaitTime:
6565
continue
6666
case option.OptLimit:
6767
limit = int64(t)

core/command/get_more.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ package command
88

99
import (
1010
"context"
11-
1211
"github.com/mongodb/mongo-go-driver/bson"
1312
"github.com/mongodb/mongo-go-driver/core/description"
1413
"github.com/mongodb/mongo-go-driver/core/option"
@@ -45,8 +44,16 @@ func (gm *GetMore) encode(desc description.SelectedServer) (*Read, error) {
4544
bson.EC.Int64("getMore", gm.ID),
4645
bson.EC.String("collection", gm.NS.Collection),
4746
)
47+
48+
var err error
49+
4850
for _, opt := range gm.Opts {
49-
err := opt.Option(cmd)
51+
switch t := opt.(type) {
52+
case option.OptMaxAwaitTime:
53+
err = option.OptMaxTime(t).Option(cmd)
54+
default:
55+
err = opt.Option(cmd)
56+
}
5057
if err != nil {
5158
return nil, err
5259
}

core/integration/aggregate_test.go

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ package integration
99
import (
1010
"bytes"
1111
"context"
12+
"fmt"
13+
"os"
1214
"strings"
1315
"testing"
1416
"time"
@@ -21,6 +23,7 @@ import (
2123
"github.com/mongodb/mongo-go-driver/core/topology"
2224
"github.com/mongodb/mongo-go-driver/core/writeconcern"
2325
"github.com/mongodb/mongo-go-driver/internal/testutil"
26+
"github.com/stretchr/testify/assert"
2427
)
2528

2629
func TestCommandAggregate(t *testing.T) {
@@ -152,3 +155,129 @@ func TestCommandAggregate(t *testing.T) {
152155
noerr(t, err)
153156
})
154157
}
158+
159+
func TestAggregatePassesMaxAwaitTimeMSThroughToGetMore(t *testing.T) {
160+
if os.Getenv("TOPOLOGY") != "replica_set" {
161+
t.Skip()
162+
}
163+
164+
startedChan, succeededChan, failedChan, monitor := initMonitor()
165+
166+
dbName := fmt.Sprintf("mongo-go-driver-%d-agg", os.Getpid())
167+
colName := testutil.ColName(t)
168+
169+
server, err := testutil.MonitoredTopology(t, dbName, monitor).SelectServer(context.Background(), description.WriteSelector())
170+
noerr(t, err)
171+
172+
versionCmd := bson.NewDocument(bson.EC.Int32("serverStatus", 1))
173+
serverStatus, err := testutil.RunCommand(t, server.Server, dbName, versionCmd)
174+
version, err := serverStatus.Lookup("version")
175+
176+
if compareVersions(t, version.Value().StringValue(), "3.6") < 0 {
177+
t.Skip()
178+
}
179+
180+
// create capped collection
181+
createCmd := bson.NewDocument(
182+
bson.EC.String("create", colName),
183+
bson.EC.Boolean("capped", true),
184+
bson.EC.Int32("size", 1000))
185+
_, err = testutil.RunCommand(t, server.Server, dbName, createCmd)
186+
noerr(t, err)
187+
188+
conn, err := server.Connection(context.Background())
189+
noerr(t, err)
190+
191+
// create an aggregate command that results with a TAILABLEAWAIT cursor
192+
cursor, err := (&command.Aggregate{
193+
NS: command.Namespace{DB: dbName, Collection: testutil.ColName(t)},
194+
Pipeline: bson.NewArray(
195+
bson.VC.Document(bson.NewDocument(
196+
bson.EC.SubDocument("$changeStream", bson.NewDocument()))),
197+
bson.VC.Document(bson.NewDocument(
198+
bson.EC.SubDocument("$match", bson.NewDocument(
199+
bson.EC.SubDocument("fullDocument._id", bson.NewDocument(bson.EC.Int32("$gte", 1))),
200+
))))),
201+
Opts: []option.AggregateOptioner{option.OptBatchSize(2), option.OptMaxAwaitTime(time.Millisecond * 50)},
202+
}).RoundTrip(context.Background(), server.SelectedDescription(), server, conn)
203+
noerr(t, err)
204+
205+
// insert some documents
206+
insertCmd := bson.NewDocument(
207+
bson.EC.String("insert", colName),
208+
bson.EC.ArrayFromElements("documents",
209+
bson.VC.Document(bson.NewDocument(bson.EC.Int32("_id", 1))),
210+
bson.VC.Document(bson.NewDocument(bson.EC.Int32("_id", 2))),
211+
bson.VC.Document(bson.NewDocument(bson.EC.Int32("_id", 3)))))
212+
_, err = testutil.RunCommand(t, server.Server, dbName, insertCmd)
213+
214+
// wait a bit between insert and getMore commands
215+
time.Sleep(time.Millisecond * 100)
216+
217+
ctx, cancel := context.WithCancel(context.Background())
218+
time.AfterFunc(time.Millisecond*900, cancel)
219+
for cursor.Next(ctx) {
220+
}
221+
222+
// allow for iteration over range chan
223+
close(startedChan)
224+
close(succeededChan)
225+
close(failedChan)
226+
227+
// no commands should have failed
228+
if len(failedChan) != 0 {
229+
t.Errorf("%d command(s) failed", len(failedChan))
230+
}
231+
232+
// check the expected commands were started
233+
for started := range startedChan {
234+
switch started.CommandName {
235+
case "aggregate":
236+
assert.Equal(t, 2, int(started.Command.Lookup("cursor", "batchSize").Int32()))
237+
assert.Nil(t, started.Command.Lookup("maxAwaitTimeMS"),
238+
"Should not have sent maxAwaitTimeMS in find command")
239+
case "getMore":
240+
assert.Equal(t, 2, int(started.Command.Lookup("batchSize").Int32()))
241+
assert.Equal(t, 50, int(started.Command.Lookup("maxTimeMS").Int64()),
242+
"Should have sent maxTimeMS in getMore command")
243+
default:
244+
continue
245+
}
246+
}
247+
248+
// to keep track of seen documents
249+
id := 1
250+
251+
// check expected commands succeeded
252+
for succeeded := range succeededChan {
253+
switch succeeded.CommandName {
254+
case "aggregate":
255+
assert.Equal(t, 1, int(succeeded.Reply.Lookup("ok").Double()))
256+
257+
actual := succeeded.Reply.Lookup("cursor", "firstBatch").MutableArray()
258+
259+
for i := 0; i < actual.Len(); i++ {
260+
v, _ := actual.Lookup(uint(i))
261+
assert.Equal(t, id, int(v.MutableDocument().Lookup("fullDocument", "_id").Int32()))
262+
id++
263+
}
264+
case "getMore":
265+
assert.Equal(t, "getMore", succeeded.CommandName)
266+
assert.Equal(t, 1, int(succeeded.Reply.Lookup("ok").Double()))
267+
268+
actual := succeeded.Reply.Lookup("cursor", "nextBatch").MutableArray()
269+
270+
for i := 0; i < actual.Len(); i++ {
271+
v, _ := actual.Lookup(uint(i))
272+
assert.Equal(t, id, int(v.MutableDocument().Lookup("fullDocument", "_id").Int32()))
273+
id++
274+
}
275+
default:
276+
continue
277+
}
278+
}
279+
280+
if id <= 3 {
281+
t.Errorf("not all documents returned; last seen id = %d", id-1)
282+
}
283+
}

core/integration/find_test.go

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
// Copyright (C) MongoDB, Inc. 2017-present.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License"); you may
4+
// not use this file except in compliance with the License. You may obtain
5+
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6+
7+
package integration
8+
9+
import (
10+
"context"
11+
"fmt"
12+
"os"
13+
"testing"
14+
"time"
15+
16+
"github.com/mongodb/mongo-go-driver/bson"
17+
"github.com/mongodb/mongo-go-driver/core/command"
18+
"github.com/mongodb/mongo-go-driver/core/description"
19+
"github.com/mongodb/mongo-go-driver/core/event"
20+
"github.com/mongodb/mongo-go-driver/core/option"
21+
"github.com/mongodb/mongo-go-driver/internal/testutil"
22+
"github.com/stretchr/testify/assert"
23+
)
24+
25+
func initMonitor() (chan *event.CommandStartedEvent, chan *event.CommandSucceededEvent, chan *event.CommandFailedEvent, *event.CommandMonitor) {
26+
startedChan := make(chan *event.CommandStartedEvent, 100)
27+
succeededChan := make(chan *event.CommandSucceededEvent, 100)
28+
failedChan := make(chan *event.CommandFailedEvent, 100)
29+
monitor := &event.CommandMonitor{
30+
Started: func(ctx context.Context, cse *event.CommandStartedEvent) {
31+
startedChan <- cse
32+
},
33+
Succeeded: func(ctx context.Context, cse *event.CommandSucceededEvent) {
34+
succeededChan <- cse
35+
},
36+
Failed: func(ctx context.Context, cfe *event.CommandFailedEvent) {
37+
failedChan <- cfe
38+
},
39+
}
40+
41+
return startedChan, succeededChan, failedChan, monitor
42+
}
43+
44+
func TestFindPassesMaxAwaitTimeMSThroughToGetMore(t *testing.T) {
45+
startedChan, succeededChan, failedChan, monitor := initMonitor()
46+
47+
dbName := fmt.Sprintf("mongo-go-driver-%d-find", os.Getpid())
48+
colName := testutil.ColName(t)
49+
50+
server, err := testutil.MonitoredTopology(t, dbName, monitor).SelectServer(context.Background(), description.WriteSelector())
51+
noerr(t, err)
52+
53+
// create capped collection
54+
createCmd := bson.NewDocument(
55+
bson.EC.String("create", colName),
56+
bson.EC.Boolean("capped", true),
57+
bson.EC.Int32("size", 1000))
58+
_, err = testutil.RunCommand(t, server.Server, dbName, createCmd)
59+
noerr(t, err)
60+
61+
// insert some documents
62+
insertCmd := bson.NewDocument(
63+
bson.EC.String("insert", colName),
64+
bson.EC.ArrayFromElements("documents",
65+
bson.VC.Document(bson.NewDocument(bson.EC.Int32("_id", 1))),
66+
bson.VC.Document(bson.NewDocument(bson.EC.Int32("_id", 2))),
67+
bson.VC.Document(bson.NewDocument(bson.EC.Int32("_id", 3))),
68+
bson.VC.Document(bson.NewDocument(bson.EC.Int32("_id", 4))),
69+
bson.VC.Document(bson.NewDocument(bson.EC.Int32("_id", 5)))))
70+
_, err = testutil.RunCommand(t, server.Server, dbName, insertCmd)
71+
72+
conn, err := server.Connection(context.Background())
73+
noerr(t, err)
74+
75+
// find those documents, setting cursor type to TAILABLEAWAIT
76+
cursor, err := (&command.Find{
77+
NS: command.Namespace{DB: dbName, Collection: colName},
78+
Filter: bson.NewDocument(bson.EC.SubDocument("_id", bson.NewDocument(bson.EC.Int32("$gte", 1)))),
79+
Opts: []option.FindOptioner{
80+
option.OptBatchSize(3),
81+
option.OptMaxAwaitTime(time.Millisecond * 250),
82+
option.OptCursorType(option.TailableAwait)},
83+
}).RoundTrip(context.Background(), server.SelectedDescription(), server, conn)
84+
noerr(t, err)
85+
86+
// exhaust the cursor, triggering getMore commands
87+
for i := 0; i < 4; i++ {
88+
cursor.Next(context.Background())
89+
}
90+
91+
// allow for iteration over range chan
92+
close(startedChan)
93+
close(succeededChan)
94+
close(failedChan)
95+
96+
// no commands should have failed
97+
if len(failedChan) != 0 {
98+
t.Errorf("%d command(s) failed", len(failedChan))
99+
}
100+
101+
// check that the expected commands were started
102+
for started := range startedChan {
103+
switch started.CommandName {
104+
case "find":
105+
assert.Equal(t, 3, int(started.Command.Lookup("batchSize").Int32()))
106+
assert.True(t, started.Command.Lookup("tailable").Boolean())
107+
assert.True(t, started.Command.Lookup("awaitData").Boolean())
108+
assert.Nil(t, started.Command.Lookup("maxAwaitTimeMS"),
109+
"Should not have sent maxAwaitTimeMS in find command")
110+
case "getMore":
111+
assert.Equal(t, 3, int(started.Command.Lookup("batchSize").Int32()))
112+
assert.Equal(t, 250, int(started.Command.Lookup("maxTimeMS").Int64()),
113+
"Should have sent maxTimeMS in getMore command")
114+
default:
115+
continue
116+
}
117+
}
118+
119+
// to keep track of seen documents
120+
id := 1
121+
122+
// check expected commands succeeded
123+
for succeeded := range succeededChan {
124+
switch succeeded.CommandName {
125+
case "find":
126+
assert.Equal(t, 1, int(succeeded.Reply.Lookup("ok").Double()))
127+
128+
actual := succeeded.Reply.Lookup("cursor", "firstBatch").MutableArray()
129+
130+
for i := 0; i < actual.Len(); i++ {
131+
v, _ := actual.Lookup(uint(i))
132+
assert.Equal(t, id, int(v.MutableDocument().Lookup("_id").Int32()))
133+
id++
134+
}
135+
case "getMore":
136+
assert.Equal(t, "getMore", succeeded.CommandName)
137+
assert.Equal(t, 1, int(succeeded.Reply.Lookup("ok").Double()))
138+
139+
actual := succeeded.Reply.Lookup("cursor", "nextBatch").MutableArray()
140+
141+
for i := 0; i < actual.Len(); i++ {
142+
v, _ := actual.Lookup(uint(i))
143+
assert.Equal(t, id, int(v.MutableDocument().Lookup("_id").Int32()))
144+
id++
145+
}
146+
default:
147+
continue
148+
}
149+
}
150+
151+
if id <= 5 {
152+
t.Errorf("not all documents returned; last seen id = %d", id-1)
153+
}
154+
}

core/option/options.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,13 +185,15 @@ var (
185185
_ AggregateOptioner = (*OptCollation)(nil)
186186
_ AggregateOptioner = (*OptComment)(nil)
187187
_ AggregateOptioner = (*OptMaxTime)(nil)
188+
_ AggregateOptioner = (*OptMaxAwaitTime)(nil)
188189
_ CountOptioner = (*OptCollation)(nil)
189190
_ CountOptioner = (*OptHint)(nil)
190191
_ CountOptioner = (*OptLimit)(nil)
191192
_ CountOptioner = (*OptMaxTime)(nil)
192193
_ CountOptioner = (*OptSkip)(nil)
193194
_ CreateIndexesOptioner = (*OptMaxTime)(nil)
194195
_ CursorOptioner = OptBatchSize(0)
196+
_ CursorOptioner = (*OptMaxAwaitTime)(nil)
195197
_ DeleteOptioner = (*OptCollation)(nil)
196198
_ DistinctOptioner = (*OptCollation)(nil)
197199
_ DistinctOptioner = (*OptMaxTime)(nil)
@@ -585,7 +587,9 @@ func (opt OptMaxAwaitTime) Option(d *bson.Document) error {
585587
return nil
586588
}
587589

590+
func (OptMaxAwaitTime) aggregateOption() {}
588591
func (OptMaxAwaitTime) changeStreamOption() {}
592+
func (OptMaxAwaitTime) cursorOption() {}
589593
func (OptMaxAwaitTime) findOption() {}
590594
func (OptMaxAwaitTime) findOneOption() {}
591595

0 commit comments

Comments
 (0)