Skip to content

GODRIVER-2181 Use dedicated state constants for each topology type. #870

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion x/mongo/driver/topology/CMAP_spec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ func runCMAPTest(t *testing.T, testFileName string) {
return &event.PoolMonitor{func(event *event.PoolEvent) { testInfo.originalEventChan <- event }}
}))
testHelpers.RequireNil(t, err, "error creating server: %v", err)
s.connectionstate = connected
s.state = serverConnected
testHelpers.RequireNil(t, err, "error connecting connection pool: %v", err)
defer s.pool.close(context.Background())

Expand Down
25 changes: 16 additions & 9 deletions x/mongo/driver/topology/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@ import (
"go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
)

// Connection state constants.
const (
connDisconnected int64 = iota
connConnected
connInitialized
)

var globalConnectionID uint64 = 1

var (
Expand All @@ -38,10 +45,10 @@ var (
func nextConnectionID() uint64 { return atomic.AddUint64(&globalConnectionID, 1) }

type connection struct {
// connected must be accessed using the atomic package and should be at the beginning of the struct.
// state must be accessed using the atomic package and should be at the beginning of the struct.
// - atomic bug: https://pkg.go.dev/sync/atomic#pkg-note-BUG
// - suggested layout: https://go101.org/article/memory-layout.html
connected int64
state int64

id string
nc net.Conn // When nil, the connection is closed.
Expand Down Expand Up @@ -93,7 +100,7 @@ func newConnection(addr address.Address, opts ...ConnectionOption) *connection {
if !c.config.loadBalanced {
c.setGenerationNumber()
}
atomic.StoreInt64(&c.connected, initialized)
atomic.StoreInt64(&c.state, connInitialized)

return c
}
Expand Down Expand Up @@ -123,7 +130,7 @@ func (c *connection) hasGenerationNumber() bool {
// handshakes. All errors returned by connect are considered "before the handshake completes" and
// must be handled by calling the appropriate SDAM handshake error handler.
func (c *connection) connect(ctx context.Context) (err error) {
if !atomic.CompareAndSwapInt64(&c.connected, initialized, connected) {
if !atomic.CompareAndSwapInt64(&c.state, connInitialized, connConnected) {
return nil
}

Expand All @@ -133,7 +140,7 @@ func (c *connection) connect(ctx context.Context) (err error) {
// underlying net.Conn if it was created.
defer func() {
if err != nil {
atomic.StoreInt64(&c.connected, disconnected)
atomic.StoreInt64(&c.state, connDisconnected)

if c.nc != nil {
_ = c.nc.Close()
Expand Down Expand Up @@ -323,7 +330,7 @@ func (c *connection) cancellationListenerCallback() {

func (c *connection) writeWireMessage(ctx context.Context, wm []byte) error {
var err error
if atomic.LoadInt64(&c.connected) != connected {
if atomic.LoadInt64(&c.state) != connConnected {
return ConnectionError{ConnectionID: c.id, message: "connection is closed"}
}
select {
Expand Down Expand Up @@ -380,7 +387,7 @@ func (c *connection) write(ctx context.Context, wm []byte) (err error) {

// readWireMessage reads a wiremessage from the connection. The dst parameter will be overwritten.
func (c *connection) readWireMessage(ctx context.Context, dst []byte) ([]byte, error) {
if atomic.LoadInt64(&c.connected) != connected {
if atomic.LoadInt64(&c.state) != connConnected {
return dst, ConnectionError{ConnectionID: c.id, message: "connection is closed"}
}

Expand Down Expand Up @@ -483,7 +490,7 @@ func (c *connection) read(ctx context.Context, dst []byte) (bytesRead []byte, er

func (c *connection) close() error {
// Overwrite the connection state as the first step so only the first close call will execute.
if !atomic.CompareAndSwapInt64(&c.connected, connected, disconnected) {
if !atomic.CompareAndSwapInt64(&c.state, connConnected, connDisconnected) {
return nil
}

Expand All @@ -496,7 +503,7 @@ func (c *connection) close() error {
}

func (c *connection) closed() bool {
return atomic.LoadInt64(&c.connected) == disconnected
return atomic.LoadInt64(&c.state) == connDisconnected
}

func (c *connection) idleTimeoutExpired() bool {
Expand Down
4 changes: 2 additions & 2 deletions x/mongo/driver/topology/connection_errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,14 @@ func TestConnectionErrors(t *testing.T) {
t.Run("write error", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
conn := &connection{id: "foobar", nc: &net.TCPConn{}, connected: connected}
conn := &connection{id: "foobar", nc: &net.TCPConn{}, state: connConnected}
err := conn.writeWireMessage(ctx, []byte{})
assert.True(t, errors.Is(err, context.Canceled), "expected error %v, got %v", context.Canceled, err)
})
t.Run("read error", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
conn := &connection{id: "foobar", nc: &net.TCPConn{}, connected: connected}
conn := &connection{id: "foobar", nc: &net.TCPConn{}, state: connConnected}
_, err := conn.readWireMessage(ctx, []byte{})
assert.True(t, errors.Is(err, context.Canceled), "expected error %v, got %v", context.Canceled, err)
})
Expand Down
62 changes: 31 additions & 31 deletions x/mongo/driver/topology/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ func TestConnection(t *testing.T) {
if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
t.Errorf("errors do not match. got %v; want %v", got, want)
}
connState := atomic.LoadInt64(&conn.connected)
assert.Equal(t, disconnected, connState, "expected connection state %v, got %v", disconnected, connState)
connState := atomic.LoadInt64(&conn.state)
assert.Equal(t, connDisconnected, connState, "expected connection state %v, got %v", connDisconnected, connState)
})
t.Run("handshaker error", func(t *testing.T) {
err := errors.New("handshaker error")
Expand All @@ -92,8 +92,8 @@ func TestConnection(t *testing.T) {
if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
t.Errorf("errors do not match. got %v; want %v", got, want)
}
connState := atomic.LoadInt64(&conn.connected)
assert.Equal(t, disconnected, connState, "expected connection state %v, got %v", disconnected, connState)
connState := atomic.LoadInt64(&conn.state)
assert.Equal(t, connDisconnected, connState, "expected connection state %v, got %v", connDisconnected, connState)
})
t.Run("context is not pinned by connect", func(t *testing.T) {
// connect creates a cancel-able version of the context passed to it and stores the CancelFunc on the
Expand Down Expand Up @@ -345,7 +345,7 @@ func TestConnection(t *testing.T) {
t.Run("completed context", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
conn := &connection{id: "foobar", nc: &net.TCPConn{}, connected: connected}
conn := &connection{id: "foobar", nc: &net.TCPConn{}, state: connConnected}
want := ConnectionError{ConnectionID: "foobar", Wrapped: ctx.Err(), message: "failed to write"}
got := conn.writeWireMessage(ctx, []byte{})
if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
Expand Down Expand Up @@ -380,7 +380,7 @@ func TestConnection(t *testing.T) {
message: "failed to set write deadline",
}
tnc := &testNetConn{deadlineerr: errors.New("set writeDeadline error")}
conn := &connection{id: "foobar", nc: tnc, writeTimeout: tc.timeout, connected: connected}
conn := &connection{id: "foobar", nc: tnc, writeTimeout: tc.timeout, state: connConnected}
got := conn.writeWireMessage(ctx, []byte{})
if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
t.Errorf("errors do not match. got %v; want %v", got, want)
Expand All @@ -397,7 +397,7 @@ func TestConnection(t *testing.T) {
t.Run("error", func(t *testing.T) {
err := errors.New("Write error")
tnc := &testNetConn{writeerr: err}
conn := &connection{id: "foobar", nc: tnc, connected: connected}
conn := &connection{id: "foobar", nc: tnc, state: connConnected}
listener := newTestCancellationListener(false)
conn.cancellationListener = listener

Expand All @@ -413,7 +413,7 @@ func TestConnection(t *testing.T) {
})
t.Run("success", func(t *testing.T) {
tnc := &testNetConn{}
conn := &connection{id: "foobar", nc: tnc, connected: connected}
conn := &connection{id: "foobar", nc: tnc, state: connConnected}
listener := newTestCancellationListener(false)
conn.cancellationListener = listener

Expand All @@ -430,7 +430,7 @@ func TestConnection(t *testing.T) {
// Simulate context cancellation during a network write.

nc := newCancellationWriteConn(&testNetConn{}, 0)
conn := &connection{id: "foobar", nc: nc, connected: connected}
conn := &connection{id: "foobar", nc: nc, state: connConnected}
listener := newTestCancellationListener(false)
conn.cancellationListener = listener

Expand All @@ -451,24 +451,24 @@ func TestConnection(t *testing.T) {
wg.Wait()
want := ConnectionError{ConnectionID: conn.id, Wrapped: context.Canceled, message: writeErrMsg}
assert.Equal(t, want, err, "expected error %v, got %v", want, err)
assert.Equal(t, disconnected, conn.connected, "expected connection state %v, got %v", disconnected,
conn.connected)
assert.Equal(t, connDisconnected, conn.state, "expected connection state %v, got %v", connDisconnected,
conn.state)
})
t.Run("connection is closed if context is cancelled even if network write succeeds", func(t *testing.T) {
// Test the race condition between Write and the cancellation listener. The socket write will
// succeed, but we set the abortedForCancellation flag to true to simulate the context being
// cancelled immediately after the Write finishes.

tnc := &testNetConn{}
conn := &connection{id: "foobar", nc: tnc, connected: connected}
conn := &connection{id: "foobar", nc: tnc, state: connConnected}
listener := newTestCancellationListener(true)
conn.cancellationListener = listener

want := ConnectionError{ConnectionID: conn.id, Wrapped: context.Canceled, message: writeErrMsg}
err := conn.writeWireMessage(context.Background(), []byte("foobar"))
assert.Equal(t, want, err, "expected error %v, got %v", want, err)
assert.Equal(t, conn.connected, disconnected, "expected connection state %v, got %v", disconnected,
conn.connected)
assert.Equal(t, conn.state, connDisconnected, "expected connection state %v, got %v", connDisconnected,
conn.state)
})
})
})
Expand All @@ -484,7 +484,7 @@ func TestConnection(t *testing.T) {
t.Run("completed context", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
conn := &connection{id: "foobar", nc: &net.TCPConn{}, connected: connected}
conn := &connection{id: "foobar", nc: &net.TCPConn{}, state: connConnected}
want := ConnectionError{ConnectionID: "foobar", Wrapped: ctx.Err(), message: "failed to read"}
_, got := conn.readWireMessage(ctx, []byte{})
if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
Expand Down Expand Up @@ -519,7 +519,7 @@ func TestConnection(t *testing.T) {
message: "failed to set read deadline",
}
tnc := &testNetConn{deadlineerr: errors.New("set readDeadline error")}
conn := &connection{id: "foobar", nc: tnc, readTimeout: tc.timeout, connected: connected}
conn := &connection{id: "foobar", nc: tnc, readTimeout: tc.timeout, state: connConnected}
_, got := conn.readWireMessage(ctx, []byte{})
if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
t.Errorf("errors do not match. got %v; want %v", got, want)
Expand All @@ -534,7 +534,7 @@ func TestConnection(t *testing.T) {
t.Run("size read errors", func(t *testing.T) {
err := errors.New("Read error")
tnc := &testNetConn{readerr: err}
conn := &connection{id: "foobar", nc: tnc, connected: connected}
conn := &connection{id: "foobar", nc: tnc, state: connConnected}
listener := newTestCancellationListener(false)
conn.cancellationListener = listener

Expand All @@ -551,7 +551,7 @@ func TestConnection(t *testing.T) {
t.Run("full message read errors", func(t *testing.T) {
err := errors.New("Read error")
tnc := &testNetConn{readerr: err, buf: []byte{0x11, 0x00, 0x00, 0x00}}
conn := &connection{id: "foobar", nc: tnc, connected: connected}
conn := &connection{id: "foobar", nc: tnc, state: connConnected}
listener := newTestCancellationListener(false)
conn.cancellationListener = listener

Expand Down Expand Up @@ -587,7 +587,7 @@ func TestConnection(t *testing.T) {
err := errors.New("length of read message too large")
tnc := &testNetConn{buf: make([]byte, len(tc.buffer))}
copy(tnc.buf, tc.buffer)
conn := &connection{id: "foobar", nc: tnc, connected: connected, desc: tc.desc}
conn := &connection{id: "foobar", nc: tnc, state: connConnected, desc: tc.desc}
listener := newTestCancellationListener(false)
conn.cancellationListener = listener

Expand All @@ -604,7 +604,7 @@ func TestConnection(t *testing.T) {
want := []byte{0x0A, 0x00, 0x00, 0x00, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A}
tnc := &testNetConn{buf: make([]byte, len(want))}
copy(tnc.buf, want)
conn := &connection{id: "foobar", nc: tnc, connected: connected}
conn := &connection{id: "foobar", nc: tnc, state: connConnected}
listener := newTestCancellationListener(false)
conn.cancellationListener = listener

Expand Down Expand Up @@ -634,7 +634,7 @@ func TestConnection(t *testing.T) {
readBuf := []byte{10, 0, 0, 0}
nc := newCancellationReadConn(&testNetConn{}, tc.skip, readBuf)

conn := &connection{id: "foobar", nc: nc, connected: connected}
conn := &connection{id: "foobar", nc: nc, state: connConnected}
listener := newTestCancellationListener(false)
conn.cancellationListener = listener

Expand All @@ -655,22 +655,22 @@ func TestConnection(t *testing.T) {
wg.Wait()
want := ConnectionError{ConnectionID: conn.id, Wrapped: context.Canceled, message: tc.errmsg}
assert.Equal(t, want, err, "expected error %v, got %v", want, err)
assert.Equal(t, disconnected, conn.connected, "expected connection state %v, got %v", disconnected,
conn.connected)
assert.Equal(t, connDisconnected, conn.state, "expected connection state %v, got %v", connDisconnected,
conn.state)
})
}
})
t.Run("closes connection if context is cancelled even if the socket read succeeds", func(t *testing.T) {
tnc := &testNetConn{buf: []byte{0x0A, 0x00, 0x00, 0x00, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A}}
conn := &connection{id: "foobar", nc: tnc, connected: connected}
conn := &connection{id: "foobar", nc: tnc, state: connConnected}
listener := newTestCancellationListener(true)
conn.cancellationListener = listener

want := ConnectionError{ConnectionID: conn.id, Wrapped: context.Canceled, message: "unable to read server response"}
_, err := conn.readWireMessage(context.Background(), nil)
assert.Equal(t, want, err, "expected error %v, got %v", want, err)
assert.Equal(t, disconnected, conn.connected, "expected connection state %v, got %v", disconnected,
conn.connected)
assert.Equal(t, connDisconnected, conn.state, "expected connection state %v, got %v", connDisconnected,
conn.state)
})
})
})
Expand All @@ -693,8 +693,8 @@ func TestConnection(t *testing.T) {

err := conn.connect(context.Background())
assert.NotNil(t, err, "expected handshake error from connect, got nil")
connState := atomic.LoadInt64(&conn.connected)
assert.Equal(t, disconnected, connState, "expected connection state %v, got %v", disconnected, connState)
connState := atomic.LoadInt64(&conn.state)
assert.Equal(t, connDisconnected, connState, "expected connection state %v, got %v", connDisconnected, connState)

err = conn.close()
assert.Nil(t, err, "close error: %v", err)
Expand All @@ -703,11 +703,11 @@ func TestConnection(t *testing.T) {
t.Run("cancellation listener callback", func(t *testing.T) {
t.Run("closes connection", func(t *testing.T) {
tnc := &testNetConn{}
conn := &connection{connected: connected, nc: tnc}
conn := &connection{state: connConnected, nc: tnc}

conn.cancellationListenerCallback()
assert.True(t, conn.connected == disconnected, "expected connection state %v, got %v", disconnected,
conn.connected)
assert.True(t, conn.state == connDisconnected, "expected connection state %v, got %v", connDisconnected,
conn.state)
assert.True(t, tnc.closed, "expected net.Conn to be closed but was not")
})
})
Expand Down
2 changes: 1 addition & 1 deletion x/mongo/driver/topology/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ func (p *pool) closeConnection(conn *connection) error {
return ErrWrongPool
}

if atomic.LoadInt64(&conn.connected) == connected {
if atomic.LoadInt64(&conn.state) == connConnected {
conn.closeConnectContext()
conn.wait() // Make sure that the connection has finished connecting.
}
Expand Down
12 changes: 9 additions & 3 deletions x/mongo/driver/topology/pool_generation_counter.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ import (
"go.mongodb.org/mongo-driver/bson/primitive"
)

// Pool generation state constants.
const (
generationDisconnected int64 = iota
generationConnected
)

// generationStats represents the version of a pool. It tracks the generation number as well as the number of
// connections that have been created in the generation.
type generationStats struct {
Expand Down Expand Up @@ -42,11 +48,11 @@ func newPoolGenerationMap() *poolGenerationMap {
}

func (p *poolGenerationMap) connect() {
atomic.StoreInt64(&p.state, connected)
atomic.StoreInt64(&p.state, generationConnected)
}

func (p *poolGenerationMap) disconnect() {
atomic.StoreInt64(&p.state, disconnected)
atomic.StoreInt64(&p.state, generationDisconnected)
}

// addConnection increments the connection count for the generation associated with the given service ID and returns the
Expand Down Expand Up @@ -102,7 +108,7 @@ func (p *poolGenerationMap) clear(serviceIDPtr *primitive.ObjectID) {

func (p *poolGenerationMap) stale(serviceIDPtr *primitive.ObjectID, knownGeneration uint64) bool {
// If the map has been disconnected, all connections should be considered stale to ensure that they're closed.
if atomic.LoadInt64(&p.state) == disconnected {
if atomic.LoadInt64(&p.state) == generationDisconnected {
return true
}

Expand Down
2 changes: 1 addition & 1 deletion x/mongo/driver/topology/pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ func TestPool(t *testing.T) {
time.Sleep(time.Millisecond)
}
for _, c := range conns {
assert.Equalf(t, connected, c.connected, "expected conn to still be connected")
assert.Equalf(t, connConnected, c.state, "expected conn to still be connected")

err := p.checkIn(c)
noerr(t, err)
Expand Down
Loading