Skip to content

Add Human OIDC Workflow #1316

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions driver-core/src/main/com/mongodb/ConnectionString.java
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static com.mongodb.MongoCredential.ALLOWED_HOSTS_KEY;
import static com.mongodb.internal.connection.OidcAuthenticator.OidcValidator.validateCreateOidcCredential;
import static java.lang.String.format;
import static java.util.Arrays.asList;
Expand Down Expand Up @@ -272,6 +275,9 @@ public class ConnectionString {
private static final Set<String> ALLOWED_OPTIONS_IN_TXT_RECORD =
new HashSet<>(asList("authsource", "replicaset", "loadbalanced"));
private static final Logger LOGGER = Loggers.getLogger("uri");
private static final List<String> MECHANISM_KEYS_DISALLOWED_IN_CONNECTION_STRING = Stream.of(ALLOWED_HOSTS_KEY)
.map(k -> k.toLowerCase())
.collect(Collectors.toList());

private final MongoCredential credential;
private final boolean isSrvProtocol;
Expand Down Expand Up @@ -902,6 +908,11 @@ private MongoCredential createCredentials(final Map<String, List<String>> option
}
String key = mechanismPropertyKeyValue[0].trim().toLowerCase();
String value = mechanismPropertyKeyValue[1].trim();
if (MECHANISM_KEYS_DISALLOWED_IN_CONNECTION_STRING.contains(key)) {
throw new IllegalArgumentException(format("The connection string contains disallowed mechanism properties. "
+ "'%s' must be set on the credential programmatically.", key));
}

if (key.equals("canonicalize_host_name")) {
credential = credential.withMechanismProperty(key, Boolean.valueOf(value));
} else {
Expand Down
25 changes: 17 additions & 8 deletions driver-core/src/main/com/mongodb/MongoCredential.java
Original file line number Diff line number Diff line change
Expand Up @@ -245,13 +245,13 @@ public final class MongoCredential {
* The list of allowed hosts that will be used if no
* {@link MongoCredential#ALLOWED_HOSTS_KEY} value is supplied.
* The default allowed hosts are:
* {@code "*.mongodb.net", "*.mongodb-dev.net", "*.mongodbgov.net", "localhost", "127.0.0.1", "::1"}
* {@code "*.mongodb.net", "*.mongodb-qa.net", "*.mongodb-dev.net", "*.mongodbgov.net", "localhost", "127.0.0.1", "::1"}
*
* @see #createOidcCredential(String)
* @since 4.10
*/
public static final List<String> DEFAULT_ALLOWED_HOSTS = Collections.unmodifiableList(Arrays.asList(
"*.mongodb.net", "*.mongodb-dev.net", "*.mongodbgov.net", "localhost", "127.0.0.1", "::1"));
"*.mongodb.net", "*.mongodb-qa.net", "*.mongodb-dev.net", "*.mongodbgov.net", "localhost", "127.0.0.1", "::1"));

/**
* Creates a MongoCredential instance with an unspecified mechanism. The client will negotiate the best mechanism based on the
Expand Down Expand Up @@ -708,28 +708,37 @@ public interface IdpInfo {
/**
* The response produced by an OIDC Identity Provider.
*/
@Evolving
public static final class OidcCallbackResult {

private final String accessToken;

@Nullable
private final Duration expiresIn;

@Nullable
private final String refreshToken;

/**
* @param accessToken The OIDC access token
* @param accessToken The OIDC access token.
* @param expiresIn Time until the access token expires. 0 is an infinite duration.
*/
public OidcCallbackResult(final String accessToken) {
this(accessToken, null);
public OidcCallbackResult(final String accessToken, final Duration expiresIn) {
this(accessToken, expiresIn, null);
}

/**
* @param accessToken The OIDC access token
* @param accessToken The OIDC access token.
* @param expiresIn Time until the access token expires. 0 is an infinite duration.
* @param refreshToken The refresh token. If null, refresh will not be attempted.
*/
public OidcCallbackResult(final String accessToken, @Nullable final String refreshToken) {
public OidcCallbackResult(final String accessToken, @Nullable final Duration expiresIn,
@Nullable final String refreshToken) {
notNull("accessToken", accessToken);
if (expiresIn != null && expiresIn.isNegative()) {
throw new IllegalArgumentException("expiresIn must not be a negative value");
}
this.accessToken = accessToken;
this.expiresIn = expiresIn;
this.refreshToken = refreshToken;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,20 +157,20 @@ private boolean isAutomaticAuthentication() {
}

private boolean isHumanCallback() {
return getMechanismProperty(OIDC_HUMAN_CALLBACK_KEY) != null;
return getOidcCallbackMechanismProperty(OIDC_HUMAN_CALLBACK_KEY) != null;
}

@Nullable
private OidcCallback getMechanismProperty(final String key) {
private OidcCallback getOidcCallbackMechanismProperty(final String key) {
return getMongoCredentialWithCache()
.getCredential()
.getMechanismProperty(key, null);
}

@Nullable
private OidcCallback getRequestCallback() {
OidcCallback machine = getMechanismProperty(OIDC_CALLBACK_KEY);
return machine != null ? machine : getMechanismProperty(OIDC_HUMAN_CALLBACK_KEY);
OidcCallback machine = getOidcCallbackMechanismProperty(OIDC_CALLBACK_KEY);
return machine != null ? machine : getOidcCallbackMechanismProperty(OIDC_HUMAN_CALLBACK_KEY);
}

@Override
Expand Down Expand Up @@ -259,7 +259,6 @@ private byte[] evaluate(final byte[] challenge) {
// original IDP info will be present, if refresh token present
assertNotNull(cachedIdpInfo);
// Invoke Callback using cached Refresh Token
validateAllowedHosts(getMongoCredential());
fallbackState = FallbackState.PHASE_2_REFRESH_CALLBACK_TOKEN;
OidcCallbackResult result = requestCallback.onRequest(new OidcCallbackContextImpl(
CALLBACK_TIMEOUT, cachedIdpInfo, cachedRefreshToken));
Expand Down Expand Up @@ -335,26 +334,30 @@ private boolean clientIsComplete() {
}

private boolean shouldRetryHandler() {
MongoCredentialWithCache mongoCredentialWithCache = getMongoCredentialWithCache();
OidcCacheEntry cacheEntry = mongoCredentialWithCache.getOidcCacheEntry();
if (fallbackState == FallbackState.PHASE_1_CACHED_TOKEN) {
// a cached access token failed
mongoCredentialWithCache.setOidcCacheEntry(cacheEntry
.clearAccessToken());
return true;
} else if (fallbackState == FallbackState.PHASE_2_REFRESH_CALLBACK_TOKEN) {
// a refresh token failed
mongoCredentialWithCache.setOidcCacheEntry(cacheEntry
.clearAccessToken()
.clearRefreshToken());
return true;
} else {
// a clean-restart failed
mongoCredentialWithCache.setOidcCacheEntry(cacheEntry
.clearAccessToken()
.clearRefreshToken());
return false;
}
boolean[] result = new boolean[1];
Locks.withLock(getMongoCredentialWithCache().getOidcLock(), () -> {
MongoCredentialWithCache mongoCredentialWithCache = getMongoCredentialWithCache();
OidcCacheEntry cacheEntry = mongoCredentialWithCache.getOidcCacheEntry();
if (fallbackState == FallbackState.PHASE_1_CACHED_TOKEN) {
// a cached access token failed
mongoCredentialWithCache.setOidcCacheEntry(cacheEntry
.clearAccessToken());
result[0] = true;
} else if (fallbackState == FallbackState.PHASE_2_REFRESH_CALLBACK_TOKEN) {
// a refresh token failed
mongoCredentialWithCache.setOidcCacheEntry(cacheEntry
.clearAccessToken()
.clearRefreshToken());
result[0] = true;
} else {
// a clean-restart failed
mongoCredentialWithCache.setOidcCacheEntry(cacheEntry
.clearAccessToken()
.clearRefreshToken());
result[0] = false;
}
});
return result[0];
}

@Nullable
Expand Down Expand Up @@ -516,7 +519,8 @@ private void validateAllowedHosts(final MongoCredential credential) {
});
if (!permitted) {
throw new MongoSecurityException(
credential, "Host not permitted by " + ALLOWED_HOSTS_KEY + ": " + host);
credential, "Host " + host + " not permitted by " + ALLOWED_HOSTS_KEY
+ ", values: " + allowedHosts);
}
}

Expand Down Expand Up @@ -568,30 +572,29 @@ public static void validateCreateOidcCredential(@Nullable final char[] password)
public static void validateBeforeUse(final MongoCredential credential) {
String userName = credential.getUserName();
Object providerName = credential.getMechanismProperty(PROVIDER_NAME_KEY, null);
Object requestCallback = credential.getMechanismProperty(OIDC_CALLBACK_KEY, null);
Object refreshCallback = credential.getMechanismProperty(OIDC_HUMAN_CALLBACK_KEY, null);
Object machineCallback = credential.getMechanismProperty(OIDC_CALLBACK_KEY, null);
Object humanCallback = credential.getMechanismProperty(OIDC_HUMAN_CALLBACK_KEY, null);
if (providerName == null) {
// callback
if (requestCallback == null && refreshCallback == null) {
if (machineCallback == null && humanCallback == null) {
throw new IllegalArgumentException("Either " + PROVIDER_NAME_KEY
+ " or " + OIDC_CALLBACK_KEY
+ " or " + OIDC_HUMAN_CALLBACK_KEY
+ " must be specified");
}
if (requestCallback != null && refreshCallback != null) {
if (machineCallback != null && humanCallback != null) {
throw new IllegalArgumentException("Both " + OIDC_CALLBACK_KEY
+ " and " + OIDC_HUMAN_CALLBACK_KEY
+ " must not be specified");
}
} else {
// automatic
if (userName != null) {
throw new IllegalArgumentException("user name must not be specified when " + PROVIDER_NAME_KEY + " is specified");
}
if (requestCallback != null) {
if (machineCallback != null) {
throw new IllegalArgumentException(OIDC_CALLBACK_KEY + " must not be specified when " + PROVIDER_NAME_KEY + " is specified");
}
if (refreshCallback != null) {
if (humanCallback != null) {
throw new IllegalArgumentException(OIDC_HUMAN_CALLBACK_KEY + " must not be specified when " + PROVIDER_NAME_KEY + " is specified");
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ private void initClient(final BsonDocument entity, final String id,
} catch (IOException e) {
throw new RuntimeException(e);
}
return new MongoCredential.OidcCallbackResult(accessToken);
return new MongoCredential.OidcCallbackResult(accessToken, null);
}));
break;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,12 @@
import java.net.URISyntaxException;
import java.util.Collection;

import static com.mongodb.ClusterFixture.isServerlessTest;
import static org.junit.Assume.assumeFalse;

public class UnifiedAuthTest extends UnifiedSyncTest {
public UnifiedAuthTest(@SuppressWarnings("unused") final String fileDescription,
@SuppressWarnings("unused") final String testDescription,
final String schemaVersion, final BsonArray runOnRequirements, final BsonArray entitiesArray,
final BsonArray initialData, final BsonDocument definition) {
super(schemaVersion, runOnRequirements, entitiesArray, initialData, definition);
assumeFalse(isServerlessTest());
}

@Parameterized.Parameters(name = "{0}: {1}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ public void test2p3CallbackReturnsMissingData() {
// conforming to the OIDCRequestTokenResult with missing field(s).
OidcCallback onRequest = (context) -> {
//noinspection ConstantConditions
return new OidcCallbackResult(null);
return new OidcCallbackResult(null, Duration.ZERO);
};
// we ensure that the error is propagated
MongoClientSettings clientSettings = createSettings(getOidcUri(), onRequest, null);
Expand Down Expand Up @@ -268,7 +268,7 @@ public void test3p1AuthFailsWithCachedToken() throws ExecutionException, Interru
@Test
public void test3p2AuthFailsWithoutCachedToken() {
MongoClientSettings clientSettings = createSettings(getOidcUri(),
(x) -> new OidcCallbackResult("invalid_token"), null);
(x) -> new OidcCallbackResult("invalid_token", Duration.ZERO), null);
try (MongoClient mongoClient = createMongoClient(clientSettings)) {
try {
performFind(mongoClient);
Expand Down Expand Up @@ -358,7 +358,7 @@ public void testh1p6AllowedHostsBlocked() {
MongoClientSettings settings1 = createSettings(
getOidcUri(),
createHumanCallback(), null, OIDC_HUMAN_CALLBACK_KEY, Collections.emptyList());
performFind(settings1, MongoSecurityException.class, "Host not permitted by ALLOWED_HOSTS");
performFind(settings1, MongoSecurityException.class, "not permitted by ALLOWED_HOSTS");

//- Create a client that uses the URL
// ``mongodb://localhost/?authMechanism=MONGODB-OIDC&ignored=example.com``, a
Expand All @@ -367,7 +367,16 @@ public void testh1p6AllowedHostsBlocked() {
MongoClientSettings settings2 = createSettings(
getOidcUri() + "&ignored=example.com",
createHumanCallback(), null, OIDC_HUMAN_CALLBACK_KEY, Arrays.asList("example.com"));
performFind(settings2, MongoSecurityException.class, "Host not permitted by ALLOWED_HOSTS");
performFind(settings2, MongoSecurityException.class, "not permitted by ALLOWED_HOSTS");
}

// Not a prose test
@Test
public void testAllowedHostsDisallowedInConnectionString() {
String string = "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=ALLOWED_HOSTS:localhost";
assertCause(IllegalArgumentException.class,
"connection string contains disallowed mechanism properties",
() -> new ConnectionString(string));
}

@Test
Expand Down Expand Up @@ -397,13 +406,13 @@ public void testh2p2HumanCallbackReturnsMissingData() {
"Result of callback must not be null");

//noinspection ConstantConditions
OidcCallback onRequest = (context) -> new OidcCallbackResult(null);
OidcCallback onRequest = (context) -> new OidcCallbackResult(null, Duration.ZERO);
performFind(createHumanSettings(getOidcUri(), onRequest, null),
IllegalArgumentException.class,
"accessToken can not be null");

// additionally, check validation for refresh in machine workflow:
OidcCallback onRequestMachineRefresh = (context) -> new OidcCallbackResult("access", "exists");
OidcCallback onRequestMachineRefresh = (context) -> new OidcCallbackResult("access", Duration.ZERO, "exists");
performFind(createSettings(getOidcUri(), onRequestMachineRefresh, null),
MongoConfigurationException.class,
"Refresh token must only be provided in human workflow");
Expand Down Expand Up @@ -789,7 +798,7 @@ private OidcCallbackResult callback() {
if (testListener != null) {
testListener.add("read access token: " + path.getFileName());
}
return new OidcCallbackResult(accessToken, refreshToken);
return new OidcCallbackResult(accessToken, Duration.ZERO, refreshToken);
} finally {
if (concurrentTracker != null) {
concurrentTracker.decrementAndGet();
Expand Down