|
| 1 | +import NIOCore |
| 2 | +import NIOPosix // inet_pton() et al. |
| 3 | +import NIOSSL |
| 4 | + |
| 5 | +extension PostgresConnection { |
| 6 | + /// A configuration object for a connection |
| 7 | + public struct Configuration { |
| 8 | + |
| 9 | + // MARK: - TLS |
| 10 | + |
| 11 | + /// The possible modes of operation for TLS encapsulation of a connection. |
| 12 | + public struct TLS { |
| 13 | + // MARK: Initializers |
| 14 | + |
| 15 | + /// Do not try to create a TLS connection to the server. |
| 16 | + public static var disable: Self = .init(base: .disable) |
| 17 | + |
| 18 | + /// Try to create a TLS connection to the server. If the server supports TLS, create a TLS connection. |
| 19 | + /// If the server does not support TLS, create an insecure connection. |
| 20 | + public static func prefer(_ sslContext: NIOSSLContext) -> Self { |
| 21 | + self.init(base: .prefer(sslContext)) |
| 22 | + } |
| 23 | + |
| 24 | + /// Try to create a TLS connection to the server. If the server supports TLS, create a TLS connection. |
| 25 | + /// If the server does not support TLS, fail the connection creation. |
| 26 | + public static func require(_ sslContext: NIOSSLContext) -> Self { |
| 27 | + self.init(base: .require(sslContext)) |
| 28 | + } |
| 29 | + |
| 30 | + // MARK: Accessors |
| 31 | + |
| 32 | + /// Whether TLS will be attempted on the connection (`false` only when mode is ``disable``). |
| 33 | + public var isAllowed: Bool { |
| 34 | + if case .disable = self.base { return false } |
| 35 | + else { return true } |
| 36 | + } |
| 37 | + |
| 38 | + /// Whether TLS will be enforced on the connection (`true` only when mode is ``require(_:)``). |
| 39 | + public var isEnforced: Bool { |
| 40 | + if case .require(_) = self.base { return true } |
| 41 | + else { return false } |
| 42 | + } |
| 43 | + |
| 44 | + /// The `NIOSSLContext` that will be used. `nil` when TLS is disabled. |
| 45 | + public var sslContext: NIOSSLContext? { |
| 46 | + switch self.base { |
| 47 | + case .prefer(let context), .require(let context): return context |
| 48 | + case .disable: return nil |
| 49 | + } |
| 50 | + } |
| 51 | + |
| 52 | + // MARK: Implementation details |
| 53 | + |
| 54 | + enum Base { |
| 55 | + case disable |
| 56 | + case prefer(NIOSSLContext) |
| 57 | + case require(NIOSSLContext) |
| 58 | + } |
| 59 | + let base: Base |
| 60 | + private init(base: Base) { self.base = base } |
| 61 | + } |
| 62 | + |
| 63 | + // MARK: - Connection options |
| 64 | + |
| 65 | + /// Describes options affecting how the underlying connection is made. |
| 66 | + public struct Options { |
| 67 | + /// A timeout for connection attempts. Defaults to ten seconds. |
| 68 | + /// |
| 69 | + /// Ignored when using a preexisting communcation channel. (See |
| 70 | + /// ``PostgresConnection/Configuration/init(establishedChannel:username:password:database:)``.) |
| 71 | + public var connectTimeout: TimeAmount |
| 72 | + |
| 73 | + /// The server name to use for certificate validation and SNI (Server Name Indication) when TLS is enabled. |
| 74 | + /// Defaults to none (but see below). |
| 75 | + /// |
| 76 | + /// > When set to `nil`: |
| 77 | + /// If the connection is made to a server over TCP using |
| 78 | + /// ``PostgresConnection/Configuration/init(host:port:username:password:database:tls:)``, the given `host` |
| 79 | + /// is used, unless it was an IP address string. If it _was_ an IP, or the connection is made by any other |
| 80 | + /// method, SNI is disabled. |
| 81 | + public var tlsServerName: String? |
| 82 | + |
| 83 | + /// Whether the connection is required to provide backend key data (internal Postgres stuff). |
| 84 | + /// |
| 85 | + /// This property is provided for compatibility with Amazon RDS Proxy, which requires it to be `false`. |
| 86 | + /// If you are not using Amazon RDS Proxy, you should leave this set to `true` (the default). |
| 87 | + public var requireBackendKeyData: Bool |
| 88 | + |
| 89 | + /// Create an options structure with default values. |
| 90 | + /// |
| 91 | + /// Most users should not need to adjust the defaults. |
| 92 | + public init() { |
| 93 | + self.connectTimeout = .seconds(10) |
| 94 | + self.tlsServerName = nil |
| 95 | + self.requireBackendKeyData = true |
| 96 | + } |
| 97 | + } |
| 98 | + |
| 99 | + // MARK: - Accessors |
| 100 | + |
| 101 | + /// The hostname to connect to for TCP configurations. |
| 102 | + /// |
| 103 | + /// Always `nil` for other configurations. |
| 104 | + public var host: String? { |
| 105 | + if case let .connectTCP(host, _) = self.endpointInfo { return host } |
| 106 | + else { return nil } |
| 107 | + } |
| 108 | + |
| 109 | + /// The port to connect to for TCP configurations. |
| 110 | + /// |
| 111 | + /// Always `nil` for other configurations. |
| 112 | + public var port: Int? { |
| 113 | + if case let .connectTCP(_, port) = self.endpointInfo { return port } |
| 114 | + else { return nil } |
| 115 | + } |
| 116 | + |
| 117 | + /// The socket path to connect to for Unix domain socket connections. |
| 118 | + /// |
| 119 | + /// Always `nil` for other configurations. |
| 120 | + public var unixSocketPath: String? { |
| 121 | + if case let .bindUnixDomainSocket(path) = self.endpointInfo { return path } |
| 122 | + else { return nil } |
| 123 | + } |
| 124 | + |
| 125 | + /// The `Channel` to use in existing-channel configurations. |
| 126 | + /// |
| 127 | + /// Always `nil` for other configurations. |
| 128 | + public var establishedChannel: Channel? { |
| 129 | + if case let .configureChannel(channel) = self.endpointInfo { return channel } |
| 130 | + else { return nil } |
| 131 | + } |
| 132 | + |
| 133 | + /// The TLS mode to use for the connection. Valid for all configurations. |
| 134 | + /// |
| 135 | + /// See ``TLS-swift.struct``. |
| 136 | + public var tls: TLS |
| 137 | + |
| 138 | + /// Options for handling the communication channel. Most users don't need to change these. |
| 139 | + /// |
| 140 | + /// See ``Options-swift.struct``. |
| 141 | + public var options: Options = .init() |
| 142 | + |
| 143 | + /// The username to connect with. |
| 144 | + public var username: String |
| 145 | + |
| 146 | + /// The password, if any, for the user specified by ``username``. |
| 147 | + /// |
| 148 | + /// - Warning: `nil` means "no password provided", whereas `""` (the empty string) is a password of zero |
| 149 | + /// length; these are not the same thing. |
| 150 | + public var password: String? |
| 151 | + |
| 152 | + /// The name of the database to open. |
| 153 | + /// |
| 154 | + /// - Note: If set to `nil` or an empty string, the provided ``username`` is used. |
| 155 | + public var database: String? |
| 156 | + |
| 157 | + // MARK: - Initializers |
| 158 | + |
| 159 | + /// Create a configuration for connecting to a server with a hostname and optional port. |
| 160 | + /// |
| 161 | + /// This specifies a TCP connection. If you're unsure which kind of connection you want, you almost |
| 162 | + /// definitely want this one. |
| 163 | + /// |
| 164 | + /// - Parameters: |
| 165 | + /// - host: The hostname to connect to. |
| 166 | + /// - port: The TCP port to connect to (defaults to 5432). |
| 167 | + /// - tls: The TLS mode to use. |
| 168 | + public init(host: String, port: Int = 5432, username: String, password: String?, database: String?, tls: TLS) { |
| 169 | + self.init(endpointInfo: .connectTCP(host: host, port: port), tls: tls, username: username, password: password, database: database) |
| 170 | + } |
| 171 | + |
| 172 | + /// Create a configuration for connecting to a server through a UNIX domain socket. |
| 173 | + /// |
| 174 | + /// - Parameters: |
| 175 | + /// - path: The filesystem path of the socket to connect to. |
| 176 | + /// - tls: The TLS mode to use. Defaults to ``TLS-swift.struct/disable``. |
| 177 | + public init(unixSocketPath: String, username: String, password: String?, database: String?) { |
| 178 | + self.init(endpointInfo: .bindUnixDomainSocket(path: unixSocketPath), tls: .disable, username: username, password: password, database: database) |
| 179 | + } |
| 180 | + |
| 181 | + /// Create a configuration for establishing a connection to a Postgres server over a preestablished |
| 182 | + /// `NIOCore/Channel`. |
| 183 | + /// |
| 184 | + /// This is provided for calling code which wants to manage the underlying connection transport on its |
| 185 | + /// own, such as when tunneling a connection through SSH. |
| 186 | + /// |
| 187 | + /// - Parameters: |
| 188 | + /// - channel: The `NIOCore/Channel` to use. The channel must already be active and connected to an |
| 189 | + /// endpoint (i.e. `NIOCore/Channel/isActive` must be `true`). |
| 190 | + /// - tls: The TLS mode to use. Defaults to ``TLS-swift.struct/disable``. |
| 191 | + public init(establishedChannel channel: Channel, username: String, password: String?, database: String?) { |
| 192 | + self.init(endpointInfo: .configureChannel(channel), tls: .disable, username: username, password: password, database: database) |
| 193 | + } |
| 194 | + |
| 195 | + // MARK: - Implementation details |
| 196 | + |
| 197 | + enum EndpointInfo { |
| 198 | + case configureChannel(Channel) |
| 199 | + case bindUnixDomainSocket(path: String) |
| 200 | + case connectTCP(host: String, port: Int) |
| 201 | + } |
| 202 | + |
| 203 | + var endpointInfo: EndpointInfo |
| 204 | + |
| 205 | + init(endpointInfo: EndpointInfo, tls: TLS, username: String, password: String?, database: String?) { |
| 206 | + self.endpointInfo = endpointInfo |
| 207 | + self.tls = tls |
| 208 | + self.username = username |
| 209 | + self.password = password |
| 210 | + self.database = database |
| 211 | + } |
| 212 | + } |
| 213 | +} |
| 214 | + |
| 215 | +// MARK: - Internal config details |
| 216 | + |
| 217 | +extension PostgresConnection { |
| 218 | + /// A configuration object to bring the new ``PostgresConnection.Configuration`` together with |
| 219 | + /// the deprecated configuration. |
| 220 | + /// |
| 221 | + /// TODO: Drop with next major release |
| 222 | + struct InternalConfiguration { |
| 223 | + enum Connection { |
| 224 | + case unresolvedTCP(host: String, port: Int) |
| 225 | + case unresolvedUDS(path: String) |
| 226 | + case resolved(address: SocketAddress) |
| 227 | + case bootstrapped(channel: Channel) |
| 228 | + } |
| 229 | + |
| 230 | + let connection: InternalConfiguration.Connection |
| 231 | + let username: String? |
| 232 | + let password: String? |
| 233 | + let database: String? |
| 234 | + var tls: Configuration.TLS |
| 235 | + let options: Configuration.Options |
| 236 | + } |
| 237 | +} |
| 238 | + |
| 239 | +extension PostgresConnection.InternalConfiguration { |
| 240 | + init(_ config: PostgresConnection.Configuration) { |
| 241 | + switch config.endpointInfo { |
| 242 | + case .connectTCP(let host, let port): self.connection = .unresolvedTCP(host: host, port: port) |
| 243 | + case .bindUnixDomainSocket(let path): self.connection = .unresolvedUDS(path: path) |
| 244 | + case .configureChannel(let channel): self.connection = .bootstrapped(channel: channel) |
| 245 | + } |
| 246 | + self.username = config.username |
| 247 | + self.password = config.password |
| 248 | + self.database = config.database |
| 249 | + self.tls = config.tls |
| 250 | + self.options = config.options |
| 251 | + } |
| 252 | + |
| 253 | + var serverNameForTLS: String? { |
| 254 | + // If a name was explicitly configured, always use it. |
| 255 | + if let tlsServerName = self.options.tlsServerName { return tlsServerName } |
| 256 | + |
| 257 | + // Otherwise, if the connection is TCP and the hostname wasn't an IP (not valid in SNI), use that. |
| 258 | + if case .unresolvedTCP(let host, _) = self.connection, !host.isIPAddress() { return host } |
| 259 | + |
| 260 | + // Otherwise, disable SNI |
| 261 | + return nil |
| 262 | + } |
| 263 | +} |
| 264 | + |
| 265 | +// originally taken from NIOSSL |
| 266 | +private extension String { |
| 267 | + func isIPAddress() -> Bool { |
| 268 | + // We need some scratch space to let inet_pton write into. |
| 269 | + var ipv4Addr = in_addr(), ipv6Addr = in6_addr() // inet_pton() assumes the provided address buffer is non-NULL |
| 270 | + |
| 271 | + /// N.B.: ``String/withCString(_:)`` is much more efficient than directly passing `self`, especially twice. |
| 272 | + return self.withCString { ptr in |
| 273 | + inet_pton(AF_INET, ptr, &ipv4Addr) == 1 || inet_pton(AF_INET6, ptr, &ipv6Addr) == 1 |
| 274 | + } |
| 275 | + } |
| 276 | +} |
0 commit comments