Skip to content

Commit 8b57925

Browse files
author
Divjot Arora
committed
GODRIVER-1710 Ensure TLS ServerName is not overriden if already set (#486)
1 parent c596b8e commit 8b57925

File tree

4 files changed

+110
-14
lines changed

4 files changed

+110
-14
lines changed

x/mongo/driver/topology/connection.go

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ func (c *connection) connect(ctx context.Context) {
149149
Cache: c.config.ocspCache,
150150
DisableEndpointChecking: c.config.disableOCSPEndpointCheck,
151151
}
152-
tlsNc, err := configureTLS(ctx, c.nc, c.addr, tlsConfig, ocspOpts)
152+
tlsNc, err := configureTLS(ctx, c.config.tlsConnectionSource, c.nc, c.addr, tlsConfig, ocspOpts)
153153
if err != nil {
154154
c.processInitializationError(err)
155155
return
@@ -604,19 +604,27 @@ func (c *Connection) LocalAddress() address.Address {
604604
var notMasterCodes = []int32{10107, 13435}
605605
var recoveringCodes = []int32{11600, 11602, 13436, 189, 91}
606606

607-
func configureTLS(ctx context.Context, nc net.Conn, addr address.Address, config *tls.Config, ocspOpts *ocsp.VerifyOptions) (net.Conn, error) {
608-
// Ensure config.ServerName is always set for SNI.
609-
hostname := addr.String()
610-
colonPos := strings.LastIndex(hostname, ":")
611-
if colonPos == -1 {
612-
colonPos = len(hostname)
613-
}
607+
func configureTLS(ctx context.Context,
608+
tlsConnSource tlsConnectionSource,
609+
nc net.Conn,
610+
addr address.Address,
611+
config *tls.Config,
612+
ocspOpts *ocsp.VerifyOptions,
613+
) (net.Conn, error) {
614614

615-
hostname = hostname[:colonPos]
616-
config.ServerName = hostname
615+
// Ensure config.ServerName is always set for SNI.
616+
if config.ServerName == "" {
617+
hostname := addr.String()
618+
colonPos := strings.LastIndex(hostname, ":")
619+
if colonPos == -1 {
620+
colonPos = len(hostname)
621+
}
617622

618-
client := tls.Client(nc, config)
623+
hostname = hostname[:colonPos]
624+
config.ServerName = hostname
625+
}
619626

627+
client := tlsConnSource.Client(nc, config)
620628
errChan := make(chan error, 1)
621629
go func() {
622630
errChan <- client.Handshake()

x/mongo/driver/topology/connection_options.go

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,15 @@ type connectionConfig struct {
5252
ocspCache ocsp.Cache
5353
disableOCSPEndpointCheck bool
5454
errorHandlingCallback func(error, uint64)
55+
tlsConnectionSource tlsConnectionSource
5556
}
5657

5758
func newConnectionConfig(opts ...ConnectionOption) (*connectionConfig, error) {
5859
cfg := &connectionConfig{
59-
connectTimeout: 30 * time.Second,
60-
dialer: nil,
61-
lifeTimeout: 30 * time.Minute,
60+
connectTimeout: 30 * time.Second,
61+
dialer: nil,
62+
lifeTimeout: 30 * time.Minute,
63+
tlsConnectionSource: defaultTLSConnectionSource,
6264
}
6365

6466
for _, opt := range opts {
@@ -78,6 +80,13 @@ func newConnectionConfig(opts ...ConnectionOption) (*connectionConfig, error) {
7880
// ConnectionOption is used to configure a connection.
7981
type ConnectionOption func(*connectionConfig) error
8082

83+
func withTLSConnectionSource(fn func(tlsConnectionSource) tlsConnectionSource) ConnectionOption {
84+
return func(c *connectionConfig) error {
85+
c.tlsConnectionSource = fn(c.tlsConnectionSource)
86+
return nil
87+
}
88+
}
89+
8190
func withErrorHandlingCallback(fn func(error, uint64)) ConnectionOption {
8291
return func(c *connectionConfig) error {
8392
c.errorHandlingCallback = fn

x/mongo/driver/topology/connection_test.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ package topology
88

99
import (
1010
"context"
11+
"crypto/tls"
1112
"errors"
1213
"net"
1314
"sync"
@@ -175,6 +176,58 @@ func TestConnection(t *testing.T) {
175176
wg.Wait()
176177
})
177178
})
179+
t.Run("tls", func(t *testing.T) {
180+
t.Run("connection source is set to default if unspecified", func(t *testing.T) {
181+
conn, err := newConnection(address.Address(""))
182+
assert.Nil(t, err, "newConnection error: %v", err)
183+
assert.NotNil(t, conn.config.tlsConnectionSource, "expected tlsConnectionSource to be set but was not")
184+
})
185+
t.Run("server name", func(t *testing.T) {
186+
testCases := []struct {
187+
name string
188+
addr address.Address
189+
cfg *tls.Config
190+
expectedServerName string
191+
}{
192+
{"set to connection address if empty", "localhost:27017", &tls.Config{}, "localhost"},
193+
{"left alone if non-empty", "localhost:27017", &tls.Config{ServerName: "other"}, "other"},
194+
}
195+
for _, tc := range testCases {
196+
t.Run(tc.name, func(t *testing.T) {
197+
var sentCfg *tls.Config
198+
var testTLSConnectionSource tlsConnectionSourceFn = func(nc net.Conn, cfg *tls.Config) *tls.Conn {
199+
sentCfg = cfg
200+
return tls.Client(nc, cfg)
201+
}
202+
203+
connOpts := []ConnectionOption{
204+
WithDialer(func(Dialer) Dialer {
205+
return DialerFunc(func(context.Context, string, string) (net.Conn, error) {
206+
return &net.TCPConn{}, nil
207+
})
208+
}),
209+
WithHandshaker(func(Handshaker) Handshaker {
210+
return &testHandshaker{}
211+
}),
212+
WithTLSConfig(func(*tls.Config) *tls.Config {
213+
return tc.cfg
214+
}),
215+
withTLSConnectionSource(func(tlsConnectionSource) tlsConnectionSource {
216+
return testTLSConnectionSource
217+
}),
218+
}
219+
conn, err := newConnection(tc.addr, connOpts...)
220+
assert.Nil(t, err, "newConnection error: %v", err)
221+
222+
conn.connect(context.Background())
223+
err = conn.wait()
224+
assert.NotNil(t, sentCfg, "expected TLS config to be set, but was not")
225+
assert.Equal(t, tc.expectedServerName, sentCfg.ServerName, "expected ServerName %s, got %s",
226+
tc.expectedServerName, sentCfg.ServerName)
227+
})
228+
}
229+
})
230+
})
178231
})
179232
t.Run("writeWireMessage", func(t *testing.T) {
180233
t.Run("closed connection", func(t *testing.T) {
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// Copyright (C) MongoDB, Inc. 2017-present.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License"); you may
4+
// not use this file except in compliance with the License. You may obtain
5+
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6+
7+
package topology
8+
9+
import (
10+
"crypto/tls"
11+
"net"
12+
)
13+
14+
type tlsConnectionSource interface {
15+
Client(net.Conn, *tls.Config) *tls.Conn
16+
}
17+
18+
type tlsConnectionSourceFn func(net.Conn, *tls.Config) *tls.Conn
19+
20+
func (t tlsConnectionSourceFn) Client(nc net.Conn, cfg *tls.Config) *tls.Conn {
21+
return t(nc, cfg)
22+
}
23+
24+
var defaultTLSConnectionSource tlsConnectionSourceFn = func(nc net.Conn, cfg *tls.Config) *tls.Conn {
25+
return tls.Client(nc, cfg)
26+
}

0 commit comments

Comments
 (0)