Skip to content

Commit ee35f14

Browse files
author
Divjot Arora
authored
GODRIVER-1710 Ensure TLS ServerName is not overriden if already set (#486)
1 parent f54e631 commit ee35f14

File tree

4 files changed

+110
-14
lines changed

4 files changed

+110
-14
lines changed

x/mongo/driver/topology/connection.go

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ func (c *connection) connect(ctx context.Context) {
152152
Cache: c.config.ocspCache,
153153
DisableEndpointChecking: c.config.disableOCSPEndpointCheck,
154154
}
155-
tlsNc, err := configureTLS(ctx, c.nc, c.addr, tlsConfig, ocspOpts)
155+
tlsNc, err := configureTLS(ctx, c.config.tlsConnectionSource, c.nc, c.addr, tlsConfig, ocspOpts)
156156
if err != nil {
157157
c.processInitializationError(err)
158158
return
@@ -657,19 +657,27 @@ func (c *Connection) LocalAddress() address.Address {
657657
var notMasterCodes = []int32{10107, 13435}
658658
var recoveringCodes = []int32{11600, 11602, 13436, 189, 91}
659659

660-
func configureTLS(ctx context.Context, nc net.Conn, addr address.Address, config *tls.Config, ocspOpts *ocsp.VerifyOptions) (net.Conn, error) {
661-
// Ensure config.ServerName is always set for SNI.
662-
hostname := addr.String()
663-
colonPos := strings.LastIndex(hostname, ":")
664-
if colonPos == -1 {
665-
colonPos = len(hostname)
666-
}
660+
func configureTLS(ctx context.Context,
661+
tlsConnSource tlsConnectionSource,
662+
nc net.Conn,
663+
addr address.Address,
664+
config *tls.Config,
665+
ocspOpts *ocsp.VerifyOptions,
666+
) (net.Conn, error) {
667667

668-
hostname = hostname[:colonPos]
669-
config.ServerName = hostname
668+
// Ensure config.ServerName is always set for SNI.
669+
if config.ServerName == "" {
670+
hostname := addr.String()
671+
colonPos := strings.LastIndex(hostname, ":")
672+
if colonPos == -1 {
673+
colonPos = len(hostname)
674+
}
670675

671-
client := tls.Client(nc, config)
676+
hostname = hostname[:colonPos]
677+
config.ServerName = hostname
678+
}
672679

680+
client := tlsConnSource.Client(nc, config)
673681
errChan := make(chan error, 1)
674682
go func() {
675683
errChan <- client.Handshake()

x/mongo/driver/topology/connection_options.go

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,15 @@ type connectionConfig struct {
5252
ocspCache ocsp.Cache
5353
disableOCSPEndpointCheck bool
5454
errorHandlingCallback func(error, uint64)
55+
tlsConnectionSource tlsConnectionSource
5556
}
5657

5758
func newConnectionConfig(opts ...ConnectionOption) (*connectionConfig, error) {
5859
cfg := &connectionConfig{
59-
connectTimeout: 30 * time.Second,
60-
dialer: nil,
61-
lifeTimeout: 30 * time.Minute,
60+
connectTimeout: 30 * time.Second,
61+
dialer: nil,
62+
lifeTimeout: 30 * time.Minute,
63+
tlsConnectionSource: defaultTLSConnectionSource,
6264
}
6365

6466
for _, opt := range opts {
@@ -78,6 +80,13 @@ func newConnectionConfig(opts ...ConnectionOption) (*connectionConfig, error) {
7880
// ConnectionOption is used to configure a connection.
7981
type ConnectionOption func(*connectionConfig) error
8082

83+
func withTLSConnectionSource(fn func(tlsConnectionSource) tlsConnectionSource) ConnectionOption {
84+
return func(c *connectionConfig) error {
85+
c.tlsConnectionSource = fn(c.tlsConnectionSource)
86+
return nil
87+
}
88+
}
89+
8190
func withErrorHandlingCallback(fn func(error, uint64)) ConnectionOption {
8291
return func(c *connectionConfig) error {
8392
c.errorHandlingCallback = fn

x/mongo/driver/topology/connection_test.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ package topology
88

99
import (
1010
"context"
11+
"crypto/tls"
1112
"errors"
1213
"net"
1314
"sync"
@@ -176,6 +177,58 @@ func TestConnection(t *testing.T) {
176177
wg.Wait()
177178
})
178179
})
180+
t.Run("tls", func(t *testing.T) {
181+
t.Run("connection source is set to default if unspecified", func(t *testing.T) {
182+
conn, err := newConnection(address.Address(""))
183+
assert.Nil(t, err, "newConnection error: %v", err)
184+
assert.NotNil(t, conn.config.tlsConnectionSource, "expected tlsConnectionSource to be set but was not")
185+
})
186+
t.Run("server name", func(t *testing.T) {
187+
testCases := []struct {
188+
name string
189+
addr address.Address
190+
cfg *tls.Config
191+
expectedServerName string
192+
}{
193+
{"set to connection address if empty", "localhost:27017", &tls.Config{}, "localhost"},
194+
{"left alone if non-empty", "localhost:27017", &tls.Config{ServerName: "other"}, "other"},
195+
}
196+
for _, tc := range testCases {
197+
t.Run(tc.name, func(t *testing.T) {
198+
var sentCfg *tls.Config
199+
var testTLSConnectionSource tlsConnectionSourceFn = func(nc net.Conn, cfg *tls.Config) *tls.Conn {
200+
sentCfg = cfg
201+
return tls.Client(nc, cfg)
202+
}
203+
204+
connOpts := []ConnectionOption{
205+
WithDialer(func(Dialer) Dialer {
206+
return DialerFunc(func(context.Context, string, string) (net.Conn, error) {
207+
return &net.TCPConn{}, nil
208+
})
209+
}),
210+
WithHandshaker(func(Handshaker) Handshaker {
211+
return &testHandshaker{}
212+
}),
213+
WithTLSConfig(func(*tls.Config) *tls.Config {
214+
return tc.cfg
215+
}),
216+
withTLSConnectionSource(func(tlsConnectionSource) tlsConnectionSource {
217+
return testTLSConnectionSource
218+
}),
219+
}
220+
conn, err := newConnection(tc.addr, connOpts...)
221+
assert.Nil(t, err, "newConnection error: %v", err)
222+
223+
conn.connect(context.Background())
224+
err = conn.wait()
225+
assert.NotNil(t, sentCfg, "expected TLS config to be set, but was not")
226+
assert.Equal(t, tc.expectedServerName, sentCfg.ServerName, "expected ServerName %s, got %s",
227+
tc.expectedServerName, sentCfg.ServerName)
228+
})
229+
}
230+
})
231+
})
179232
})
180233
t.Run("writeWireMessage", func(t *testing.T) {
181234
t.Run("closed connection", func(t *testing.T) {
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// Copyright (C) MongoDB, Inc. 2017-present.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License"); you may
4+
// not use this file except in compliance with the License. You may obtain
5+
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6+
7+
package topology
8+
9+
import (
10+
"crypto/tls"
11+
"net"
12+
)
13+
14+
type tlsConnectionSource interface {
15+
Client(net.Conn, *tls.Config) *tls.Conn
16+
}
17+
18+
type tlsConnectionSourceFn func(net.Conn, *tls.Config) *tls.Conn
19+
20+
func (t tlsConnectionSourceFn) Client(nc net.Conn, cfg *tls.Config) *tls.Conn {
21+
return t(nc, cfg)
22+
}
23+
24+
var defaultTLSConnectionSource tlsConnectionSourceFn = func(nc net.Conn, cfg *tls.Config) *tls.Conn {
25+
return tls.Client(nc, cfg)
26+
}

0 commit comments

Comments
 (0)