Skip to content

Commit 6101f15

Browse files
author
Divjot Arora
authored
Simplify unified test runner (#416)
This commit changes the test runner to use the default ClientOptions when starting a new test and then using that Client for test setup. After setup, the Client is reset to use the ClientOptions specified in the test.
1 parent ed2bb99 commit 6101f15

File tree

3 files changed

+94
-74
lines changed

3 files changed

+94
-74
lines changed

mongo/integration/crud_helpers_test.go

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ package integration
99
import (
1010
"bytes"
1111
"context"
12+
"fmt"
1213
"math"
1314
"strconv"
1415
"time"
@@ -88,18 +89,55 @@ func isExpectedKillAllSessionsError(err error) bool {
8889
func killSessions(mt *mtest.T) {
8990
mt.Helper()
9091

91-
err := mt.GlobalClient().Database("admin").RunCommand(mtest.Background, bson.D{
92+
cmd := bson.D{
9293
{"killAllSessions", bson.A{}},
93-
}, options.RunCmd().SetReadPreference(mtest.PrimaryRp)).Err()
94+
}
95+
runCmdOpts := options.RunCmd().SetReadPreference(mtest.PrimaryRp)
96+
97+
// killAllSessions has to be run against each mongos in a sharded cluster, so we use the runCommandOnAllServers
98+
// helper.
99+
err := runCommandOnAllServers(mt, func(client *mongo.Client) error {
100+
return client.Database("admin").RunCommand(mtest.Background, cmd, runCmdOpts).Err()
101+
})
102+
94103
if err == nil {
95104
return
96105
}
97-
98106
if !isExpectedKillAllSessionsError(err) {
99107
mt.Fatalf("killAllSessions error: %v", err)
100108
}
101109
}
102110

111+
// Utility function to run a command on all servers. For standalones, the command is run against the one server. For
112+
// replica sets, the command is run against the primary. sharded clusters, the command is run against each mongos.
113+
func runCommandOnAllServers(mt *mtest.T, commandFn func(client *mongo.Client) error) error {
114+
opts := options.Client().
115+
ApplyURI(mt.ConnString())
116+
117+
if mt.TopologyKind() != mtest.Sharded {
118+
client, err := mongo.Connect(mtest.Background, opts)
119+
if err != nil {
120+
return fmt.Errorf("error creating replica set client: %v", err)
121+
}
122+
defer func() { _ = client.Disconnect(mtest.Background) }()
123+
124+
return commandFn(client)
125+
}
126+
127+
for _, host := range opts.Hosts {
128+
shardClient, err := mongo.Connect(mtest.Background, opts.SetHosts([]string{host}))
129+
if err != nil {
130+
return fmt.Errorf("error creating client for mongos %v: %v", host, err)
131+
}
132+
133+
err = commandFn(shardClient)
134+
_ = shardClient.Disconnect(mtest.Background)
135+
return err
136+
}
137+
138+
return nil
139+
}
140+
103141
// aggregator is an interface used to run collection and database-level aggregations
104142
type aggregator interface {
105143
Aggregate(context.Context, interface{}, ...*options.AggregateOptions) (*mongo.Cursor, error)

mongo/integration/mtest/mongotest.go

Lines changed: 29 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ type T struct {
8585
runOn []RunOnBlock
8686
mockDeployment *mockDeployment // nil if the test is not being run against a mock
8787
mockResponses []bson.D
88-
createdColls []*mongo.Collection // collections created in this test
88+
createdColls []*Collection // collections created in this test
8989
dbName, collName string
9090
failPointNames []string
9191
minServerVersion string
@@ -361,18 +361,23 @@ func (t *T) ResetClient(opts *options.ClientOptions) {
361361
_ = t.Client.Disconnect(Background)
362362
t.createTestClient()
363363
t.DB = t.Client.Database(t.dbName)
364-
t.Coll = t.DB.Collection(t.collName)
364+
t.Coll = t.DB.Collection(t.collName, t.collOpts)
365365

366-
created := make([]*mongo.Collection, len(t.createdColls))
367-
for i, coll := range t.createdColls {
368-
if coll.Name() == t.collName {
369-
created[i] = t.Coll
366+
for _, coll := range t.createdColls {
367+
// If the collection was created using a different Client, it doesn't need to be reset.
368+
if coll.hasDifferentClient {
369+
continue
370+
}
371+
372+
// If the namespace is the same as t.Coll, we can use t.Coll.
373+
if coll.created.Name() == t.collName && coll.created.Database().Name() == t.dbName {
374+
coll.created = t.Coll
370375
continue
371376
}
372377

373-
created[i] = t.DB.Collection(coll.Name())
378+
// Otherwise, reset the collection to use the new Client.
379+
coll.created = t.Client.Database(coll.DB).Collection(coll.Name, coll.Opts)
374380
}
375-
t.createdColls = created
376381
}
377382

378383
// Collection is used to configure a new collection created during a test.
@@ -382,35 +387,25 @@ type Collection struct {
382387
Client *mongo.Client // defaults to mt.Client if not specified
383388
Opts *options.CollectionOptions
384389
CreateOpts bson.D
385-
}
386-
387-
// returns database to use for creating a new collection
388-
func (t *T) extractDatabase(coll Collection) *mongo.Database {
389-
// default to t.DB unless coll overrides it
390-
var createNewDb bool
391-
dbName := t.DB.Name()
392-
if coll.DB != "" {
393-
createNewDb = true
394-
dbName = coll.DB
395-
}
396390

397-
// if a client is specified, a new database must be created
398-
if coll.Client != nil {
399-
return coll.Client.Database(dbName)
400-
}
401-
// if dbName is the same as t.DB.Name(), t.DB can be used
402-
if !createNewDb {
403-
return t.DB
404-
}
405-
// a new database must be created from t.Client
406-
return t.Client.Database(dbName)
391+
hasDifferentClient bool
392+
created *mongo.Collection // the actual collection that was created
407393
}
408394

409395
// CreateCollection creates a new collection with the given configuration. The collection will be dropped after the test
410396
// finishes running. If createOnServer is true, the function ensures that the collection has been created server-side
411397
// by running the create command. The create command will appear in command monitoring channels.
412398
func (t *T) CreateCollection(coll Collection, createOnServer bool) *mongo.Collection {
413-
db := t.extractDatabase(coll)
399+
if coll.DB == "" {
400+
coll.DB = t.DB.Name()
401+
}
402+
if coll.Client == nil {
403+
coll.Client = t.Client
404+
}
405+
coll.hasDifferentClient = coll.Client != t.Client
406+
407+
db := coll.Client.Database(coll.DB)
408+
414409
if createOnServer && t.clientType != Mock {
415410
cmd := bson.D{{"create", coll.Name}}
416411
cmd = append(cmd, coll.CreateOpts...)
@@ -425,15 +420,15 @@ func (t *T) CreateCollection(coll Collection, createOnServer bool) *mongo.Collec
425420
}
426421
}
427422

428-
created := db.Collection(coll.Name, coll.Opts)
429-
t.createdColls = append(t.createdColls, created)
430-
return created
423+
coll.created = db.Collection(coll.Name, coll.Opts)
424+
t.createdColls = append(t.createdColls, &coll)
425+
return coll.created
431426
}
432427

433428
// ClearCollections drops all collections previously created by this test.
434429
func (t *T) ClearCollections() {
435430
for _, coll := range t.createdColls {
436-
_ = coll.Drop(Background)
431+
_ = coll.created.Drop(Background)
437432
}
438433
t.createdColls = t.createdColls[:0]
439434
}

mongo/integration/unified_spec_test.go

Lines changed: 24 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ import (
1111
"path"
1212
"reflect"
1313
"testing"
14-
"time"
1514

1615
"go.mongodb.org/mongo-driver/bson"
1716
"go.mongodb.org/mongo-driver/bson/bsoncodec"
@@ -184,9 +183,6 @@ func runSpecTestFile(t *testing.T, specDir, fileName string) {
184183
}
185184

186185
func runSpecTestCase(mt *mtest.T, test *testCase, testFile testFile) {
187-
testClientOpts := createClientOptions(mt, test.ClientOptions)
188-
testClientOpts.SetHeartbeatInterval(50 * time.Millisecond)
189-
190186
opts := mtest.NewOptions().DatabaseName(testFile.DatabaseName).CollectionName(testFile.CollectionName)
191187
if mt.TopologyKind() == mtest.Sharded && !test.UseMultipleMongoses {
192188
// pin to a single mongos
@@ -200,12 +196,8 @@ func runSpecTestCase(mt *mtest.T, test *testCase, testFile testFile) {
200196
{"validator", validator},
201197
})
202198
}
203-
if test.Description != cseMaxVersionTest {
204-
// don't specify client options for the maxWireVersion CSE test because the client cannot
205-
// be created successfully. Should be fixed by SPEC-1403.
206-
opts.ClientOptions(testClientOpts)
207-
}
208199

200+
// Start the test without setting client options so the setup will be done with a default client.
209201
mt.RunOpts(test.Description, opts, func(mt *mtest.T) {
210202
if len(test.SkipReason) > 0 {
211203
mt.Skip(test.SkipReason)
@@ -219,37 +211,39 @@ func runSpecTestCase(mt *mtest.T, test *testCase, testFile testFile) {
219211

220212
// work around for SERVER-39704: run a non-transactional distinct against each shard in a sharded cluster
221213
if mt.TopologyKind() == mtest.Sharded && test.Description == "distinct" {
222-
opts := options.Client().ApplyURI(mt.ConnString())
223-
for _, host := range opts.Hosts {
224-
shardClient, err := mongo.Connect(mtest.Background, opts.SetHosts([]string{host}))
225-
assert.Nil(mt, err, "Connect error for shard %v: %v", host, err)
226-
coll := shardClient.Database(mt.DB.Name()).Collection(mt.Coll.Name())
227-
_, err = coll.Distinct(mtest.Background, "x", bson.D{})
228-
assert.Nil(mt, err, "Distinct error for shard %v: %v", host, err)
229-
_ = shardClient.Disconnect(mtest.Background)
230-
}
214+
err := runCommandOnAllServers(mt, func(mongosClient *mongo.Client) error {
215+
coll := mongosClient.Database(mt.DB.Name()).Collection(mt.Coll.Name())
216+
_, err := coll.Distinct(mtest.Background, "x", bson.D{})
217+
return err
218+
})
219+
assert.Nil(mt, err, "error running distinct against all mongoses: %v", err)
231220
}
232221

233-
// defer killSessions to ensure it runs regardless of the state of the test because the client has already
222+
// Defer killSessions to ensure it runs regardless of the state of the test because the client has already
234223
// been created and the collection drop in mongotest will hang for transactions to be aborted (60 seconds)
235224
// in error cases.
236225
defer killSessions(mt)
226+
227+
// Test setup: create collections that are tracked by mtest, insert test data, and set the failpoint.
237228
setupTest(mt, &testFile, test)
229+
if test.FailPoint != nil {
230+
mt.SetFailPoint(*test.FailPoint)
231+
}
238232

239-
// create the GridFS bucket after resetting the client so it will be created with a connected client
240-
createBucket(mt, testFile, test)
233+
// Reset the client using the client options specified in the test.
234+
testClientOpts := createClientOptions(mt, test.ClientOptions)
235+
mt.ResetClient(testClientOpts)
241236

242-
// create sessions, fail points, and collection
237+
// Create the GridFS bucket and sessions after resetting the client so it will be created with a connected
238+
// client.
239+
createBucket(mt, testFile, test)
243240
sess0, sess1 := setupSessions(mt, test)
244241
if sess0 != nil {
245242
defer func() {
246243
sess0.EndSession(mtest.Background)
247244
sess1.EndSession(mtest.Background)
248245
}()
249246
}
250-
if test.FailPoint != nil {
251-
mt.SetFailPoint(*test.FailPoint)
252-
}
253247

254248
// run operations
255249
mt.ClearEvents()
@@ -762,24 +756,19 @@ func insertDocuments(mt *mtest.T, coll *mongo.Collection, rawDocs []bson.Raw) {
762756
func setupTest(mt *mtest.T, testFile *testFile, testCase *testCase) {
763757
mt.Helper()
764758

765-
// all setup should be done with the global client instead of the test client to prevent any errors created by
766-
// client configurations.
767-
setupClient := mt.GlobalClient()
768759
// key vault data
769760
if len(testFile.KeyVaultData) > 0 {
770761
keyVaultColl := mt.CreateCollection(mtest.Collection{
771-
Name: "datakeys",
772-
DB: "keyvault",
773-
Client: setupClient,
762+
Name: "datakeys",
763+
DB: "keyvault",
774764
}, false)
775765

776766
insertDocuments(mt, keyVaultColl, testFile.KeyVaultData)
777767
}
778768

779769
// regular documents
780770
if testFile.Data.Documents != nil {
781-
insertColl := setupClient.Database(mt.DB.Name()).Collection(mt.Coll.Name())
782-
insertDocuments(mt, insertColl, testFile.Data.Documents)
771+
insertDocuments(mt, mt.Coll, testFile.Data.Documents)
783772
return
784773
}
785774

@@ -788,15 +777,13 @@ func setupTest(mt *mtest.T, testFile *testFile, testCase *testCase) {
788777

789778
if gfsData.Chunks != nil {
790779
chunks := mt.CreateCollection(mtest.Collection{
791-
Name: gridFSChunks,
792-
Client: setupClient,
780+
Name: gridFSChunks,
793781
}, false)
794782
insertDocuments(mt, chunks, gfsData.Chunks)
795783
}
796784
if gfsData.Files != nil {
797785
files := mt.CreateCollection(mtest.Collection{
798-
Name: gridFSFiles,
799-
Client: setupClient,
786+
Name: gridFSFiles,
800787
}, false)
801788
insertDocuments(mt, files, gfsData.Files)
802789

0 commit comments

Comments
 (0)