Skip to content

Add SSLContext configuration per KMS provider #820

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion driver-core/src/main/com/mongodb/AutoEncryptionSettings.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@
import com.mongodb.lang.Nullable;
import org.bson.BsonDocument;

import javax.net.ssl.SSLContext;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

import static com.mongodb.assertions.Assertions.notNull;
import static java.util.Collections.unmodifiableMap;

/**
* The client-side automatic encryption settings. Client side encryption enables an application to specify what fields in a collection
Expand Down Expand Up @@ -58,6 +61,7 @@ public final class AutoEncryptionSettings {
private final MongoClientSettings keyVaultMongoClientSettings;
private final String keyVaultNamespace;
private final Map<String, Map<String, Object>> kmsProviders;
private final Map<String, SSLContext> kmsProviderSslContextMap;
private final Map<String, BsonDocument> schemaMap;
private final Map<String, Object> extraOptions;
private final boolean bypassAutoEncryption;
Expand All @@ -71,6 +75,7 @@ public static final class Builder {
private MongoClientSettings keyVaultMongoClientSettings;
private String keyVaultNamespace;
private Map<String, Map<String, Object>> kmsProviders;
private Map<String, SSLContext> kmsProviderSslContextMap = new HashMap<>();
private Map<String, BsonDocument> schemaMap = Collections.emptyMap();
private Map<String, Object> extraOptions = Collections.emptyMap();
private boolean bypassAutoEncryption;
Expand Down Expand Up @@ -111,6 +116,19 @@ public Builder kmsProviders(final Map<String, Map<String, Object>> kmsProviders)
return this;
}

/**
* Sets the KMS provider to SSLContext map
*
* @param kmsProviderSslContextMap the KMS provider to SSLContext map, which may not be null
* @return this
* @see #getKmsProviderSslContextMap()
* @since 4.4
*/
public Builder kmsProviderSslContextMap(final Map<String, SSLContext> kmsProviderSslContextMap) {
this.kmsProviderSslContextMap = notNull("kmsProviderSslContextMap", kmsProviderSslContextMap);
return this;
}

/**
* Sets the map from namespace to local schema document
*
Expand Down Expand Up @@ -250,7 +268,22 @@ public String getKeyVaultNamespace() {
* @return map of KMS provider properties
*/
public Map<String, Map<String, Object>> getKmsProviders() {
return kmsProviders;
return unmodifiableMap(kmsProviders);
}

/**
* Gets the KMS provider to SSLContext map.
*
* <p>
* If a KMS provider is mapped to a non-null {@link SSLContext}, the context will be used to establish a TLS connection to the KMS.
* Otherwise, the default context will be used.
* </p>
*
* @return the KMS provider to SSLContext map
* @since 4.4
*/
public Map<String, SSLContext> getKmsProviderSslContextMap() {
return unmodifiableMap(kmsProviderSslContextMap);
}

/**
Expand Down Expand Up @@ -321,6 +354,7 @@ private AutoEncryptionSettings(final Builder builder) {
this.keyVaultMongoClientSettings = builder.keyVaultMongoClientSettings;
this.keyVaultNamespace = notNull("keyVaultNamespace", builder.keyVaultNamespace);
this.kmsProviders = notNull("kmsProviders", builder.kmsProviders);
this.kmsProviderSslContextMap = notNull("kmsProviderSslContextMap", builder.kmsProviderSslContextMap);
this.schemaMap = notNull("schemaMap", builder.schemaMap);
this.extraOptions = notNull("extraOptions", builder.extraOptions);
this.bypassAutoEncryption = builder.bypassAutoEncryption;
Expand Down
37 changes: 35 additions & 2 deletions driver-core/src/main/com/mongodb/ClientEncryptionSettings.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@

import com.mongodb.annotations.NotThreadSafe;

import javax.net.ssl.SSLContext;
import java.util.HashMap;
import java.util.Map;

import static com.mongodb.assertions.Assertions.notNull;
import static java.util.Collections.unmodifiableMap;

/**
* The client-side settings for data key creation and explicit encryption.
Expand All @@ -36,7 +39,7 @@ public final class ClientEncryptionSettings {
private final MongoClientSettings keyVaultMongoClientSettings;
private final String keyVaultNamespace;
private final Map<String, Map<String, Object>> kmsProviders;

private final Map<String, SSLContext> kmsProviderSslContextMap;
/**
* A builder for {@code ClientEncryptionSettings} so that {@code ClientEncryptionSettings} can be immutable, and to support easier
* construction through chaining.
Expand All @@ -46,6 +49,7 @@ public static final class Builder {
private MongoClientSettings keyVaultMongoClientSettings;
private String keyVaultNamespace;
private Map<String, Map<String, Object>> kmsProviders;
private Map<String, SSLContext> kmsProviderSslContextMap = new HashMap<>();

/**
* Sets the key vault settings.
Expand Down Expand Up @@ -83,6 +87,19 @@ public Builder kmsProviders(final Map<String, Map<String, Object>> kmsProviders)
return this;
}

/**
* Sets the KMS provider to SSLContext map
*
* @param kmsProviderSslContextMap the KMS provider to SSLContext map, which may not be null
* @return this
* @see #getKmsProviderSslContextMap()
* @since 4.4
*/
public Builder kmsProviderSslContextMap(final Map<String, SSLContext> kmsProviderSslContextMap) {
this.kmsProviderSslContextMap = notNull("kmsProviderSslContextMap", kmsProviderSslContextMap);
return this;
}

/**
* Build an instance of {@code ClientEncryptionSettings}.
*
Expand Down Expand Up @@ -184,13 +201,29 @@ public String getKeyVaultNamespace() {
* @return map of KMS provider properties
*/
public Map<String, Map<String, Object>> getKmsProviders() {
return kmsProviders;
return unmodifiableMap(kmsProviders);
}

/**
* Gets the KMS provider to SSLContext map.
*
* <p>
* If a KMS provider is mapped to a non-null {@link SSLContext}, the context will be used to establish a TLS connection to the KMS.
* Otherwise, the default context will be used.
* </p>
*
* @return the KMS provider to SSLContext map
* @since 4.4
*/
public Map<String, SSLContext> getKmsProviderSslContextMap() {
return unmodifiableMap(kmsProviderSslContextMap);
}

private ClientEncryptionSettings(final Builder builder) {
this.keyVaultMongoClientSettings = builder.keyVaultMongoClientSettings;
this.keyVaultNamespace = notNull("keyVaultNamespace", builder.keyVaultNamespace);
this.kmsProviders = notNull("kmsProviders", builder.kmsProviders);
this.kmsProviderSslContextMap = notNull("kmsProviderSslContextMap", builder.kmsProviderSslContextMap);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import javax.net.ssl.SSLContext;
import java.security.NoSuchAlgorithmException;
import java.util.Map;

import static com.mongodb.internal.capi.MongoCryptHelper.createMongoCryptOptions;

Expand All @@ -54,7 +55,7 @@ public static Crypt createCrypt(final MongoClientImpl client, final AutoEncrypti
options.isBypassAutoEncryption() ? null : new CollectionInfoRetriever(collectionInfoRetrieverClient),
new CommandMarker(options.isBypassAutoEncryption(), options.getExtraOptions()),
new KeyRetriever(keyVaultClient, new MongoNamespace(options.getKeyVaultNamespace())),
createKeyManagementService(),
createKeyManagementService(options.getKmsProviderSslContextMap()),
options.isBypassAutoEncryption(),
internalClient);
}
Expand All @@ -63,11 +64,11 @@ public static Crypt create(final MongoClient keyVaultClient, final ClientEncrypt
return new Crypt(MongoCrypts.create(
createMongoCryptOptions(options.getKmsProviders(), null)),
new KeyRetriever(keyVaultClient, new MongoNamespace(options.getKeyVaultNamespace())),
createKeyManagementService());
createKeyManagementService(options.getKmsProviderSslContextMap()));
}

private static KeyManagementService createKeyManagementService() {
return new KeyManagementService(getSslContext(), 443, 10000);
private static KeyManagementService createKeyManagementService(final Map<String, SSLContext> kmsProviderSslContextMap) {
return new KeyManagementService(kmsProviderSslContextMap, 10000);
}

private static SSLContext getSslContext() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import com.mongodb.connection.StreamFactory;
import com.mongodb.connection.TlsChannelStreamFactoryFactory;
import com.mongodb.crypt.capi.MongoKeyDecryptor;
import com.mongodb.diagnostics.logging.Logger;
import com.mongodb.diagnostics.logging.Loggers;
import com.mongodb.internal.connection.AsynchronousChannelStream;
import org.bson.ByteBuf;
import org.bson.ByteBufNIO;
Expand All @@ -35,36 +37,41 @@
import java.io.Closeable;
import java.nio.channels.CompletionHandler;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.Map;

import static java.util.Collections.singletonList;
import static java.util.concurrent.TimeUnit.MILLISECONDS;

class KeyManagementService implements Closeable {
private final int defaultPort;
private static final Logger LOGGER = Loggers.getLogger("client");
private final Map<String, SSLContext> kmsProviderSslContextMap;
private final int timeoutMillis;
private final TlsChannelStreamFactoryFactory tlsChannelStreamFactoryFactory;
private final StreamFactory streamFactory;

KeyManagementService(final SSLContext sslContext, final int defaultPort, final int timeoutMillis) {
this.defaultPort = defaultPort;
KeyManagementService(final Map<String, SSLContext> kmsProviderSslContextMap, final int timeoutMillis) {
this.kmsProviderSslContextMap = kmsProviderSslContextMap;
this.tlsChannelStreamFactoryFactory = new TlsChannelStreamFactoryFactory();
this.streamFactory = tlsChannelStreamFactoryFactory.create(SocketSettings.builder()
.connectTimeout(timeoutMillis, TimeUnit.MILLISECONDS)
.readTimeout(timeoutMillis, TimeUnit.MILLISECONDS)
.build(),
SslSettings.builder().enabled(true).context(sslContext).build());
this.timeoutMillis = timeoutMillis;
}

public void close() {
tlsChannelStreamFactoryFactory.close();
}

Mono<Void> decryptKey(final MongoKeyDecryptor keyDecryptor) {
SocketSettings socketSettings = SocketSettings.builder()
.connectTimeout(timeoutMillis, MILLISECONDS)
.readTimeout(timeoutMillis, MILLISECONDS)
.build();
StreamFactory streamFactory = tlsChannelStreamFactoryFactory.create(socketSettings,
SslSettings.builder().enabled(true).context(kmsProviderSslContextMap.get(keyDecryptor.getKmsProvider())).build());

ServerAddress serverAddress = new ServerAddress(keyDecryptor.getHostName());

LOGGER.info("Connecting to KMS server at " + serverAddress);

return Mono.<Void>create(sink -> {
ServerAddress serverAddress = keyDecryptor.getHostName().contains(":")
? new ServerAddress(keyDecryptor.getHostName())
: new ServerAddress(keyDecryptor.getHostName(), defaultPort);
final Stream stream = streamFactory.create(serverAddress);
Stream stream = streamFactory.create(serverAddress);
stream.openAsync(new AsyncCompletionHandler<Void>() {
@Override
public void completed(final Void ignored) {
Expand Down
3 changes: 2 additions & 1 deletion driver-sync/src/main/com/mongodb/client/internal/Crypt.java
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,8 @@ private void decryptKeys(final MongoCryptContext cryptContext) {
}

private void decryptKey(final MongoKeyDecryptor keyDecryptor) throws IOException {
InputStream inputStream = keyManagementService.stream(keyDecryptor.getHostName(), keyDecryptor.getMessage());
InputStream inputStream = keyManagementService.stream(keyDecryptor.getKmsProvider(), keyDecryptor.getHostName(),
keyDecryptor.getMessage());
try {
int bytesNeeded = keyDecryptor.bytesNeeded();

Expand Down
21 changes: 5 additions & 16 deletions driver-sync/src/main/com/mongodb/client/internal/Crypts.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,14 @@

import com.mongodb.AutoEncryptionSettings;
import com.mongodb.ClientEncryptionSettings;
import com.mongodb.MongoClientException;
import com.mongodb.MongoClientSettings;
import com.mongodb.MongoNamespace;
import com.mongodb.client.MongoClient;
import com.mongodb.client.MongoClients;
import com.mongodb.crypt.capi.MongoCrypts;

import javax.net.ssl.SSLContext;
import java.security.NoSuchAlgorithmException;
import java.util.Map;

import static com.mongodb.internal.capi.MongoCryptHelper.createMongoCryptOptions;

Expand All @@ -50,7 +49,7 @@ public static Crypt createCrypt(final MongoClientImpl client, final AutoEncrypti
options.isBypassAutoEncryption() ? null : new CollectionInfoRetriever(collectionInfoRetrieverClient),
new CommandMarker(options.isBypassAutoEncryption(), options.getExtraOptions()),
new KeyRetriever(keyVaultClient, new MongoNamespace(options.getKeyVaultNamespace())),
createKeyManagementService(),
createKeyManagementService(options.getKmsProviderSslContextMap()),
options.isBypassAutoEncryption(),
internalClient);
}
Expand All @@ -59,26 +58,16 @@ static Crypt create(final MongoClient keyVaultClient, final ClientEncryptionSett
return new Crypt(MongoCrypts.create(
createMongoCryptOptions(options.getKmsProviders(), null)),
createKeyRetriever(keyVaultClient, options.getKeyVaultNamespace()),
createKeyManagementService());
createKeyManagementService(options.getKmsProviderSslContextMap()));
}

private static KeyRetriever createKeyRetriever(final MongoClient keyVaultClient,
final String keyVaultNamespaceString) {
return new KeyRetriever(keyVaultClient, new MongoNamespace(keyVaultNamespaceString));
}

private static KeyManagementService createKeyManagementService() {
return new KeyManagementService(getSslContext(), 443, 10000);
}

private static SSLContext getSslContext() {
SSLContext sslContext;
try {
sslContext = SSLContext.getDefault();
} catch (NoSuchAlgorithmException e) {
throw new MongoClientException("Unable to create default SSLContext", e);
}
return sslContext;
private static KeyManagementService createKeyManagementService(final Map<String, SSLContext> kmsProviderSslContextMap) {
return new KeyManagementService(kmsProviderSslContextMap, 10000);
}

private Crypts() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,36 +17,48 @@
package com.mongodb.client.internal;

import com.mongodb.ServerAddress;
import com.mongodb.diagnostics.logging.Logger;
import com.mongodb.diagnostics.logging.Loggers;
import com.mongodb.internal.connection.SslHelper;

import javax.net.SocketFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLParameters;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.nio.ByteBuffer;
import java.util.Map;

import static com.mongodb.assertions.Assertions.notNull;

class KeyManagementService {
private final SSLContext sslContext;
private final int defaultPort;
private static final Logger LOGGER = Loggers.getLogger("client");
private final Map<String, SSLContext> kmsProviderSslContextMap;
private final int timeoutMillis;

KeyManagementService(final SSLContext sslContext, final int defaultPort, final int timeoutMillis) {
this.sslContext = sslContext;
this.defaultPort = defaultPort;
KeyManagementService(final Map<String, SSLContext> kmsProviderSslContextMap, final int timeoutMillis) {
this.kmsProviderSslContextMap = notNull("kmsProviderSslContextMap", kmsProviderSslContextMap);
this.timeoutMillis = timeoutMillis;
}

public InputStream stream(final String host, final ByteBuffer message) throws IOException {
ServerAddress serverAddress = host.contains(":") ? new ServerAddress(host) : new ServerAddress(host, defaultPort);
SSLSocket socket = (SSLSocket) sslContext.getSocketFactory().createSocket();
public InputStream stream(final String kmsProvider, final String host, final ByteBuffer message) throws IOException {
ServerAddress serverAddress = new ServerAddress(host);

LOGGER.info("Connecting to KMS server at " + serverAddress);
SSLContext sslContext = kmsProviderSslContextMap.get(kmsProvider);

SocketFactory sslSocketFactory = sslContext == null
? SSLSocketFactory.getDefault() : sslContext.getSocketFactory();
SSLSocket socket = (SSLSocket) sslSocketFactory.createSocket();
enableHostNameVerification(socket);

try {
enableHostNameVerification(socket);
socket.setSoTimeout(timeoutMillis);
socket.connect(new InetSocketAddress(InetAddress.getByName(serverAddress.getHost()), serverAddress.getPort()), timeoutMillis);
} catch (IOException e) {
Expand Down Expand Up @@ -83,10 +95,6 @@ private void enableHostNameVerification(final SSLSocket socket) {
socket.setSSLParameters(sslParameters);
}

public int getDefaultPort() {
return defaultPort;
}

private void closeSocket(final Socket socket) {
try {
socket.close();
Expand Down
Loading