Skip to content

Commit 001e550

Browse files
authored
GODRIVER-2181 Use different state constants for each topology type. (#870)
1 parent 82b7dd1 commit 001e550

12 files changed

+121
-106
lines changed

x/mongo/driver/topology/CMAP_spec_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ func runCMAPTest(t *testing.T, testFileName string) {
211211

212212
s, err := NewServer(address.Address(l.Addr().String()), primitive.NewObjectID(), sOpts...)
213213
testHelpers.RequireNil(t, err, "error creating server: %v", err)
214-
s.connectionstate = connected
214+
s.state = serverConnected
215215
testHelpers.RequireNil(t, err, "error connecting connection pool: %v", err)
216216
defer s.pool.close(context.Background())
217217

x/mongo/driver/topology/connection.go

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,13 @@ import (
2727
"go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
2828
)
2929

30+
// Connection state constants.
31+
const (
32+
connDisconnected int64 = iota
33+
connConnected
34+
connInitialized
35+
)
36+
3037
var globalConnectionID uint64 = 1
3138

3239
var (
@@ -38,10 +45,10 @@ var (
3845
func nextConnectionID() uint64 { return atomic.AddUint64(&globalConnectionID, 1) }
3946

4047
type connection struct {
41-
// connected must be accessed using the atomic package and should be at the beginning of the struct.
48+
// state must be accessed using the atomic package and should be at the beginning of the struct.
4249
// - atomic bug: https://pkg.go.dev/sync/atomic#pkg-note-BUG
4350
// - suggested layout: https://go101.org/article/memory-layout.html
44-
connected int64
51+
state int64
4552

4653
id string
4754
nc net.Conn // When nil, the connection is closed.
@@ -93,7 +100,7 @@ func newConnection(addr address.Address, opts ...ConnectionOption) *connection {
93100
if !c.config.loadBalanced {
94101
c.setGenerationNumber()
95102
}
96-
atomic.StoreInt64(&c.connected, initialized)
103+
atomic.StoreInt64(&c.state, connInitialized)
97104

98105
return c
99106
}
@@ -123,7 +130,7 @@ func (c *connection) hasGenerationNumber() bool {
123130
// handshakes. All errors returned by connect are considered "before the handshake completes" and
124131
// must be handled by calling the appropriate SDAM handshake error handler.
125132
func (c *connection) connect(ctx context.Context) (err error) {
126-
if !atomic.CompareAndSwapInt64(&c.connected, initialized, connected) {
133+
if !atomic.CompareAndSwapInt64(&c.state, connInitialized, connConnected) {
127134
return nil
128135
}
129136

@@ -133,7 +140,7 @@ func (c *connection) connect(ctx context.Context) (err error) {
133140
// underlying net.Conn if it was created.
134141
defer func() {
135142
if err != nil {
136-
atomic.StoreInt64(&c.connected, disconnected)
143+
atomic.StoreInt64(&c.state, connDisconnected)
137144

138145
if c.nc != nil {
139146
_ = c.nc.Close()
@@ -323,7 +330,7 @@ func (c *connection) cancellationListenerCallback() {
323330

324331
func (c *connection) writeWireMessage(ctx context.Context, wm []byte) error {
325332
var err error
326-
if atomic.LoadInt64(&c.connected) != connected {
333+
if atomic.LoadInt64(&c.state) != connConnected {
327334
return ConnectionError{ConnectionID: c.id, message: "connection is closed"}
328335
}
329336
select {
@@ -380,7 +387,7 @@ func (c *connection) write(ctx context.Context, wm []byte) (err error) {
380387

381388
// readWireMessage reads a wiremessage from the connection. The dst parameter will be overwritten.
382389
func (c *connection) readWireMessage(ctx context.Context, dst []byte) ([]byte, error) {
383-
if atomic.LoadInt64(&c.connected) != connected {
390+
if atomic.LoadInt64(&c.state) != connConnected {
384391
return dst, ConnectionError{ConnectionID: c.id, message: "connection is closed"}
385392
}
386393

@@ -483,7 +490,7 @@ func (c *connection) read(ctx context.Context, dst []byte) (bytesRead []byte, er
483490

484491
func (c *connection) close() error {
485492
// Overwrite the connection state as the first step so only the first close call will execute.
486-
if !atomic.CompareAndSwapInt64(&c.connected, connected, disconnected) {
493+
if !atomic.CompareAndSwapInt64(&c.state, connConnected, connDisconnected) {
487494
return nil
488495
}
489496

@@ -496,7 +503,7 @@ func (c *connection) close() error {
496503
}
497504

498505
func (c *connection) closed() bool {
499-
return atomic.LoadInt64(&c.connected) == disconnected
506+
return atomic.LoadInt64(&c.state) == connDisconnected
500507
}
501508

502509
func (c *connection) idleTimeoutExpired() bool {

x/mongo/driver/topology/connection_errors_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,14 @@ func TestConnectionErrors(t *testing.T) {
5353
t.Run("write error", func(t *testing.T) {
5454
ctx, cancel := context.WithCancel(context.Background())
5555
cancel()
56-
conn := &connection{id: "foobar", nc: &net.TCPConn{}, connected: connected}
56+
conn := &connection{id: "foobar", nc: &net.TCPConn{}, state: connConnected}
5757
err := conn.writeWireMessage(ctx, []byte{})
5858
assert.True(t, errors.Is(err, context.Canceled), "expected error %v, got %v", context.Canceled, err)
5959
})
6060
t.Run("read error", func(t *testing.T) {
6161
ctx, cancel := context.WithCancel(context.Background())
6262
cancel()
63-
conn := &connection{id: "foobar", nc: &net.TCPConn{}, connected: connected}
63+
conn := &connection{id: "foobar", nc: &net.TCPConn{}, state: connConnected}
6464
_, err := conn.readWireMessage(ctx, []byte{})
6565
assert.True(t, errors.Is(err, context.Canceled), "expected error %v, got %v", context.Canceled, err)
6666
})

x/mongo/driver/topology/connection_test.go

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ func TestConnection(t *testing.T) {
6868
if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
6969
t.Errorf("errors do not match. got %v; want %v", got, want)
7070
}
71-
connState := atomic.LoadInt64(&conn.connected)
72-
assert.Equal(t, disconnected, connState, "expected connection state %v, got %v", disconnected, connState)
71+
connState := atomic.LoadInt64(&conn.state)
72+
assert.Equal(t, connDisconnected, connState, "expected connection state %v, got %v", connDisconnected, connState)
7373
})
7474
t.Run("handshaker error", func(t *testing.T) {
7575
err := errors.New("handshaker error")
@@ -92,8 +92,8 @@ func TestConnection(t *testing.T) {
9292
if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
9393
t.Errorf("errors do not match. got %v; want %v", got, want)
9494
}
95-
connState := atomic.LoadInt64(&conn.connected)
96-
assert.Equal(t, disconnected, connState, "expected connection state %v, got %v", disconnected, connState)
95+
connState := atomic.LoadInt64(&conn.state)
96+
assert.Equal(t, connDisconnected, connState, "expected connection state %v, got %v", connDisconnected, connState)
9797
})
9898
t.Run("context is not pinned by connect", func(t *testing.T) {
9999
// connect creates a cancel-able version of the context passed to it and stores the CancelFunc on the
@@ -345,7 +345,7 @@ func TestConnection(t *testing.T) {
345345
t.Run("completed context", func(t *testing.T) {
346346
ctx, cancel := context.WithCancel(context.Background())
347347
cancel()
348-
conn := &connection{id: "foobar", nc: &net.TCPConn{}, connected: connected}
348+
conn := &connection{id: "foobar", nc: &net.TCPConn{}, state: connConnected}
349349
want := ConnectionError{ConnectionID: "foobar", Wrapped: ctx.Err(), message: "failed to write"}
350350
got := conn.writeWireMessage(ctx, []byte{})
351351
if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
@@ -380,7 +380,7 @@ func TestConnection(t *testing.T) {
380380
message: "failed to set write deadline",
381381
}
382382
tnc := &testNetConn{deadlineerr: errors.New("set writeDeadline error")}
383-
conn := &connection{id: "foobar", nc: tnc, writeTimeout: tc.timeout, connected: connected}
383+
conn := &connection{id: "foobar", nc: tnc, writeTimeout: tc.timeout, state: connConnected}
384384
got := conn.writeWireMessage(ctx, []byte{})
385385
if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
386386
t.Errorf("errors do not match. got %v; want %v", got, want)
@@ -397,7 +397,7 @@ func TestConnection(t *testing.T) {
397397
t.Run("error", func(t *testing.T) {
398398
err := errors.New("Write error")
399399
tnc := &testNetConn{writeerr: err}
400-
conn := &connection{id: "foobar", nc: tnc, connected: connected}
400+
conn := &connection{id: "foobar", nc: tnc, state: connConnected}
401401
listener := newTestCancellationListener(false)
402402
conn.cancellationListener = listener
403403

@@ -413,7 +413,7 @@ func TestConnection(t *testing.T) {
413413
})
414414
t.Run("success", func(t *testing.T) {
415415
tnc := &testNetConn{}
416-
conn := &connection{id: "foobar", nc: tnc, connected: connected}
416+
conn := &connection{id: "foobar", nc: tnc, state: connConnected}
417417
listener := newTestCancellationListener(false)
418418
conn.cancellationListener = listener
419419

@@ -430,7 +430,7 @@ func TestConnection(t *testing.T) {
430430
// Simulate context cancellation during a network write.
431431

432432
nc := newCancellationWriteConn(&testNetConn{}, 0)
433-
conn := &connection{id: "foobar", nc: nc, connected: connected}
433+
conn := &connection{id: "foobar", nc: nc, state: connConnected}
434434
listener := newTestCancellationListener(false)
435435
conn.cancellationListener = listener
436436

@@ -451,24 +451,24 @@ func TestConnection(t *testing.T) {
451451
wg.Wait()
452452
want := ConnectionError{ConnectionID: conn.id, Wrapped: context.Canceled, message: writeErrMsg}
453453
assert.Equal(t, want, err, "expected error %v, got %v", want, err)
454-
assert.Equal(t, disconnected, conn.connected, "expected connection state %v, got %v", disconnected,
455-
conn.connected)
454+
assert.Equal(t, connDisconnected, conn.state, "expected connection state %v, got %v", connDisconnected,
455+
conn.state)
456456
})
457457
t.Run("connection is closed if context is cancelled even if network write succeeds", func(t *testing.T) {
458458
// Test the race condition between Write and the cancellation listener. The socket write will
459459
// succeed, but we set the abortedForCancellation flag to true to simulate the context being
460460
// cancelled immediately after the Write finishes.
461461

462462
tnc := &testNetConn{}
463-
conn := &connection{id: "foobar", nc: tnc, connected: connected}
463+
conn := &connection{id: "foobar", nc: tnc, state: connConnected}
464464
listener := newTestCancellationListener(true)
465465
conn.cancellationListener = listener
466466

467467
want := ConnectionError{ConnectionID: conn.id, Wrapped: context.Canceled, message: writeErrMsg}
468468
err := conn.writeWireMessage(context.Background(), []byte("foobar"))
469469
assert.Equal(t, want, err, "expected error %v, got %v", want, err)
470-
assert.Equal(t, conn.connected, disconnected, "expected connection state %v, got %v", disconnected,
471-
conn.connected)
470+
assert.Equal(t, conn.state, connDisconnected, "expected connection state %v, got %v", connDisconnected,
471+
conn.state)
472472
})
473473
})
474474
})
@@ -484,7 +484,7 @@ func TestConnection(t *testing.T) {
484484
t.Run("completed context", func(t *testing.T) {
485485
ctx, cancel := context.WithCancel(context.Background())
486486
cancel()
487-
conn := &connection{id: "foobar", nc: &net.TCPConn{}, connected: connected}
487+
conn := &connection{id: "foobar", nc: &net.TCPConn{}, state: connConnected}
488488
want := ConnectionError{ConnectionID: "foobar", Wrapped: ctx.Err(), message: "failed to read"}
489489
_, got := conn.readWireMessage(ctx, []byte{})
490490
if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
@@ -519,7 +519,7 @@ func TestConnection(t *testing.T) {
519519
message: "failed to set read deadline",
520520
}
521521
tnc := &testNetConn{deadlineerr: errors.New("set readDeadline error")}
522-
conn := &connection{id: "foobar", nc: tnc, readTimeout: tc.timeout, connected: connected}
522+
conn := &connection{id: "foobar", nc: tnc, readTimeout: tc.timeout, state: connConnected}
523523
_, got := conn.readWireMessage(ctx, []byte{})
524524
if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
525525
t.Errorf("errors do not match. got %v; want %v", got, want)
@@ -534,7 +534,7 @@ func TestConnection(t *testing.T) {
534534
t.Run("size read errors", func(t *testing.T) {
535535
err := errors.New("Read error")
536536
tnc := &testNetConn{readerr: err}
537-
conn := &connection{id: "foobar", nc: tnc, connected: connected}
537+
conn := &connection{id: "foobar", nc: tnc, state: connConnected}
538538
listener := newTestCancellationListener(false)
539539
conn.cancellationListener = listener
540540

@@ -551,7 +551,7 @@ func TestConnection(t *testing.T) {
551551
t.Run("full message read errors", func(t *testing.T) {
552552
err := errors.New("Read error")
553553
tnc := &testNetConn{readerr: err, buf: []byte{0x11, 0x00, 0x00, 0x00}}
554-
conn := &connection{id: "foobar", nc: tnc, connected: connected}
554+
conn := &connection{id: "foobar", nc: tnc, state: connConnected}
555555
listener := newTestCancellationListener(false)
556556
conn.cancellationListener = listener
557557

@@ -587,7 +587,7 @@ func TestConnection(t *testing.T) {
587587
err := errors.New("length of read message too large")
588588
tnc := &testNetConn{buf: make([]byte, len(tc.buffer))}
589589
copy(tnc.buf, tc.buffer)
590-
conn := &connection{id: "foobar", nc: tnc, connected: connected, desc: tc.desc}
590+
conn := &connection{id: "foobar", nc: tnc, state: connConnected, desc: tc.desc}
591591
listener := newTestCancellationListener(false)
592592
conn.cancellationListener = listener
593593

@@ -604,7 +604,7 @@ func TestConnection(t *testing.T) {
604604
want := []byte{0x0A, 0x00, 0x00, 0x00, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A}
605605
tnc := &testNetConn{buf: make([]byte, len(want))}
606606
copy(tnc.buf, want)
607-
conn := &connection{id: "foobar", nc: tnc, connected: connected}
607+
conn := &connection{id: "foobar", nc: tnc, state: connConnected}
608608
listener := newTestCancellationListener(false)
609609
conn.cancellationListener = listener
610610

@@ -634,7 +634,7 @@ func TestConnection(t *testing.T) {
634634
readBuf := []byte{10, 0, 0, 0}
635635
nc := newCancellationReadConn(&testNetConn{}, tc.skip, readBuf)
636636

637-
conn := &connection{id: "foobar", nc: nc, connected: connected}
637+
conn := &connection{id: "foobar", nc: nc, state: connConnected}
638638
listener := newTestCancellationListener(false)
639639
conn.cancellationListener = listener
640640

@@ -655,22 +655,22 @@ func TestConnection(t *testing.T) {
655655
wg.Wait()
656656
want := ConnectionError{ConnectionID: conn.id, Wrapped: context.Canceled, message: tc.errmsg}
657657
assert.Equal(t, want, err, "expected error %v, got %v", want, err)
658-
assert.Equal(t, disconnected, conn.connected, "expected connection state %v, got %v", disconnected,
659-
conn.connected)
658+
assert.Equal(t, connDisconnected, conn.state, "expected connection state %v, got %v", connDisconnected,
659+
conn.state)
660660
})
661661
}
662662
})
663663
t.Run("closes connection if context is cancelled even if the socket read succeeds", func(t *testing.T) {
664664
tnc := &testNetConn{buf: []byte{0x0A, 0x00, 0x00, 0x00, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A}}
665-
conn := &connection{id: "foobar", nc: tnc, connected: connected}
665+
conn := &connection{id: "foobar", nc: tnc, state: connConnected}
666666
listener := newTestCancellationListener(true)
667667
conn.cancellationListener = listener
668668

669669
want := ConnectionError{ConnectionID: conn.id, Wrapped: context.Canceled, message: "unable to read server response"}
670670
_, err := conn.readWireMessage(context.Background(), nil)
671671
assert.Equal(t, want, err, "expected error %v, got %v", want, err)
672-
assert.Equal(t, disconnected, conn.connected, "expected connection state %v, got %v", disconnected,
673-
conn.connected)
672+
assert.Equal(t, connDisconnected, conn.state, "expected connection state %v, got %v", connDisconnected,
673+
conn.state)
674674
})
675675
})
676676
})
@@ -693,8 +693,8 @@ func TestConnection(t *testing.T) {
693693

694694
err := conn.connect(context.Background())
695695
assert.NotNil(t, err, "expected handshake error from connect, got nil")
696-
connState := atomic.LoadInt64(&conn.connected)
697-
assert.Equal(t, disconnected, connState, "expected connection state %v, got %v", disconnected, connState)
696+
connState := atomic.LoadInt64(&conn.state)
697+
assert.Equal(t, connDisconnected, connState, "expected connection state %v, got %v", connDisconnected, connState)
698698

699699
err = conn.close()
700700
assert.Nil(t, err, "close error: %v", err)
@@ -703,11 +703,11 @@ func TestConnection(t *testing.T) {
703703
t.Run("cancellation listener callback", func(t *testing.T) {
704704
t.Run("closes connection", func(t *testing.T) {
705705
tnc := &testNetConn{}
706-
conn := &connection{connected: connected, nc: tnc}
706+
conn := &connection{state: connConnected, nc: tnc}
707707

708708
conn.cancellationListenerCallback()
709-
assert.True(t, conn.connected == disconnected, "expected connection state %v, got %v", disconnected,
710-
conn.connected)
709+
assert.True(t, conn.state == connDisconnected, "expected connection state %v, got %v", connDisconnected,
710+
conn.state)
711711
assert.True(t, tnc.closed, "expected net.Conn to be closed but was not")
712712
})
713713
})

x/mongo/driver/topology/pool.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,7 @@ func (p *pool) closeConnection(conn *connection) error {
505505
return ErrWrongPool
506506
}
507507

508-
if atomic.LoadInt64(&conn.connected) == connected {
508+
if atomic.LoadInt64(&conn.state) == connConnected {
509509
conn.closeConnectContext()
510510
conn.wait() // Make sure that the connection has finished connecting.
511511
}

x/mongo/driver/topology/pool_generation_counter.go

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@ import (
1313
"go.mongodb.org/mongo-driver/bson/primitive"
1414
)
1515

16+
// Pool generation state constants.
17+
const (
18+
generationDisconnected int64 = iota
19+
generationConnected
20+
)
21+
1622
// generationStats represents the version of a pool. It tracks the generation number as well as the number of
1723
// connections that have been created in the generation.
1824
type generationStats struct {
@@ -42,11 +48,11 @@ func newPoolGenerationMap() *poolGenerationMap {
4248
}
4349

4450
func (p *poolGenerationMap) connect() {
45-
atomic.StoreInt64(&p.state, connected)
51+
atomic.StoreInt64(&p.state, generationConnected)
4652
}
4753

4854
func (p *poolGenerationMap) disconnect() {
49-
atomic.StoreInt64(&p.state, disconnected)
55+
atomic.StoreInt64(&p.state, generationDisconnected)
5056
}
5157

5258
// addConnection increments the connection count for the generation associated with the given service ID and returns the
@@ -102,7 +108,7 @@ func (p *poolGenerationMap) clear(serviceIDPtr *primitive.ObjectID) {
102108

103109
func (p *poolGenerationMap) stale(serviceIDPtr *primitive.ObjectID, knownGeneration uint64) bool {
104110
// If the map has been disconnected, all connections should be considered stale to ensure that they're closed.
105-
if atomic.LoadInt64(&p.state) == disconnected {
111+
if atomic.LoadInt64(&p.state) == generationDisconnected {
106112
return true
107113
}
108114

x/mongo/driver/topology/pool_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ func TestPool(t *testing.T) {
232232
time.Sleep(time.Millisecond)
233233
}
234234
for _, c := range conns {
235-
assert.Equalf(t, connected, c.connected, "expected conn to still be connected")
235+
assert.Equalf(t, connConnected, c.state, "expected conn to still be connected")
236236

237237
err := p.checkIn(c)
238238
noerr(t, err)

0 commit comments

Comments
 (0)