Skip to content

Commit 6bebbdf

Browse files
Merge pull request #943 from AzureAD/nebharg/PassClaimsAndCapabilities
Token revocation for service fabric
2 parents 2e76834 + 274ffa1 commit 6bebbdf

11 files changed

+214
-14
lines changed

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AbstractManagedIdentitySource.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ public ManagedIdentityResponse getManagedIdentityResponse(
4141
ManagedIdentityParameters parameters) {
4242

4343
createManagedIdentityRequest(parameters.resource);
44+
managedIdentityRequest.addTokenRevocationParametersToQuery(parameters);
4445
IHttpResponse response;
4546

4647
try {

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AcquireTokenByManagedIdentitySupplier.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,12 @@ AuthenticationResult execute() throws Exception {
8383
result.metadata().tokenSource(TokenSource.CACHE);
8484
return result;
8585
} else {
86+
if (cacheRefreshReason == CacheRefreshReason.CLAIMS) {
87+
LOG.debug("Claims are passed, creating token hash and refreshing the token");
88+
managedIdentityParameters.revokedTokenHash = StringHelper.createSha256HashHexString(result.accessToken());
89+
return fetchNewAccessTokenAndSaveToCache(tokenRequestExecutor, CacheRefreshReason.CLAIMS);
90+
}
91+
8692
LOG.debug(String.format("Refreshing access token. Cache refresh reason: %s", cacheRefreshReason));
8793
return fetchNewAccessTokenAndSaveToCache(tokenRequestExecutor, cacheRefreshReason);
8894
}

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AuthenticationErrorCode.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,4 +147,10 @@ public class AuthenticationErrorCode {
147147
* For more information on managed identity see https://aka.ms/msal4j-managed-identity.
148148
*/
149149
public static final String MANAGED_IDENTITY_REQUEST_FAILED = "managed_identity_request_failed";
150+
151+
/**
152+
* Indicates a cryptographic operation error occurred, such as when generating hash values
153+
* or performing other cryptographic functions.
154+
*/
155+
public static final String CRYPTO_ERROR = "crypto_error";
150156
}

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/Constants.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33

44
package com.microsoft.aad.msal4j;
55

6+
import java.util.HashSet;
7+
import java.util.Set;
8+
69
final class Constants {
710

811
static final String CACHE_KEY_SEPARATOR = "-";
@@ -23,4 +26,13 @@ final class Constants {
2326
public static final String MSI_ENDPOINT = "MSI_ENDPOINT";
2427
public static final String IDENTITY_SERVER_THUMBPRINT = "IDENTITY_SERVER_THUMBPRINT";
2528

29+
// Constants for token revocation and client capabilities
30+
public static final String TOKEN_HASH_CLAIM = "token_sha256_to_refresh";
31+
public static final String CLIENT_CAPABILITY_REQUEST_PARAM = "xms_cc";
32+
33+
// Only Service Fabric managed identity environments support token revocation
34+
public static final Set<ManagedIdentitySourceType> TOKEN_REVOCATION_SUPPORTED_ENVIRONMENTS = new HashSet<ManagedIdentitySourceType>() {{
35+
add(ManagedIdentitySourceType.SERVICE_FABRIC);
36+
}};
37+
2638
}

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ManagedIdentityApplication.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
public class ManagedIdentityApplication extends AbstractApplicationBase implements IManagedIdentityApplication {
1818

1919
private final ManagedIdentityId managedIdentityId;
20+
private List<String> clientCapabilities;
2021
static TokenCache sharedTokenCache = new TokenCache();
2122

2223
//Deprecated the field in favor of the static getManagedIdentitySource method
@@ -44,6 +45,7 @@ private ManagedIdentityApplication(Builder builder) {
4445

4546
this.managedIdentityId = builder.managedIdentityId;
4647
this.tenant = Constants.MANAGED_IDENTITY_DEFAULT_TENTANT;
48+
this.clientCapabilities = builder.clientCapabilities;
4749
}
4850

4951
public static TokenCache getSharedTokenCache() {
@@ -57,6 +59,8 @@ static IEnvironmentVariables getEnvironmentVariables() {
5759
public ManagedIdentityId getManagedIdentityId() {
5860
return this.managedIdentityId;
5961
}
62+
63+
public List<String> getClientCapabilities() { return this.clientCapabilities; }
6064

6165
@Override
6266
public CompletableFuture<IAuthenticationResult> acquireTokenForManagedIdentity(ManagedIdentityParameters managedIdentityParameters)
@@ -105,7 +109,7 @@ public Builder resource(String resource) {
105109
/**
106110
* Informs the token issuer that the application is able to perform complex authentication actions.
107111
* For example, "cp1" means that the application is able to perform conditional access evaluation,
108-
* because the application has been setup to parse WWW-Authenticate headers associated with a 401 response from the protected APIs,
112+
* because the application has been set up to parse WWW-Authenticate headers associated with a 401 response from the protected APIs,
109113
* and to retry the request with claims API.
110114
*
111115
* @param clientCapabilities a list of capabilities (e.g., ["cp1"]) recognized by the token service.

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ManagedIdentityParameters.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ public class ManagedIdentityParameters implements IAcquireTokenParameters {
1515
String resource;
1616
boolean forceRefresh;
1717
String claims;
18-
18+
String revokedTokenHash;
19+
1920
private ManagedIdentityParameters(String resource, boolean forceRefresh, String claims) {
2021
this.resource = resource;
2122
this.forceRefresh = forceRefresh;
@@ -78,6 +79,10 @@ public String resource() {
7879
return this.resource;
7980
}
8081

82+
public String revokedTokenHash() {
83+
return this.revokedTokenHash;
84+
}
85+
8186
public static class ManagedIdentityParametersBuilder {
8287
private String resource;
8388
private boolean forceRefresh;

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ManagedIdentityRequest.java

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import java.net.URISyntaxException;
1313
import java.net.URL;
1414
import java.util.Collections;
15+
import java.util.HashMap;
1516
import java.util.List;
1617
import java.util.Map;
1718

@@ -75,4 +76,35 @@ void addUserAssignedIdToQuery(ManagedIdentityIdType idType, String userAssignedI
7576
break;
7677
}
7778
}
79+
80+
void addTokenRevocationParametersToQuery(ManagedIdentityParameters parameters) {
81+
// Check if the environment supports token revocation
82+
ManagedIdentitySourceType sourceType = ManagedIdentityClient.getManagedIdentitySource();
83+
boolean supportsTokenRevocation = Constants.TOKEN_REVOCATION_SUPPORTED_ENVIRONMENTS
84+
.contains(sourceType);
85+
86+
// If token revocation is supported, pass the client capabilities and token revocation parameters
87+
if (supportsTokenRevocation) {
88+
ManagedIdentityApplication managedIdentityApplication =
89+
(ManagedIdentityApplication) this.application();
90+
91+
// Pass capabilities if present.
92+
if (managedIdentityApplication.getClientCapabilities() != null &&
93+
!managedIdentityApplication.getClientCapabilities().isEmpty()) {
94+
// Add client capabilities as a comma separated string for all the values in client capabilities
95+
String clientCapabilities = String.join(",", managedIdentityApplication.getClientCapabilities());
96+
97+
queryParameters.put(Constants.CLIENT_CAPABILITY_REQUEST_PARAM, Collections.singletonList(clientCapabilities.toString()));
98+
}
99+
100+
// Pass the token revocation parameter if the claims are present and there is a token to revoke
101+
if (!StringHelper.isNullOrBlank(parameters.claims) && !StringHelper.isNullOrBlank(parameters.revokedTokenHash())) {
102+
LOG.info("[Managed Identity] Adding token revocation parameter to request");
103+
if (queryParameters == null) {
104+
queryParameters = new HashMap<>();
105+
}
106+
queryParameters.put(Constants.TOKEN_HASH_CLAIM, Collections.singletonList(parameters.revokedTokenHash()));
107+
}
108+
}
109+
}
78110
}

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ServiceFabricManagedIdentitySource.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ public ManagedIdentityResponse getManagedIdentityResponse(
5757
ManagedIdentityParameters parameters) {
5858

5959
createManagedIdentityRequest(parameters.resource);
60+
managedIdentityRequest.addTokenRevocationParametersToQuery(parameters);
6061
IHttpResponse response;
6162

6263
try {

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/StringHelper.java

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,40 @@ static private String createSha256Hash(String stringToHash, boolean base64Encode
4141
return res;
4242
}
4343

44-
public static boolean isNullOrBlank(final String str) {
44+
/**
45+
* Creates a SHA-256 hash of the input string and returns it as a lowercase hex string.
46+
* This is used for token revocation and other scenarios requiring hex hash representation.
47+
*
48+
* @param stringToHash The string to hash
49+
* @return The SHA-256 hash of the string as a lowercase hex string
50+
* @throws MsalClientException If the SHA-256 algorithm is not available
51+
*/
52+
static String createSha256HashHexString(String stringToHash) {
53+
if (stringToHash == null || stringToHash.isEmpty()) {
54+
throw new IllegalArgumentException("String to hash cannot be null or empty");
55+
}
56+
57+
try {
58+
MessageDigest messageDigest = MessageDigest.getInstance("SHA-256");
59+
byte[] hash = messageDigest.digest(stringToHash.getBytes(StandardCharsets.UTF_8));
60+
61+
// Convert to hex string
62+
StringBuilder hexString = new StringBuilder();
63+
for (byte b : hash) {
64+
String hex = Integer.toHexString(0xff & b);
65+
if (hex.length() == 1) {
66+
hexString.append('0');
67+
}
68+
hexString.append(hex);
69+
}
70+
return hexString.toString();
71+
} catch (NoSuchAlgorithmException e) {
72+
throw new MsalClientException("Failed to create SHA-256 hash: " + e.getMessage(),
73+
AuthenticationErrorCode.CRYPTO_ERROR);
74+
}
75+
}
76+
77+
static boolean isNullOrBlank(final String str) {
4578
return str == null || str.trim().length() == 0;
4679
}
4780
}

msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTests.java

Lines changed: 68 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -70,16 +70,34 @@ private String getMsiErrorResponseNoRetry() {
7070
return "{\"statusCode\":\"123\",\"message\":\"Not one of the retryable error responses\",\"correlationId\":\"7d0c9763-ff1d-4842-a3f3-6d49e64f4513\"}";
7171
}
7272

73+
private HttpRequest expectedRequest(ManagedIdentitySourceType source, String resource, boolean hasClaims, boolean hasCapabilities, String expectedTokenHash) {
74+
return expectedRequest(source, resource, ManagedIdentityId.systemAssigned(), hasClaims, hasCapabilities, expectedTokenHash);
75+
}
76+
77+
private HttpRequest expectedRequest(ManagedIdentitySourceType source, String resource, ManagedIdentityId id) {
78+
return expectedRequest(source, resource, id, false, false, null);
79+
}
80+
7381
private HttpRequest expectedRequest(ManagedIdentitySourceType source, String resource) {
74-
return expectedRequest(source, resource, ManagedIdentityId.systemAssigned());
82+
return expectedRequest(source, resource, ManagedIdentityId.systemAssigned(), false, false, null);
7583
}
7684

7785
private HttpRequest expectedRequest(ManagedIdentitySourceType source, String resource,
78-
ManagedIdentityId id) {
86+
ManagedIdentityId id, boolean hasClaims, boolean hasCapabilities, String expectedTokenHash) {
7987
String endpoint = null;
8088
Map<String, String> headers = new HashMap<>();
8189
Map<String, List<String>> queryParameters = new HashMap<>();
8290

91+
if (Constants.TOKEN_REVOCATION_SUPPORTED_ENVIRONMENTS.contains(source)) {
92+
if (hasCapabilities) {
93+
queryParameters.put(Constants.CLIENT_CAPABILITY_REQUEST_PARAM, Collections.singletonList("cp1"));
94+
}
95+
96+
if (hasClaims) {
97+
queryParameters.put(Constants.TOKEN_HASH_CLAIM, Collections.singletonList(expectedTokenHash));
98+
}
99+
}
100+
83101
switch (source) {
84102
case APP_SERVICE:
85103
endpoint = appServiceEndpoint;
@@ -93,12 +111,6 @@ private HttpRequest expectedRequest(ManagedIdentitySourceType source, String res
93111
headers.put("Metadata", "true");
94112
queryParameters.put("resource", Collections.singletonList(resource));
95113
break;
96-
case IMDS:
97-
endpoint = IMDS_ENDPOINT;
98-
queryParameters.put("api-version", Collections.singletonList("2018-02-01"));
99-
queryParameters.put("resource", Collections.singletonList(resource));
100-
headers.put("Metadata", "true");
101-
break;
102114
case AZURE_ARC:
103115
endpoint = azureArcEndpoint;
104116
queryParameters.put("api-version", Collections.singletonList("2019-11-01"));
@@ -111,6 +123,7 @@ private HttpRequest expectedRequest(ManagedIdentitySourceType source, String res
111123
queryParameters.put("resource", Collections.singletonList(resource));
112124
headers.put("secret", "secret");
113125
break;
126+
case IMDS:
114127
case NONE:
115128
case DEFAULT_TO_IMDS:
116129
endpoint = IMDS_ENDPOINT;
@@ -657,6 +670,9 @@ void managedIdentityTest_WithClaims(ManagedIdentitySourceType source, String end
657670
assertNotNull(result.accessToken());
658671
assertEquals(TokenSource.CACHE, result.metadata().tokenSource());
659672

673+
String expectedTokenHash = StringHelper.createSha256HashHexString(result.accessToken());
674+
when(httpClientMock.send(expectedRequest(source, resource, true, false, expectedTokenHash))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource)));
675+
660676
// Third call, when claims are passed bypass the cache.
661677
result = miApp.acquireTokenForManagedIdentity(
662678
ManagedIdentityParameters.builder(resource)
@@ -669,6 +685,46 @@ void managedIdentityTest_WithClaims(ManagedIdentitySourceType source, String end
669685
verify(httpClientMock, times(2)).send(any());
670686
}
671687

688+
@ParameterizedTest
689+
@MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataError")
690+
void managedIdentityTest_WithCapabilitiesOnly(ManagedIdentitySourceType source, String endpoint) throws Exception {
691+
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint);
692+
ManagedIdentityApplication.setEnvironmentVariables(environmentVariables);
693+
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);
694+
if (source == SERVICE_FABRIC) {
695+
ServiceFabricManagedIdentitySource.setHttpClient(httpClientMock);
696+
}
697+
698+
when(httpClientMock.send(expectedRequest(source, resource, false, true, null))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource)));
699+
700+
miApp = ManagedIdentityApplication
701+
.builder(ManagedIdentityId.systemAssigned())
702+
.httpClient(httpClientMock)
703+
.clientCapabilities(singletonList("cp1"))
704+
.build();
705+
706+
// Clear caching to avoid cross test pollution.
707+
miApp.tokenCache().accessTokens.clear();
708+
709+
// First call, get the token from the identity provider.
710+
IAuthenticationResult result = miApp.acquireTokenForManagedIdentity(
711+
ManagedIdentityParameters.builder(resource)
712+
.build()).get();
713+
714+
assertNotNull(result.accessToken());
715+
assertEquals(TokenSource.IDENTITY_PROVIDER, result.metadata().tokenSource());
716+
717+
// Second call, get the token from the cache without passing the claims.
718+
result = miApp.acquireTokenForManagedIdentity(
719+
ManagedIdentityParameters.builder(resource)
720+
.build()).get();
721+
722+
assertNotNull(result.accessToken());
723+
assertEquals(TokenSource.CACHE, result.metadata().tokenSource());
724+
725+
verify(httpClientMock, times(1)).send(any());
726+
}
727+
672728
@ParameterizedTest
673729
@MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataError")
674730
void managedIdentity_ClaimsAndCapabilities(ManagedIdentitySourceType source, String endpoint) throws Exception {
@@ -679,7 +735,7 @@ void managedIdentity_ClaimsAndCapabilities(ManagedIdentitySourceType source, Str
679735
ServiceFabricManagedIdentitySource.setHttpClient(httpClientMock);
680736
}
681737

682-
when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource)));
738+
when(httpClientMock.send(expectedRequest(source, resource, false, true, null))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource)));
683739

684740
miApp = ManagedIdentityApplication
685741
.builder(ManagedIdentityId.systemAssigned())
@@ -707,6 +763,9 @@ void managedIdentity_ClaimsAndCapabilities(ManagedIdentitySourceType source, Str
707763
assertNotNull(result.accessToken());
708764
assertEquals(TokenSource.CACHE, result.metadata().tokenSource());
709765

766+
String expectedTokenHash = StringHelper.createSha256HashHexString(result.accessToken());
767+
when(httpClientMock.send(expectedRequest(source, resource, true, true, expectedTokenHash))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource)));
768+
710769
// Third call, when claims are passed bypass the cache.
711770
result = miApp.acquireTokenForManagedIdentity(
712771
ManagedIdentityParameters.builder(resource)
@@ -715,8 +774,6 @@ void managedIdentity_ClaimsAndCapabilities(ManagedIdentitySourceType source, Str
715774

716775
assertNotNull(result.accessToken());
717776
assertEquals(TokenSource.IDENTITY_PROVIDER, result.metadata().tokenSource());
718-
719-
verify(httpClientMock, times(2)).send(any());
720777
}
721778

722779
@ParameterizedTest
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
package com.microsoft.aad.msal4j;
5+
6+
import org.junit.jupiter.api.Test;
7+
import static org.junit.jupiter.api.Assertions.*;
8+
9+
public class TokenRevocationTest {
10+
11+
private static final String TEST_TOKEN = "test_token";
12+
private static final String EXPECTED_TOKEN_HASH = "cc0af97287543b65da2c7e1476426021826cab166f1e063ed012b855ff819656";
13+
private static final String TEST_RESOURCE = "https://management.azure.com";
14+
@Test
15+
public void testConvertTokenToSHA256Hash() {
16+
String hash = StringHelper.createSha256HashHexString(TEST_TOKEN);
17+
assertEquals(EXPECTED_TOKEN_HASH, hash);
18+
}
19+
20+
@Test
21+
public void testTokenToRevokeValidation() {
22+
// Should throw exception when null
23+
assertThrows(IllegalArgumentException.class, () -> {
24+
StringHelper.createSha256HashHexString(null);
25+
});
26+
27+
// Should throw exception when empty
28+
assertThrows(IllegalArgumentException.class, () -> {
29+
StringHelper.createSha256HashHexString("");
30+
});
31+
}
32+
33+
@Test
34+
public void testManagedIdentityParametersBuilder() {
35+
ManagedIdentityParameters params = ManagedIdentityParameters.builder(TEST_RESOURCE)
36+
.build();
37+
38+
params.revokedTokenHash = StringHelper.createSha256HashHexString(TEST_TOKEN);
39+
40+
assertEquals(TEST_RESOURCE, params.resource());
41+
assertEquals(EXPECTED_TOKEN_HASH, params.revokedTokenHash());
42+
}
43+
}

0 commit comments

Comments
 (0)