Skip to content

Simplify unified test runner #416

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 41 additions & 3 deletions mongo/integration/crud_helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package integration
import (
"bytes"
"context"
"fmt"
"math"
"strconv"
"time"
Expand Down Expand Up @@ -88,18 +89,55 @@ func isExpectedKillAllSessionsError(err error) bool {
func killSessions(mt *mtest.T) {
mt.Helper()

err := mt.GlobalClient().Database("admin").RunCommand(mtest.Background, bson.D{
cmd := bson.D{
{"killAllSessions", bson.A{}},
}, options.RunCmd().SetReadPreference(mtest.PrimaryRp)).Err()
}
runCmdOpts := options.RunCmd().SetReadPreference(mtest.PrimaryRp)

// killAllSessions has to be run against each mongos in a sharded cluster, so we use the runCommandOnAllServers
// helper.
err := runCommandOnAllServers(mt, func(client *mongo.Client) error {
return client.Database("admin").RunCommand(mtest.Background, cmd, runCmdOpts).Err()
})

if err == nil {
return
}

if !isExpectedKillAllSessionsError(err) {
mt.Fatalf("killAllSessions error: %v", err)
}
}

// Utility function to run a command on all servers. For standalones, the command is run against the one server. For
// replica sets, the command is run against the primary. sharded clusters, the command is run against each mongos.
func runCommandOnAllServers(mt *mtest.T, commandFn func(client *mongo.Client) error) error {
opts := options.Client().
ApplyURI(mt.ConnString())

if mt.TopologyKind() != mtest.Sharded {
client, err := mongo.Connect(mtest.Background, opts)
if err != nil {
return fmt.Errorf("error creating replica set client: %v", err)
}
defer func() { _ = client.Disconnect(mtest.Background) }()

return commandFn(client)
}

for _, host := range opts.Hosts {
shardClient, err := mongo.Connect(mtest.Background, opts.SetHosts([]string{host}))
if err != nil {
return fmt.Errorf("error creating client for mongos %v: %v", host, err)
}

err = commandFn(shardClient)
_ = shardClient.Disconnect(mtest.Background)
return err
}

return nil
}

// aggregator is an interface used to run collection and database-level aggregations
type aggregator interface {
Aggregate(context.Context, interface{}, ...*options.AggregateOptions) (*mongo.Cursor, error)
Expand Down
63 changes: 29 additions & 34 deletions mongo/integration/mtest/mongotest.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ type T struct {
runOn []RunOnBlock
mockDeployment *mockDeployment // nil if the test is not being run against a mock
mockResponses []bson.D
createdColls []*mongo.Collection // collections created in this test
createdColls []*Collection // collections created in this test
dbName, collName string
failPointNames []string
minServerVersion string
Expand Down Expand Up @@ -361,18 +361,23 @@ func (t *T) ResetClient(opts *options.ClientOptions) {
_ = t.Client.Disconnect(Background)
t.createTestClient()
t.DB = t.Client.Database(t.dbName)
t.Coll = t.DB.Collection(t.collName)
t.Coll = t.DB.Collection(t.collName, t.collOpts)

created := make([]*mongo.Collection, len(t.createdColls))
for i, coll := range t.createdColls {
if coll.Name() == t.collName {
created[i] = t.Coll
for _, coll := range t.createdColls {
// If the collection was created using a different Client, it doesn't need to be reset.
if coll.hasDifferentClient {
continue
}

// If the namespace is the same as t.Coll, we can use t.Coll.
if coll.created.Name() == t.collName && coll.created.Database().Name() == t.dbName {
coll.created = t.Coll
continue
}

created[i] = t.DB.Collection(coll.Name())
// Otherwise, reset the collection to use the new Client.
coll.created = t.Client.Database(coll.DB).Collection(coll.Name, coll.Opts)
}
t.createdColls = created
}

// Collection is used to configure a new collection created during a test.
Expand All @@ -382,35 +387,25 @@ type Collection struct {
Client *mongo.Client // defaults to mt.Client if not specified
Opts *options.CollectionOptions
CreateOpts bson.D
}

// returns database to use for creating a new collection
func (t *T) extractDatabase(coll Collection) *mongo.Database {
// default to t.DB unless coll overrides it
var createNewDb bool
dbName := t.DB.Name()
if coll.DB != "" {
createNewDb = true
dbName = coll.DB
}

// if a client is specified, a new database must be created
if coll.Client != nil {
return coll.Client.Database(dbName)
}
// if dbName is the same as t.DB.Name(), t.DB can be used
if !createNewDb {
return t.DB
}
// a new database must be created from t.Client
return t.Client.Database(dbName)
hasDifferentClient bool
created *mongo.Collection // the actual collection that was created
}

// CreateCollection creates a new collection with the given configuration. The collection will be dropped after the test
// finishes running. If createOnServer is true, the function ensures that the collection has been created server-side
// by running the create command. The create command will appear in command monitoring channels.
func (t *T) CreateCollection(coll Collection, createOnServer bool) *mongo.Collection {
db := t.extractDatabase(coll)
if coll.DB == "" {
coll.DB = t.DB.Name()
}
if coll.Client == nil {
coll.Client = t.Client
}
coll.hasDifferentClient = coll.Client != t.Client

db := coll.Client.Database(coll.DB)

if createOnServer && t.clientType != Mock {
cmd := bson.D{{"create", coll.Name}}
cmd = append(cmd, coll.CreateOpts...)
Expand All @@ -425,15 +420,15 @@ func (t *T) CreateCollection(coll Collection, createOnServer bool) *mongo.Collec
}
}

created := db.Collection(coll.Name, coll.Opts)
t.createdColls = append(t.createdColls, created)
return created
coll.created = db.Collection(coll.Name, coll.Opts)
t.createdColls = append(t.createdColls, &coll)
return coll.created
}

// ClearCollections drops all collections previously created by this test.
func (t *T) ClearCollections() {
for _, coll := range t.createdColls {
_ = coll.Drop(Background)
_ = coll.created.Drop(Background)
}
t.createdColls = t.createdColls[:0]
}
Expand Down
61 changes: 24 additions & 37 deletions mongo/integration/unified_spec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"path"
"reflect"
"testing"
"time"

"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/bsoncodec"
Expand Down Expand Up @@ -184,9 +183,6 @@ func runSpecTestFile(t *testing.T, specDir, fileName string) {
}

func runSpecTestCase(mt *mtest.T, test *testCase, testFile testFile) {
testClientOpts := createClientOptions(mt, test.ClientOptions)
testClientOpts.SetHeartbeatInterval(50 * time.Millisecond)

opts := mtest.NewOptions().DatabaseName(testFile.DatabaseName).CollectionName(testFile.CollectionName)
if mt.TopologyKind() == mtest.Sharded && !test.UseMultipleMongoses {
// pin to a single mongos
Expand All @@ -200,12 +196,8 @@ func runSpecTestCase(mt *mtest.T, test *testCase, testFile testFile) {
{"validator", validator},
})
}
if test.Description != cseMaxVersionTest {
// don't specify client options for the maxWireVersion CSE test because the client cannot
// be created successfully. Should be fixed by SPEC-1403.
opts.ClientOptions(testClientOpts)
}

// Start the test without setting client options so the setup will be done with a default client.
mt.RunOpts(test.Description, opts, func(mt *mtest.T) {
if len(test.SkipReason) > 0 {
mt.Skip(test.SkipReason)
Expand All @@ -219,37 +211,39 @@ func runSpecTestCase(mt *mtest.T, test *testCase, testFile testFile) {

// work around for SERVER-39704: run a non-transactional distinct against each shard in a sharded cluster
if mt.TopologyKind() == mtest.Sharded && test.Description == "distinct" {
opts := options.Client().ApplyURI(mt.ConnString())
for _, host := range opts.Hosts {
shardClient, err := mongo.Connect(mtest.Background, opts.SetHosts([]string{host}))
assert.Nil(mt, err, "Connect error for shard %v: %v", host, err)
coll := shardClient.Database(mt.DB.Name()).Collection(mt.Coll.Name())
_, err = coll.Distinct(mtest.Background, "x", bson.D{})
assert.Nil(mt, err, "Distinct error for shard %v: %v", host, err)
_ = shardClient.Disconnect(mtest.Background)
}
err := runCommandOnAllServers(mt, func(mongosClient *mongo.Client) error {
coll := mongosClient.Database(mt.DB.Name()).Collection(mt.Coll.Name())
_, err := coll.Distinct(mtest.Background, "x", bson.D{})
return err
})
assert.Nil(mt, err, "error running distinct against all mongoses: %v", err)
}

// defer killSessions to ensure it runs regardless of the state of the test because the client has already
// Defer killSessions to ensure it runs regardless of the state of the test because the client has already
// been created and the collection drop in mongotest will hang for transactions to be aborted (60 seconds)
// in error cases.
defer killSessions(mt)

// Test setup: create collections that are tracked by mtest, insert test data, and set the failpoint.
setupTest(mt, &testFile, test)
if test.FailPoint != nil {
mt.SetFailPoint(*test.FailPoint)
}

// create the GridFS bucket after resetting the client so it will be created with a connected client
createBucket(mt, testFile, test)
// Reset the client using the client options specified in the test.
testClientOpts := createClientOptions(mt, test.ClientOptions)
mt.ResetClient(testClientOpts)

// create sessions, fail points, and collection
// Create the GridFS bucket and sessions after resetting the client so it will be created with a connected
// client.
createBucket(mt, testFile, test)
sess0, sess1 := setupSessions(mt, test)
if sess0 != nil {
defer func() {
sess0.EndSession(mtest.Background)
sess1.EndSession(mtest.Background)
}()
}
if test.FailPoint != nil {
mt.SetFailPoint(*test.FailPoint)
}

// run operations
mt.ClearEvents()
Expand Down Expand Up @@ -762,24 +756,19 @@ func insertDocuments(mt *mtest.T, coll *mongo.Collection, rawDocs []bson.Raw) {
func setupTest(mt *mtest.T, testFile *testFile, testCase *testCase) {
mt.Helper()

// all setup should be done with the global client instead of the test client to prevent any errors created by
// client configurations.
setupClient := mt.GlobalClient()
// key vault data
if len(testFile.KeyVaultData) > 0 {
keyVaultColl := mt.CreateCollection(mtest.Collection{
Name: "datakeys",
DB: "keyvault",
Client: setupClient,
Name: "datakeys",
DB: "keyvault",
}, false)

insertDocuments(mt, keyVaultColl, testFile.KeyVaultData)
}

// regular documents
if testFile.Data.Documents != nil {
insertColl := setupClient.Database(mt.DB.Name()).Collection(mt.Coll.Name())
insertDocuments(mt, insertColl, testFile.Data.Documents)
insertDocuments(mt, mt.Coll, testFile.Data.Documents)
return
}

Expand All @@ -788,15 +777,13 @@ func setupTest(mt *mtest.T, testFile *testFile, testCase *testCase) {

if gfsData.Chunks != nil {
chunks := mt.CreateCollection(mtest.Collection{
Name: gridFSChunks,
Client: setupClient,
Name: gridFSChunks,
}, false)
insertDocuments(mt, chunks, gfsData.Chunks)
}
if gfsData.Files != nil {
files := mt.CreateCollection(mtest.Collection{
Name: gridFSFiles,
Client: setupClient,
Name: gridFSFiles,
}, false)
insertDocuments(mt, files, gfsData.Files)

Expand Down