Skip to content

Commit d648c5b

Browse files
authored
Make backend key data optional (#296)
1 parent 2825829 commit d648c5b

File tree

7 files changed

+62
-32
lines changed

7 files changed

+62
-32
lines changed

Sources/PostgresNIO/Connection/PostgresConnection.swift

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,12 @@ public final class PostgresConnection {
7878
/// - Default: 5432
7979
public var port: Int
8080

81+
/// Require connection to provide `BackendKeyData`.
82+
/// For use with Amazon RDS Proxy, this must be set to false.
83+
///
84+
/// - Default: true
85+
public var requireBackendKeyData: Bool = true
86+
8187
/// Specifies a timeout to apply to a connection attempt.
8288
///
8389
/// - Default: 10 seconds
@@ -401,7 +407,8 @@ extension PostgresConnection {
401407
connection: .resolved(address: socketAddress, serverName: serverHostname),
402408
connectTimeout: .seconds(10),
403409
authentication: nil,
404-
tls: tls
410+
tls: tls,
411+
requireBackendKeyData: true
405412
)
406413

407414
return PostgresConnection.connect(
@@ -764,6 +771,8 @@ extension PostgresConnection {
764771
var authentication: Configuration.Authentication?
765772

766773
var tls: Configuration.TLS
774+
775+
var requireBackendKeyData: Bool
767776
}
768777
}
769778

@@ -773,6 +782,7 @@ extension PostgresConnection.InternalConfiguration {
773782
self.connection = .unresolved(host: config.connection.host, port: config.connection.port)
774783
self.connectTimeout = config.connection.connectTimeout
775784
self.tls = config.tls
785+
self.requireBackendKeyData = config.connection.requireBackendKeyData
776786
}
777787
}
778788

Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@ struct ConnectionStateMachine {
55
typealias TransactionState = PostgresBackendMessage.TransactionState
66

77
struct ConnectionContext {
8-
let processID: Int32
9-
let secretKey: Int32
10-
8+
let backendKeyData: Optional<BackendKeyData>
119
var parameters: [String: String]
1210
var transactionState: TransactionState
1311
}
@@ -113,17 +111,20 @@ struct ConnectionStateMachine {
113111
}
114112

115113
private var state: State
114+
private let requireBackendKeyData: Bool
116115
private var taskQueue = CircularBuffer<PSQLTask>()
117116
private var quiescingState: QuiescingState = .notQuiescing
118117

119-
init() {
118+
init(requireBackendKeyData: Bool) {
120119
self.state = .initialized
120+
self.requireBackendKeyData = requireBackendKeyData
121121
}
122122

123123
#if DEBUG
124124
/// for testing purposes only
125-
init(_ state: State) {
125+
init(_ state: State, requireBackendKeyData: Bool = true) {
126126
self.state = state
127+
self.requireBackendKeyData = requireBackendKeyData
127128
}
128129
#endif
129130

@@ -543,14 +544,12 @@ struct ConnectionStateMachine {
543544
mutating func readyForQueryReceived(_ transactionState: PostgresBackendMessage.TransactionState) -> ConnectionAction {
544545
switch self.state {
545546
case .authenticated(let backendKeyData, let parameters):
546-
guard let keyData = backendKeyData else {
547-
// `backendKeyData` must have been received, before receiving the first `readyForQuery`
547+
if self.requireBackendKeyData && backendKeyData == nil {
548548
return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.readyForQuery(transactionState)))
549549
}
550550

551551
let connectionContext = ConnectionContext(
552-
processID: keyData.processID,
553-
secretKey: keyData.secretKey,
552+
backendKeyData: backendKeyData,
554553
parameters: parameters,
555554
transactionState: transactionState)
556555

@@ -1314,8 +1313,8 @@ extension ConnectionStateMachine.State: CustomDebugStringConvertible {
13141313
extension ConnectionStateMachine.ConnectionContext: CustomDebugStringConvertible {
13151314
var debugDescription: String {
13161315
"""
1317-
(processID: \(self.processID), \
1318-
secretKey: \(self.secretKey), \
1316+
(processID: \(self.backendKeyData?.processID != nil ? String(self.backendKeyData!.processID) : "nil")), \
1317+
secretKey: \(self.backendKeyData?.secretKey != nil ? String(self.backendKeyData!.secretKey) : "nil")), \
13191318
parameters: \(String(reflecting: self.parameters)))
13201319
"""
13211320
}

Sources/PostgresNIO/New/PostgresChannelHandler.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
3232
logger: Logger,
3333
configureSSLCallback: ((Channel) throws -> Void)?)
3434
{
35-
self.state = ConnectionStateMachine()
35+
self.state = ConnectionStateMachine(requireBackendKeyData: configuration.requireBackendKeyData)
3636
self.configuration = configuration
3737
self.configureSSLCallback = configureSSLCallback
3838
self.logger = logger

Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ class AuthenticationStateMachineTests: XCTestCase {
77
func testAuthenticatePlaintext() {
88
let authContext = AuthContext(username: "test", password: "abc123", database: "test")
99

10-
var state = ConnectionStateMachine()
10+
var state = ConnectionStateMachine(requireBackendKeyData: true)
1111
XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext)
1212

1313
XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext))
@@ -17,7 +17,7 @@ class AuthenticationStateMachineTests: XCTestCase {
1717

1818
func testAuthenticateMD5() {
1919
let authContext = AuthContext(username: "test", password: "abc123", database: "test")
20-
var state = ConnectionStateMachine()
20+
var state = ConnectionStateMachine(requireBackendKeyData: true)
2121
XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext)
2222
let salt: (UInt8, UInt8, UInt8, UInt8) = (0, 1, 2, 3)
2323

@@ -28,7 +28,7 @@ class AuthenticationStateMachineTests: XCTestCase {
2828

2929
func testAuthenticateMD5WithoutPassword() {
3030
let authContext = AuthContext(username: "test", password: nil, database: "test")
31-
var state = ConnectionStateMachine()
31+
var state = ConnectionStateMachine(requireBackendKeyData: true)
3232
XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext)
3333
let salt: (UInt8, UInt8, UInt8, UInt8) = (0, 1, 2, 3)
3434

@@ -39,15 +39,15 @@ class AuthenticationStateMachineTests: XCTestCase {
3939

4040
func testAuthenticateOkAfterStartUpWithoutAuthChallenge() {
4141
let authContext = AuthContext(username: "test", password: "abc123", database: "test")
42-
var state = ConnectionStateMachine()
42+
var state = ConnectionStateMachine(requireBackendKeyData: true)
4343
XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext)
4444
XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext))
4545
XCTAssertEqual(state.authenticationMessageReceived(.ok), .wait)
4646
}
4747

4848
func testAuthenticationFailure() {
4949
let authContext = AuthContext(username: "test", password: "abc123", database: "test")
50-
var state = ConnectionStateMachine()
50+
var state = ConnectionStateMachine(requireBackendKeyData: true)
5151
XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext)
5252
let salt: (UInt8, UInt8, UInt8, UInt8) = (0, 1, 2, 3)
5353

@@ -79,7 +79,7 @@ class AuthenticationStateMachineTests: XCTestCase {
7979

8080
for (message, mechanism) in unsupported {
8181
let authContext = AuthContext(username: "test", password: "abc123", database: "test")
82-
var state = ConnectionStateMachine()
82+
var state = ConnectionStateMachine(requireBackendKeyData: true)
8383
XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext)
8484
XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext))
8585
XCTAssertEqual(state.authenticationMessageReceived(message),
@@ -98,7 +98,7 @@ class AuthenticationStateMachineTests: XCTestCase {
9898

9999
for message in unexpected {
100100
let authContext = AuthContext(username: "test", password: "abc123", database: "test")
101-
var state = ConnectionStateMachine()
101+
var state = ConnectionStateMachine(requireBackendKeyData: true)
102102
XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext)
103103
XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext))
104104
XCTAssertEqual(state.authenticationMessageReceived(message),
@@ -125,7 +125,7 @@ class AuthenticationStateMachineTests: XCTestCase {
125125

126126
for message in unexpected {
127127
let authContext = AuthContext(username: "test", password: "abc123", database: "test")
128-
var state = ConnectionStateMachine()
128+
var state = ConnectionStateMachine(requireBackendKeyData: true)
129129
XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext)
130130
XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext))
131131
XCTAssertEqual(state.authenticationMessageReceived(.md5(salt: salt)), .sendPasswordMessage(.md5(salt: salt), authContext))

Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ class ConnectionStateMachineTests: XCTestCase {
88

99
func testStartup() {
1010
let authContext = AuthContext(username: "test", password: "abc123", database: "test")
11-
var state = ConnectionStateMachine()
11+
var state = ConnectionStateMachine(requireBackendKeyData: true)
1212
XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext)
1313
XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext))
1414
XCTAssertEqual(state.authenticationMessageReceived(.plaintext), .sendPasswordMessage(.cleartext, authContext))
@@ -17,7 +17,7 @@ class ConnectionStateMachineTests: XCTestCase {
1717

1818
func testSSLStartupSuccess() {
1919
let authContext = AuthContext(username: "test", password: "abc123", database: "test")
20-
var state = ConnectionStateMachine()
20+
var state = ConnectionStateMachine(requireBackendKeyData: true)
2121
XCTAssertEqual(state.connected(tls: .require), .sendSSLRequest)
2222
XCTAssertEqual(state.sslSupportedReceived(), .establishSSLConnection)
2323
XCTAssertEqual(state.sslHandlerAdded(), .wait)
@@ -30,23 +30,23 @@ class ConnectionStateMachineTests: XCTestCase {
3030
func testSSLStartupFailHandler() {
3131
struct SSLHandlerAddError: Error, Equatable {}
3232

33-
var state = ConnectionStateMachine()
33+
var state = ConnectionStateMachine(requireBackendKeyData: true)
3434
XCTAssertEqual(state.connected(tls: .require), .sendSSLRequest)
3535
XCTAssertEqual(state.sslSupportedReceived(), .establishSSLConnection)
3636
let failError = PSQLError.failedToAddSSLHandler(underlying: SSLHandlerAddError())
3737
XCTAssertEqual(state.errorHappened(failError), .closeConnectionAndCleanup(.init(action: .close, tasks: [], error: failError, closePromise: nil)))
3838
}
3939

4040
func testTLSRequiredStartupSSLUnsupported() {
41-
var state = ConnectionStateMachine()
41+
var state = ConnectionStateMachine(requireBackendKeyData: true)
4242

4343
XCTAssertEqual(state.connected(tls: .require), .sendSSLRequest)
4444
XCTAssertEqual(state.sslUnsupportedReceived(),
4545
.closeConnectionAndCleanup(.init(action: .close, tasks: [], error: PSQLError.sslUnsupported, closePromise: nil)))
4646
}
4747

4848
func testTLSPreferredStartupSSLUnsupported() {
49-
var state = ConnectionStateMachine()
49+
var state = ConnectionStateMachine(requireBackendKeyData: true)
5050

5151
XCTAssertEqual(state.connected(tls: .prefer), .sendSSLRequest)
5252
XCTAssertEqual(state.sslUnsupportedReceived(), .provideAuthenticationContext)
@@ -92,7 +92,7 @@ class ConnectionStateMachineTests: XCTestCase {
9292
}
9393

9494
func testReadyForQueryReceivedWithoutBackendKeyAfterAuthenticated() {
95-
var state = ConnectionStateMachine(.authenticated(nil, [:]))
95+
var state = ConnectionStateMachine(.authenticated(nil, [:]), requireBackendKeyData: true)
9696

9797
XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "DateStyle", value: "ISO, MDY")), .wait)
9898
XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "application_name", value: "")), .wait)
@@ -110,6 +110,24 @@ class ConnectionStateMachineTests: XCTestCase {
110110
.closeConnectionAndCleanup(.init(action: .close, tasks: [], error: PSQLError.unexpectedBackendMessage(.readyForQuery(.idle)), closePromise: nil)))
111111
}
112112

113+
func testReadyForQueryReceivedWithoutUnneededBackendKeyAfterAuthenticated() {
114+
var state = ConnectionStateMachine(.authenticated(nil, [:]), requireBackendKeyData: false)
115+
116+
XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "DateStyle", value: "ISO, MDY")), .wait)
117+
XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "application_name", value: "")), .wait)
118+
XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "server_encoding", value: "UTF8")), .wait)
119+
XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "integer_datetimes", value: "on")), .wait)
120+
XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "client_encoding", value: "UTF8")), .wait)
121+
XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "TimeZone", value: "Etc/UTC")), .wait)
122+
XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "is_superuser", value: "on")), .wait)
123+
XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "server_version", value: "13.1 (Debian 13.1-1.pgdg100+1)")), .wait)
124+
XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "session_authorization", value: "postgres")), .wait)
125+
XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "IntervalStyle", value: "postgres")), .wait)
126+
XCTAssertEqual(state.parameterStatusReceived(.init(parameter: "standard_conforming_strings", value: "on")), .wait)
127+
128+
XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery)
129+
}
130+
113131
func testErrorIsIgnoredWhenClosingConnection() {
114132
// test ignore unclean shutdown when closing connection
115133
var stateIgnoreChannelError = ConnectionStateMachine(.closing)
@@ -133,7 +151,7 @@ class ConnectionStateMachineTests: XCTestCase {
133151

134152
let queryPromise = eventLoopGroup.next().makePromise(of: PSQLRowStream.self)
135153

136-
var state = ConnectionStateMachine()
154+
var state = ConnectionStateMachine(requireBackendKeyData: true)
137155
let extendedQueryContext = ExtendedQueryContext(
138156
query: "Select version()",
139157
logger: .psqlTest,

Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ extension ConnectionStateMachine {
7575
}
7676

7777
static func createConnectionContext(transactionState: PostgresBackendMessage.TransactionState = .idle) -> ConnectionContext {
78+
let backendKeyData = BackendKeyData(processID: 2730, secretKey: 882037977)
79+
7880
let paramaters = [
7981
"DateStyle": "ISO, MDY",
8082
"application_name": "",
@@ -90,8 +92,7 @@ extension ConnectionStateMachine {
9092
]
9193

9294
return ConnectionContext(
93-
processID: 2730,
94-
secretKey: 882037977,
95+
backendKeyData: backendKeyData,
9596
parameters: paramaters,
9697
transactionState: transactionState
9798
)

Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,8 @@ class PostgresChannelHandlerTests: XCTestCase {
174174
database: String = "postgres",
175175
password: String = "password",
176176
tls: PostgresConnection.Configuration.TLS = .disable,
177-
connectTimeout: TimeAmount = .seconds(10)
177+
connectTimeout: TimeAmount = .seconds(10),
178+
requireBackendKeyData: Bool = true
178179
) -> PostgresConnection.InternalConfiguration {
179180
let authentication = PostgresConnection.Configuration.Authentication(
180181
username: username,
@@ -186,7 +187,8 @@ class PostgresChannelHandlerTests: XCTestCase {
186187
connection: .unresolved(host: host, port: port),
187188
connectTimeout: connectTimeout,
188189
authentication: authentication,
189-
tls: tls
190+
tls: tls,
191+
requireBackendKeyData: requireBackendKeyData
190192
)
191193
}
192194
}

0 commit comments

Comments
 (0)