Skip to content

Allow client assertion to be a callback and a per-request parameter #482

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@
<dependency>
<groupId>com.azure</groupId>
<artifactId>azure-security-keyvault-secrets</artifactId>
<version>4.3.6</version>
<version>4.3.5</version>
<scope>test</scope>
</dependency>
<dependency>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import java.security.UnrecoverableKeyException;
import java.security.cert.CertificateException;
import java.util.Collections;
import java.util.concurrent.Callable;

import static com.microsoft.aad.msal4j.TestConstants.KEYVAULT_DEFAULT_SCOPE;

Expand Down Expand Up @@ -48,16 +49,35 @@ public void acquireTokenClientCredentials_ClientSecret() throws Exception {
public void acquireTokenClientCredentials_ClientAssertion() throws Exception {
String clientId = "2afb0add-2f32-4946-ac90-81a02aa4550e";

ClientAssertion clientAssertion = JwtHelper.buildJwt(
clientId,
(ClientCertificate) certificate,
"https://login.microsoftonline.com/common/oauth2/v2.0/token",
true);
ClientAssertion clientAssertion = getClientAssertion(clientId);

IClientCredential credential = ClientCredentialFactory.createFromClientAssertion(clientAssertion.assertion());

assertAcquireTokenCommon(clientId, credential);
}

@Test
public void acquireTokenClientCredentials_Callback() throws Exception {
String clientId = "2afb0add-2f32-4946-ac90-81a02aa4550e";

// Creates a valid client assertion using a callback, and uses it to build the client app and make a request
Callable<String> callable = () -> {
ClientAssertion clientAssertion = getClientAssertion(clientId);

IClientCredential credential = ClientCredentialFactory.createFromClientAssertion(
clientAssertion.assertion());
return clientAssertion.assertion();
};

IClientCredential credential = ClientCredentialFactory.createFromCallback(callable);

assertAcquireTokenCommon(clientId, credential);

// Creates an invalid client assertion to build the application, but overrides it with a valid client assertion
// in the request parameters in order to make a successful token request
ClientAssertion invalidClientAssertion = getClientAssertion("abc");

IClientCredential invalidCredentials = ClientCredentialFactory.createFromClientAssertion(invalidClientAssertion.assertion());

assertAcquireTokenCommon_withParameters(clientId, invalidCredentials, credential);
}

@Test
Expand Down Expand Up @@ -98,6 +118,13 @@ public void acquireTokenClientCredentials_DefaultCacheLookup() throws Exception
Assert.assertNotEquals(result2.accessToken(), result3.accessToken());
}

private ClientAssertion getClientAssertion(String clientId) {
return JwtHelper.buildJwt(
clientId,
(ClientCertificate) certificate,
"https://login.microsoftonline.com/common/oauth2/v2.0/token",
true);
}

private void assertAcquireTokenCommon(String clientId, IClientCredential credential) throws Exception {
ConfidentialClientApplication cca = ConfidentialClientApplication.builder(
Expand All @@ -113,4 +140,20 @@ private void assertAcquireTokenCommon(String clientId, IClientCredential credent
Assert.assertNotNull(result);
Assert.assertNotNull(result.accessToken());
}

private void assertAcquireTokenCommon_withParameters(String clientId, IClientCredential credential, IClientCredential credentialParam) throws Exception {

ConfidentialClientApplication cca = ConfidentialClientApplication.builder(
clientId, credential).
authority(TestConstants.MICROSOFT_AUTHORITY).
build();

IAuthenticationResult result = cca.acquireToken(ClientCredentialParameters
.builder(Collections.singleton(KEYVAULT_DEFAULT_SCOPE)).clientCredential(credentialParam)
.build())
.get();

Assert.assertNotNull(result);
Assert.assertNotNull(result.accessToken());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

import static com.microsoft.aad.msal4j.ParameterValidationUtils.validateNotNull;

Expand All @@ -29,7 +34,7 @@ public static IClientSecret createFromSecret(String secret) {
}

/**
* Static method to create a {@link ClientCertificate} instance from a certificate
* Static method to create a {@link ClientCertificate} instance from a password-protected certificate.
*
* @param pkcs12Certificate InputStream containing PCKS12 formatted certificate
* @param password certificate password
Expand All @@ -48,7 +53,7 @@ public static IClientCertificate createFromCertificate(final InputStream pkcs12C
}

/**
* Static method to create a {@link ClientCertificate} instance.
* Static method to create a {@link ClientCertificate} instance from a private key/public certificate pair.
*
* @param key RSA private key to sign the assertion.
* @param publicKeyCertificate x509 public certificate used for thumbprint
Expand All @@ -61,7 +66,7 @@ public static IClientCertificate createFromCertificate(final PrivateKey key, fin
}

/**
* Static method to create a {@link ClientCertificate} instance.
* Static method to create a {@link ClientCertificate} instance from a certificate chain.
*
* @param key RSA private key to sign the assertion.
* @param publicKeyCertificateChain ordered with the user's certificate first followed by zero or more certificate authorities
Expand All @@ -75,12 +80,26 @@ public static IClientCertificate createFromCertificateChain(PrivateKey key, List
}

/**
* Static method to create a {@link ClientAssertion} instance.
* Static method to create a {@link ClientAssertion} instance from a JWT token encoded as a base64 URL encoded string.
*
* @param clientAssertion Jwt token encoded as a base64 URL encoded string
* @param clientAssertion JWT token encoded as a base64 URL encoded string
* @return {@link ClientAssertion}
*/
public static IClientAssertion createFromClientAssertion(String clientAssertion) {
return new ClientAssertion(clientAssertion);
}

/**
* Static method to create a {@link ClientAssertion} instance from a provided Callable.
*
* @param callable Callable that produces a JWT token encoded as a base64 URL encoded string
* @return {@link ClientAssertion}
*/
public static IClientAssertion createFromCallback(Callable<String> callable) throws ExecutionException, InterruptedException {
ExecutorService executor = Executors.newSingleThreadExecutor();

Future<String> future = executor.submit(callable);

return new ClientAssertion(future.get());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ public class ClientCredentialParameters implements IAcquireTokenParameters {
*/
private String tenant;

/**
* Overrides the client credentials for this request
*/
private IClientCredential clientCredential;

private static ClientCredentialParametersBuilder builder() {

return new ClientCredentialParametersBuilder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ private ClientAuthentication buildValidClientCertificateAuthority() {
return createClientAuthFromClientAssertion(clientAssertion);
}

private ClientAuthentication createClientAuthFromClientAssertion(
protected ClientAuthentication createClientAuthFromClientAssertion(
final ClientAssertion clientAssertion) {
final Map<String, List<String>> map = new HashMap<>();
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,15 @@ OAuthHttpRequest createOauthHttpRequest() throws SerializeException, MalformedUR
oauthHttpRequest.setQuery(URLUtils.serializeParameters(params));

if (msalRequest.application().clientAuthentication() != null) {
msalRequest.application().clientAuthentication().applyTo(oauthHttpRequest);
// If the client application has a client assertion to apply to the request, check if a new client assertion
// was supplied as a request parameter. If so, use the request's assertion instead of the application's
if (msalRequest instanceof ClientCredentialRequest && ((ClientCredentialRequest) msalRequest).parameters.clientCredential() != null) {
((ConfidentialClientApplication) msalRequest.application())
.createClientAuthFromClientAssertion((ClientAssertion) ((ClientCredentialRequest) msalRequest).parameters.clientCredential())
.applyTo(oauthHttpRequest);
} else {
msalRequest.application().clientAuthentication().applyTo(oauthHttpRequest);
}
}
return oauthHttpRequest;
}
Expand Down