Skip to content

Add SSLContext configuration per KMS provider #820

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 9, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions driver-core/src/main/com/mongodb/AutoEncryptionSettings.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
import com.mongodb.lang.Nullable;
import org.bson.BsonDocument;

import javax.net.ssl.SSLContext;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

import static com.mongodb.assertions.Assertions.notNull;
Expand Down Expand Up @@ -58,6 +60,7 @@ public final class AutoEncryptionSettings {
private final MongoClientSettings keyVaultMongoClientSettings;
private final String keyVaultNamespace;
private final Map<String, Map<String, Object>> kmsProviders;
private final Map<String, SSLContext> kmsProviderSslContextMap;
private final Map<String, BsonDocument> schemaMap;
private final Map<String, Object> extraOptions;
private final boolean bypassAutoEncryption;
Expand All @@ -71,6 +74,7 @@ public static final class Builder {
private MongoClientSettings keyVaultMongoClientSettings;
private String keyVaultNamespace;
private Map<String, Map<String, Object>> kmsProviders;
private Map<String, SSLContext> kmsProviderSslContextMap = new HashMap<>();
private Map<String, BsonDocument> schemaMap = Collections.emptyMap();
private Map<String, Object> extraOptions = Collections.emptyMap();
private boolean bypassAutoEncryption;
Expand Down Expand Up @@ -111,6 +115,19 @@ public Builder kmsProviders(final Map<String, Map<String, Object>> kmsProviders)
return this;
}

/**
* TODO
*
* @param kmsProviderSslContextMap TODO
* @return this
* @see #getKmsProviderSslContextMap()
* @since 4.4
*/
public Builder kmsProviderSslContextMap(final Map<String, SSLContext> kmsProviderSslContextMap) {
this.kmsProviderSslContextMap = notNull("kmsProviderSslContextMap", kmsProviderSslContextMap);
return this;
}

/**
* Sets the map from namespace to local schema document
*
Expand Down Expand Up @@ -253,6 +270,16 @@ public Map<String, Map<String, Object>> getKmsProviders() {
return kmsProviders;
}

/**
* TODO
*
* @return TODO
* @since 4.4
*/
public Map<String, SSLContext> getKmsProviderSslContextMap() {
return kmsProviderSslContextMap;
}

/**
* Gets the map of namespace to local JSON schema.
* <p>
Expand Down Expand Up @@ -321,6 +348,7 @@ private AutoEncryptionSettings(final Builder builder) {
this.keyVaultMongoClientSettings = builder.keyVaultMongoClientSettings;
this.keyVaultNamespace = notNull("keyVaultNamespace", builder.keyVaultNamespace);
this.kmsProviders = notNull("kmsProviders", builder.kmsProviders);
this.kmsProviderSslContextMap = notNull("kmsProviderSslContextMap", builder.kmsProviderSslContextMap);
this.schemaMap = notNull("schemaMap", builder.schemaMap);
this.extraOptions = notNull("extraOptions", builder.extraOptions);
this.bypassAutoEncryption = builder.bypassAutoEncryption;
Expand Down
29 changes: 28 additions & 1 deletion driver-core/src/main/com/mongodb/ClientEncryptionSettings.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

import com.mongodb.annotations.NotThreadSafe;

import javax.net.ssl.SSLContext;
import java.util.HashMap;
import java.util.Map;

import static com.mongodb.assertions.Assertions.notNull;
Expand All @@ -36,7 +38,7 @@ public final class ClientEncryptionSettings {
private final MongoClientSettings keyVaultMongoClientSettings;
private final String keyVaultNamespace;
private final Map<String, Map<String, Object>> kmsProviders;

private final Map<String, SSLContext> kmsProviderSslContextMap;
/**
* A builder for {@code ClientEncryptionSettings} so that {@code ClientEncryptionSettings} can be immutable, and to support easier
* construction through chaining.
Expand All @@ -46,6 +48,7 @@ public static final class Builder {
private MongoClientSettings keyVaultMongoClientSettings;
private String keyVaultNamespace;
private Map<String, Map<String, Object>> kmsProviders;
private Map<String, SSLContext> kmsProviderSslContextMap = new HashMap<>();

/**
* Sets the key vault settings.
Expand Down Expand Up @@ -83,6 +86,19 @@ public Builder kmsProviders(final Map<String, Map<String, Object>> kmsProviders)
return this;
}

/**
* TODO
*
* @param kmsProviderSslContextMap TODO
* @return this
* @see #getKmsProviderSslContextMap()
* @since 4.4
*/
public Builder kmsProviderSslContextMap(final Map<String, SSLContext> kmsProviderSslContextMap) {
this.kmsProviderSslContextMap = notNull("kmsProviderSslContextMap", kmsProviderSslContextMap);
return this;
}

/**
* Build an instance of {@code ClientEncryptionSettings}.
*
Expand Down Expand Up @@ -187,10 +203,21 @@ public Map<String, Map<String, Object>> getKmsProviders() {
return kmsProviders;
}

/**
* TODO
*
* @return TODO
* @since 4.4
*/
public Map<String, SSLContext> getKmsProviderSslContextMap() {
return kmsProviderSslContextMap;
}

private ClientEncryptionSettings(final Builder builder) {
this.keyVaultMongoClientSettings = builder.keyVaultMongoClientSettings;
this.keyVaultNamespace = notNull("keyVaultNamespace", builder.keyVaultNamespace);
this.kmsProviders = notNull("kmsProviders", builder.kmsProviders);
this.kmsProviderSslContextMap = notNull("kmsProviderSslContextMap", builder.kmsProviderSslContextMap);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import javax.net.ssl.SSLContext;
import java.security.NoSuchAlgorithmException;
import java.util.Map;

import static com.mongodb.internal.capi.MongoCryptHelper.createMongoCryptOptions;

Expand All @@ -54,7 +55,7 @@ public static Crypt createCrypt(final MongoClientImpl client, final AutoEncrypti
options.isBypassAutoEncryption() ? null : new CollectionInfoRetriever(collectionInfoRetrieverClient),
new CommandMarker(options.isBypassAutoEncryption(), options.getExtraOptions()),
new KeyRetriever(keyVaultClient, new MongoNamespace(options.getKeyVaultNamespace())),
createKeyManagementService(),
createKeyManagementService(options.getKmsProviderSslContextMap()),
options.isBypassAutoEncryption(),
internalClient);
}
Expand All @@ -63,11 +64,11 @@ public static Crypt create(final MongoClient keyVaultClient, final ClientEncrypt
return new Crypt(MongoCrypts.create(
createMongoCryptOptions(options.getKmsProviders(), null)),
new KeyRetriever(keyVaultClient, new MongoNamespace(options.getKeyVaultNamespace())),
createKeyManagementService());
createKeyManagementService(options.getKmsProviderSslContextMap()));
}

private static KeyManagementService createKeyManagementService() {
return new KeyManagementService(getSslContext(), 443, 10000);
private static KeyManagementService createKeyManagementService(final Map<String, SSLContext> kmsProviderSslContextMap) {
return new KeyManagementService(kmsProviderSslContextMap, 10000);
}

private static SSLContext getSslContext() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.mongodb.reactivestreams.client.internal.crypt;

import com.mongodb.MongoClientException;
import com.mongodb.MongoSocketException;
import com.mongodb.ServerAddress;
import com.mongodb.connection.AsyncCompletionHandler;
Expand All @@ -34,37 +35,41 @@
import javax.net.ssl.SSLContext;
import java.io.Closeable;
import java.nio.channels.CompletionHandler;
import java.security.NoSuchAlgorithmException;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.Map;

import static java.util.Collections.singletonList;
import static java.util.concurrent.TimeUnit.MILLISECONDS;

class KeyManagementService implements Closeable {
private final int defaultPort;
private final Map<String, SSLContext> kmsProviderSslContextMap;
private final int timeoutMillis;
private final TlsChannelStreamFactoryFactory tlsChannelStreamFactoryFactory;
private final StreamFactory streamFactory;

KeyManagementService(final SSLContext sslContext, final int defaultPort, final int timeoutMillis) {
this.defaultPort = defaultPort;
KeyManagementService(final Map<String, SSLContext> kmsProviderSslContextMap, final int timeoutMillis) {
this.kmsProviderSslContextMap = kmsProviderSslContextMap;
this.tlsChannelStreamFactoryFactory = new TlsChannelStreamFactoryFactory();
this.streamFactory = tlsChannelStreamFactoryFactory.create(SocketSettings.builder()
.connectTimeout(timeoutMillis, TimeUnit.MILLISECONDS)
.readTimeout(timeoutMillis, TimeUnit.MILLISECONDS)
.build(),
SslSettings.builder().enabled(true).context(sslContext).build());
this.timeoutMillis = timeoutMillis;
}

public void close() {
tlsChannelStreamFactoryFactory.close();
}

Mono<Void> decryptKey(final MongoKeyDecryptor keyDecryptor) {
SocketSettings socketSettings = SocketSettings.builder()
.connectTimeout(timeoutMillis, MILLISECONDS)
.readTimeout(timeoutMillis, MILLISECONDS)
.build();
StreamFactory streamFactory = tlsChannelStreamFactoryFactory.create(socketSettings,
SslSettings.builder().enabled(true).context(kmsProviderSslContextMap.get(keyDecryptor.getKmsProvider())).build());

return Mono.<Void>create(sink -> {
ServerAddress serverAddress = keyDecryptor.getHostName().contains(":")
? new ServerAddress(keyDecryptor.getHostName())
: new ServerAddress(keyDecryptor.getHostName(), defaultPort);
final Stream stream = streamFactory.create(serverAddress);
: new ServerAddress(keyDecryptor.getHostName(), 443); // TODO: default to 443 is weird?
Stream stream = streamFactory.create(serverAddress);
stream.openAsync(new AsyncCompletionHandler<Void>() {
@Override
public void completed(final Void ignored) {
Expand Down Expand Up @@ -129,6 +134,16 @@ public void failed(final Throwable t, final Void aVoid) {
}
}

private static SSLContext getDefaultSslContext() {
SSLContext sslContext;
try {
sslContext = SSLContext.getDefault();
} catch (NoSuchAlgorithmException e) {
throw new MongoClientException("Unable to create default SSLContext", e);
}
return sslContext;
}

private Throwable unWrapException(final Throwable t) {
return t instanceof MongoSocketException ? t.getCause() : t;
}
Expand Down
3 changes: 2 additions & 1 deletion driver-sync/src/main/com/mongodb/client/internal/Crypt.java
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,8 @@ private void decryptKeys(final MongoCryptContext cryptContext) {
}

private void decryptKey(final MongoKeyDecryptor keyDecryptor) throws IOException {
InputStream inputStream = keyManagementService.stream(keyDecryptor.getHostName(), keyDecryptor.getMessage());
InputStream inputStream = keyManagementService.stream(keyDecryptor.getKmsProvider(), keyDecryptor.getHostName(),
keyDecryptor.getMessage());
try {
int bytesNeeded = keyDecryptor.bytesNeeded();

Expand Down
21 changes: 5 additions & 16 deletions driver-sync/src/main/com/mongodb/client/internal/Crypts.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,14 @@

import com.mongodb.AutoEncryptionSettings;
import com.mongodb.ClientEncryptionSettings;
import com.mongodb.MongoClientException;
import com.mongodb.MongoClientSettings;
import com.mongodb.MongoNamespace;
import com.mongodb.client.MongoClient;
import com.mongodb.client.MongoClients;
import com.mongodb.crypt.capi.MongoCrypts;

import javax.net.ssl.SSLContext;
import java.security.NoSuchAlgorithmException;
import java.util.Map;

import static com.mongodb.internal.capi.MongoCryptHelper.createMongoCryptOptions;

Expand All @@ -50,7 +49,7 @@ public static Crypt createCrypt(final MongoClientImpl client, final AutoEncrypti
options.isBypassAutoEncryption() ? null : new CollectionInfoRetriever(collectionInfoRetrieverClient),
new CommandMarker(options.isBypassAutoEncryption(), options.getExtraOptions()),
new KeyRetriever(keyVaultClient, new MongoNamespace(options.getKeyVaultNamespace())),
createKeyManagementService(),
createKeyManagementService(options.getKmsProviderSslContextMap()),
options.isBypassAutoEncryption(),
internalClient);
}
Expand All @@ -59,26 +58,16 @@ static Crypt create(final MongoClient keyVaultClient, final ClientEncryptionSett
return new Crypt(MongoCrypts.create(
createMongoCryptOptions(options.getKmsProviders(), null)),
createKeyRetriever(keyVaultClient, options.getKeyVaultNamespace()),
createKeyManagementService());
createKeyManagementService(options.getKmsProviderSslContextMap()));
}

private static KeyRetriever createKeyRetriever(final MongoClient keyVaultClient,
final String keyVaultNamespaceString) {
return new KeyRetriever(keyVaultClient, new MongoNamespace(keyVaultNamespaceString));
}

private static KeyManagementService createKeyManagementService() {
return new KeyManagementService(getSslContext(), 443, 10000);
}

private static SSLContext getSslContext() {
SSLContext sslContext;
try {
sslContext = SSLContext.getDefault();
} catch (NoSuchAlgorithmException e) {
throw new MongoClientException("Unable to create default SSLContext", e);
}
return sslContext;
private static KeyManagementService createKeyManagementService(final Map<String, SSLContext> kmsProviderSslContextMap) {
return new KeyManagementService(kmsProviderSslContextMap, 10000);
}

private Crypts() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,36 +17,48 @@
package com.mongodb.client.internal;

import com.mongodb.ServerAddress;
import com.mongodb.diagnostics.logging.Logger;
import com.mongodb.diagnostics.logging.Loggers;
import com.mongodb.internal.connection.SslHelper;

import javax.net.SocketFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLParameters;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.nio.ByteBuffer;
import java.util.Map;

import static com.mongodb.assertions.Assertions.notNull;

class KeyManagementService {
private final SSLContext sslContext;
private final int defaultPort;
private static final Logger LOGGER = Loggers.getLogger("client");
private final Map<String, SSLContext> kmsProviderSslContextMap;
private final int timeoutMillis;

KeyManagementService(final SSLContext sslContext, final int defaultPort, final int timeoutMillis) {
this.sslContext = sslContext;
this.defaultPort = defaultPort;
KeyManagementService(final Map<String, SSLContext> kmsProviderSslContextMap, final int timeoutMillis) {
this.kmsProviderSslContextMap = notNull("kmsProviderSslContextMap", kmsProviderSslContextMap);
this.timeoutMillis = timeoutMillis;
}

public InputStream stream(final String host, final ByteBuffer message) throws IOException {
ServerAddress serverAddress = host.contains(":") ? new ServerAddress(host) : new ServerAddress(host, defaultPort);
SSLSocket socket = (SSLSocket) sslContext.getSocketFactory().createSocket();
public InputStream stream(final String kmsProvider, final String host, final ByteBuffer message) throws IOException {
ServerAddress serverAddress = new ServerAddress(host);

LOGGER.info("Connecting to KMS server at " + serverAddress);
SSLContext sslContext = kmsProviderSslContextMap.get(kmsProvider);

SocketFactory sslSocketFactory = sslContext == null
? SSLSocketFactory.getDefault() : sslContext.getSocketFactory();
SSLSocket socket = (SSLSocket) sslSocketFactory.createSocket();
enableHostNameVerification(socket);

try {
enableHostNameVerification(socket);
socket.setSoTimeout(timeoutMillis);
socket.connect(new InetSocketAddress(InetAddress.getByName(serverAddress.getHost()), serverAddress.getPort()), timeoutMillis);
} catch (IOException e) {
Expand Down Expand Up @@ -83,10 +95,6 @@ private void enableHostNameVerification(final SSLSocket socket) {
socket.setSSLParameters(sslParameters);
}

public int getDefaultPort() {
return defaultPort;
}

private void closeSocket(final Socket socket) {
try {
socket.close();
Expand Down
Loading