Skip to content

Commit a10006a

Browse files
author
Divjot Arora
committed
GODRIVER-1636 Ensure SNI is always enabled (#418)
1 parent a2601d2 commit a10006a

File tree

2 files changed

+43
-32
lines changed

2 files changed

+43
-32
lines changed

mongo/testatlas/main.go

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"context"
1111
"flag"
1212
"fmt"
13+
"time"
1314

1415
"go.mongodb.org/mongo-driver/bson"
1516
"go.mongodb.org/mongo-driver/mongo"
@@ -22,34 +23,45 @@ func main() {
2223
ctx := context.Background()
2324

2425
for idx, uri := range uris {
25-
client, err := mongo.Connect(ctx, options.Client().ApplyURI(uri))
26-
if err != nil {
27-
panic(createErrorMessage(idx, "Connect error: %v", err))
28-
}
26+
// Set a low server selection timeout so we fail fast if there are errors.
27+
clientOpts := options.Client().
28+
ApplyURI(uri).
29+
SetServerSelectionTimeout(1 * time.Second)
2930

30-
defer func() {
31-
if err = client.Disconnect(ctx); err != nil {
32-
panic(createErrorMessage(idx, "Disconnect error: %v", err))
33-
}
34-
}()
35-
36-
db := client.Database("test")
37-
err = db.RunCommand(
38-
ctx,
39-
bson.D{{"isMaster", 1}},
40-
).Err()
41-
if err != nil {
42-
panic(createErrorMessage(idx, "isMaster error: %v", err))
31+
// Run basic connectivity test.
32+
if err := runTest(ctx, clientOpts); err != nil {
33+
panic(fmt.Sprintf("error running test with TLS at index %d: %v", idx, err))
4334
}
4435

45-
coll := db.Collection("test")
46-
if err = coll.FindOne(ctx, bson.D{{"x", 1}}).Err(); err != nil && err != mongo.ErrNoDocuments {
47-
panic(createErrorMessage(idx, "FindOne error: %v", err))
36+
// Run the connectivity test with InsecureSkipVerify to ensure SNI is done correctly even if verification is
37+
// disabled.
38+
clientOpts.TLSConfig.InsecureSkipVerify = true
39+
if err := runTest(ctx, clientOpts); err != nil {
40+
panic(fmt.Sprintf("error running test with tlsInsecure at index %d: %v", idx, err))
4841
}
4942
}
5043
}
5144

52-
func createErrorMessage(idx int, msg string, args ...interface{}) string {
53-
msg = fmt.Sprintf(msg, args...)
54-
return fmt.Sprintf("error for URI at index %d: %s", idx, msg)
45+
func runTest(ctx context.Context, clientOpts *options.ClientOptions) error {
46+
client, err := mongo.Connect(ctx, clientOpts)
47+
if err != nil {
48+
return fmt.Errorf("Connect error: %v", err)
49+
}
50+
51+
defer func() {
52+
_ = client.Disconnect(ctx)
53+
}()
54+
55+
db := client.Database("test")
56+
cmd := bson.D{{"isMaster", 1}}
57+
err = db.RunCommand(ctx, cmd).Err()
58+
if err != nil {
59+
return fmt.Errorf("isMaster error: %v", err)
60+
}
61+
62+
coll := db.Collection("test")
63+
if err = coll.FindOne(ctx, bson.D{{"x", 1}}).Err(); err != nil && err != mongo.ErrNoDocuments {
64+
return fmt.Errorf("FindOne error: %v", err)
65+
}
66+
return nil
5567
}

x/mongo/driver/topology/connection.go

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -496,17 +496,16 @@ var notMasterCodes = []int32{10107, 13435}
496496
var recoveringCodes = []int32{11600, 11602, 13436, 189, 91}
497497

498498
func configureTLS(ctx context.Context, nc net.Conn, addr address.Address, config *tls.Config) (net.Conn, error) {
499-
if !config.InsecureSkipVerify {
500-
hostname := addr.String()
501-
colonPos := strings.LastIndex(hostname, ":")
502-
if colonPos == -1 {
503-
colonPos = len(hostname)
504-
}
505-
506-
hostname = hostname[:colonPos]
507-
config.ServerName = hostname
499+
// Ensure config.ServerName is always set for SNI.
500+
hostname := addr.String()
501+
colonPos := strings.LastIndex(hostname, ":")
502+
if colonPos == -1 {
503+
colonPos = len(hostname)
508504
}
509505

506+
hostname = hostname[:colonPos]
507+
config.ServerName = hostname
508+
510509
client := tls.Client(nc, config)
511510

512511
errChan := make(chan error, 1)

0 commit comments

Comments
 (0)