Skip to content

Commit fc02e68

Browse files
author
Divjot Arora
committed
GODRIVER-1507 Correctly pass URI to topology (#320)
1 parent 38d59c3 commit fc02e68

File tree

8 files changed

+94
-13
lines changed

8 files changed

+94
-13
lines changed

mongo/client.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,11 @@ func (c *Client) configure(opts *options.ClientOptions) error {
325325

326326
// TODO(GODRIVER-814): Add tests for topology, server, and connection related options.
327327

328+
// Pass down URI so topology can determine whether or not SRV polling is required
329+
topologyOpts = append(topologyOpts, topology.WithURI(func(uri string) string {
330+
return opts.GetURI()
331+
}))
332+
328333
// AppName
329334
var appName string
330335
if opts.AppName != nil {
@@ -559,7 +564,9 @@ func (c *Client) configure(opts *options.ClientOptions) error {
559564

560565
// Deployment
561566
if opts.Deployment != nil {
562-
if len(serverOpts) > 2 || len(topologyOpts) > 1 {
567+
// topology options: WithSeedlist and WithURI
568+
// server options: WithClock and WithConnectionOptions
569+
if len(serverOpts) > 2 || len(topologyOpts) > 2 {
563570
return errors.New("cannot specify topology or server options with a deployment")
564571
}
565572
c.deployment = opts.Deployment

mongo/client_test.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,4 +239,22 @@ func TestClient(t *testing.T) {
239239
client := setupClient(options.Client().SetWriteConcern(wc))
240240
assert.Equal(t, wc, client.writeConcern, "mismatch; expected write concern %v, got %v", wc, client.writeConcern)
241241
})
242+
t.Run("GetURI", func(t *testing.T) {
243+
t.Run("ApplyURI not called", func(t *testing.T) {
244+
opts := options.Client().SetHosts([]string{"localhost:27017"})
245+
uri := opts.GetURI()
246+
assert.Equal(t, "", uri, "expected GetURI to return empty string, got %v", uri)
247+
})
248+
t.Run("ApplyURI called with empty string", func(t *testing.T) {
249+
opts := options.Client().ApplyURI("")
250+
uri := opts.GetURI()
251+
assert.Equal(t, "", uri, "expected GetURI to return empty string, got %v", uri)
252+
})
253+
t.Run("ApplyURI called with non-empty string", func(t *testing.T) {
254+
uri := "mongodb://localhost:27017/foobar"
255+
opts := options.Client().ApplyURI(uri)
256+
got := opts.GetURI()
257+
assert.Equal(t, uri, got, "expected GetURI to return %v, got %v", uri, got)
258+
})
259+
})
242260
}

mongo/options/clientoptions.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ type ClientOptions struct {
118118
AutoEncryptionOptions *AutoEncryptionOptions
119119

120120
err error
121+
uri string
121122

122123
// These options are for internal use only and should not be set. They are deprecated and are
123124
// not part of the stability guarantee. They may be removed in the future.
@@ -133,6 +134,12 @@ func Client() *ClientOptions {
133134
// Validate validates the client options. This method will return the first error found.
134135
func (c *ClientOptions) Validate() error { return c.err }
135136

137+
// GetURI returns the original URI used to configure the ClientOptions instance. If ApplyURI was not called during
138+
// construction, this returns "".
139+
func (c *ClientOptions) GetURI() string {
140+
return c.uri
141+
}
142+
136143
// ApplyURI parses the given URI and sets options accordingly. The URI can contain host names, IPv4/IPv6 literals, or
137144
// an SRV record that will be resolved when the Client is created. When using an SRV record, TLS support is
138145
// implictly enabled. Specify the "tls=false" URI option to override this.
@@ -152,6 +159,7 @@ func (c *ClientOptions) ApplyURI(uri string) *ClientOptions {
152159
return c
153160
}
154161

162+
c.uri = uri
155163
cs, err := connstring.Parse(uri)
156164
if err != nil {
157165
c.err = err

mongo/options/clientoptions_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,7 @@ func TestClientOptions(t *testing.T) {
439439
for _, tc := range testCases {
440440
t.Run(tc.name, func(t *testing.T) {
441441
result := Client().ApplyURI(tc.uri)
442+
tc.result.uri = tc.uri // manually add URI to avoid writing it in each test
442443
if diff := cmp.Diff(
443444
tc.result, result,
444445
cmp.AllowUnexported(readconcern.ReadConcern{}, writeconcern.WriteConcern{}, readpref.ReadPref{}),

x/mongo/driver/topology/polling_srv_records_test.go

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,10 @@ func TestPollingSRVRecordsSpec(t *testing.T) {
131131
t.Run(tt.name, func(t *testing.T) {
132132
cs, err := connstring.Parse("mongodb+srv://test1.test.build.10gen.cc/?heartbeatFrequencyMS=100")
133133
require.NoError(t, err, "Problem parsing the uri: %v", err)
134-
topo, err := New(WithConnString(func(connstring.ConnString) connstring.ConnString { return cs }))
134+
topo, err := New(
135+
WithConnString(func(connstring.ConnString) connstring.ConnString { return cs }),
136+
WithURI(func(string) string { return cs.Original }),
137+
)
135138
require.NoError(t, err, "Could not create the topology: %v", err)
136139
mockRes := newMockResolver(tt.recordsToAdd, tt.recordsToRemove, tt.lookupFail, tt.lookupTimeout)
137140
topo.dnsResolver = &dns.Resolver{mockRes.LookupSRV, mockRes.LookupTXT}
@@ -167,7 +170,10 @@ func TestPollSRVRecords(t *testing.T) {
167170
t.Run("Not unknown or sharded topology", func(t *testing.T) {
168171
cs, err := connstring.Parse("mongodb+srv://test1.test.build.10gen.cc/?heartbeatFrequencyMS=100")
169172
require.NoError(t, err, "Problem parsing the uri: %v", err)
170-
topo, err := New(WithConnString(func(connstring.ConnString) connstring.ConnString { return cs }))
173+
topo, err := New(
174+
WithConnString(func(connstring.ConnString) connstring.ConnString { return cs }),
175+
WithURI(func(string) string { return cs.Original }),
176+
)
171177
require.NoError(t, err, "Could not create the topology: %v", err)
172178
mockRes := newMockResolver(nil, nil, false, false)
173179
topo.dnsResolver = &dns.Resolver{mockRes.LookupSRV, mockRes.LookupTXT}
@@ -209,7 +215,10 @@ func TestPollSRVRecords(t *testing.T) {
209215
t.Run("Failed Hostname Verification", func(t *testing.T) {
210216
cs, err := connstring.Parse("mongodb+srv://test1.test.build.10gen.cc/?heartbeatFrequencyMS=100")
211217
require.NoError(t, err, "Problem parsing the uri: %v", err)
212-
topo, err := New(WithConnString(func(connstring.ConnString) connstring.ConnString { return cs }))
218+
topo, err := New(
219+
WithConnString(func(connstring.ConnString) connstring.ConnString { return cs }),
220+
WithURI(func(string) string { return cs.Original }),
221+
)
213222
require.NoError(t, err, "Could not create the topology: %v", err)
214223
mockRes := newMockResolver([]*net.SRV{{"blah.bleh", 27019, 0, 0}, {"localhost.test.build.10gen.cc.", 27020, 0, 0}}, nil, false, false)
215224
topo.dnsResolver = &dns.Resolver{mockRes.LookupSRV, mockRes.LookupTXT}
@@ -240,7 +249,10 @@ func TestPollSRVRecords(t *testing.T) {
240249
t.Run("Return to polling time", func(t *testing.T) {
241250
cs, err := connstring.Parse("mongodb+srv://test1.test.build.10gen.cc/?heartbeatFrequencyMS=100")
242251
require.NoError(t, err, "Problem parsing the uri: %v", err)
243-
topo, err := New(WithConnString(func(connstring.ConnString) connstring.ConnString { return cs }))
252+
topo, err := New(
253+
WithConnString(func(connstring.ConnString) connstring.ConnString { return cs }),
254+
WithURI(func(string) string { return cs.Original }),
255+
)
244256
require.NoError(t, err, "Could not create the topology: %v", err)
245257
mockRes := newMockResolver(nil, nil, false, false)
246258
mockRes.fail = 1

x/mongo/driver/topology/topology.go

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ type Topology struct {
6464

6565
done chan struct{}
6666

67+
pollingRequired bool
6768
pollingDone chan struct{}
6869
pollingwg sync.WaitGroup
6970
rescanSRVInterval time.Duration
@@ -131,6 +132,10 @@ func New(opts ...Option) (*Topology, error) {
131132
t.fsm.Kind = description.Single
132133
}
133134

135+
if t.cfg.uri != "" {
136+
t.pollingRequired = strings.HasPrefix(t.cfg.uri, "mongodb+srv://")
137+
}
138+
134139
return t, nil
135140
}
136141

@@ -154,7 +159,7 @@ func (t *Topology) Connect() error {
154159
}
155160
t.serversLock.Unlock()
156161

157-
if srvPollingRequired(t.cfg.cs.Original) {
162+
if t.pollingRequired {
158163
go t.pollSRVRecords()
159164
t.pollingwg.Add(1)
160165
}
@@ -192,7 +197,7 @@ func (t *Topology) Disconnect(ctx context.Context) error {
192197
t.subscriptionsClosed = true
193198
t.subLock.Unlock()
194199

195-
if srvPollingRequired(t.cfg.cs.Original) {
200+
if t.pollingRequired {
196201
t.pollingDone <- struct{}{}
197202
t.pollingwg.Wait()
198203
}
@@ -203,10 +208,6 @@ func (t *Topology) Disconnect(ctx context.Context) error {
203208
return nil
204209
}
205210

206-
func srvPollingRequired(connstr string) bool {
207-
return strings.HasPrefix(connstr, "mongodb+srv://")
208-
}
209-
210211
// Description returns a description of the topology.
211212
func (t *Topology) Description() description.Topology {
212213
td, ok := t.desc.Load().(description.Topology)
@@ -498,7 +499,7 @@ func (t *Topology) pollSRVRecords() {
498499
}()
499500

500501
// remove the scheme
501-
uri := t.cfg.cs.Original[14:]
502+
uri := t.cfg.uri[14:]
502503
hosts := uri
503504
if idx := strings.IndexAny(uri, "/?@"); idx != -1 {
504505
hosts = uri[:idx]

x/mongo/driver/topology/topology_options.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ type config struct {
3131
replicaSetName string
3232
seedList []string
3333
serverOpts []ServerOption
34-
cs connstring.ConnString
34+
cs connstring.ConnString // This must not be used for any logic in topology.Topology.
35+
uri string
3536
serverSelectionTimeout time.Duration
3637
}
3738

@@ -275,6 +276,14 @@ func WithServerSelectionTimeout(fn func(time.Duration) time.Duration) Option {
275276
}
276277
}
277278

279+
// WithURI specifies the URI that was used to create the topology.
280+
func WithURI(fn func(string) string) Option {
281+
return func(cfg *config) error {
282+
cfg.uri = fn(cfg.uri)
283+
return nil
284+
}
285+
}
286+
278287
// addCACertFromFile adds a root CA certificate to the configuration given a path
279288
// to the containing file.
280289
func addCACertFromFile(cfg *tls.Config, file string) error {

x/mongo/driver/topology/topology_test.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,3 +514,28 @@ func TestTopology_String_Race(t *testing.T) {
514514
<-ch
515515
<-ch
516516
}
517+
518+
func TestTopologyConstruction(t *testing.T) {
519+
t.Run("construct with URI", func(t *testing.T) {
520+
testCases := []struct {
521+
name string
522+
uri string
523+
pollingRequired bool
524+
}{
525+
{"normal", "mongodb://localhost:27017", false},
526+
{"srv", "mongodb+srv://localhost:27017", true},
527+
}
528+
for _, tc := range testCases {
529+
t.Run(tc.name, func(t *testing.T) {
530+
topo, err := New(
531+
WithURI(func(string) string { return tc.uri }),
532+
)
533+
assert.Nil(t, err, "topology.New error: %v", err)
534+
535+
assert.Equal(t, tc.uri, topo.cfg.uri, "expected topology URI to be %v, got %v", tc.uri, topo.cfg.uri)
536+
assert.Equal(t, tc.pollingRequired, topo.pollingRequired,
537+
"expected topo.pollingRequired to be %v, got %v", tc.pollingRequired, topo.pollingRequired)
538+
})
539+
}
540+
})
541+
}

0 commit comments

Comments
 (0)