Skip to content

Commit 524548f

Browse files
committed
Add azure arc managed identity
1 parent 8eb5a0d commit 524548f

File tree

6 files changed

+230
-1
lines changed

6 files changed

+230
-1
lines changed

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AbstractManagedIdentitySource.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ abstract class AbstractManagedIdentitySource {
2121

2222
protected final ManagedIdentityRequest managedIdentityRequest;
2323
protected final ServiceBundle serviceBundle;
24-
private ManagedIdentitySourceType managedIdentitySourceType;
24+
ManagedIdentitySourceType managedIdentitySourceType;
2525

2626
@Getter
2727
@Setter
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
package com.microsoft.aad.msal4j;
5+
6+
import org.slf4j.Logger;
7+
import org.slf4j.LoggerFactory;
8+
9+
import java.net.HttpURLConnection;
10+
import java.net.URI;
11+
import java.net.URISyntaxException;
12+
import java.util.Collections;
13+
import java.util.HashMap;
14+
15+
class AzureArcManagedIdentitySource extends AbstractManagedIdentitySource{
16+
17+
private final static Logger LOG = LoggerFactory.getLogger(AzureArcManagedIdentitySource.class);
18+
private static final String ARC_API_VERSION = "2019-11-01";
19+
private static final String AZURE_ARC = "Azure Arc";
20+
21+
private final URI MSI_ENDPOINT;
22+
23+
static AbstractManagedIdentitySource create(MsalRequest msalRequest, ServiceBundle serviceBundle)
24+
{
25+
IEnvironmentVariables environmentVariables = getEnvironmentVariables((ManagedIdentityParameters) msalRequest.requestContext().apiParameters());
26+
String identityEndpoint = environmentVariables.getEnvironmentVariable(Constants.IDENTITY_ENDPOINT);
27+
String imdsEndpoint = environmentVariables.getEnvironmentVariable(Constants.IMDS_ENDPOINT);
28+
29+
URI validatedUri = validateAndGetUri(identityEndpoint, imdsEndpoint);
30+
return validatedUri == null ? null : new AzureArcManagedIdentitySource(validatedUri, msalRequest, serviceBundle );
31+
}
32+
33+
private static URI validateAndGetUri(String identityEndpoint, String imdsEndpoint) {
34+
35+
// if BOTH the env vars IDENTITY_ENDPOINT and IMDS_ENDPOINT are set the MsiType is Azure Arc
36+
if (StringHelper.isNullOrBlank(identityEndpoint) || StringHelper.isNullOrBlank(imdsEndpoint))
37+
{
38+
LOG.info("[Managed Identity] Azure Arc managed identity is unavailable.");
39+
return null;
40+
}
41+
42+
URI endpointUri;
43+
try {
44+
endpointUri = new URI(identityEndpoint);
45+
} catch (URISyntaxException e) {
46+
throw new MsalManagedIdentityException(MsalError.INVALID_MANAGED_IDENTITY_ENDPOINT, String.format(
47+
MsalErrorMessage.MANAGED_IDENTITY_ENDPOINT_INVALID_URI_ERROR, "IDENTITY_ENDPOINT", identityEndpoint, AZURE_ARC),
48+
ManagedIdentitySourceType.AzureArc);
49+
}
50+
51+
LOG.info("[Managed Identity] Creating Azure Arc managed identity. Endpoint URI: " + endpointUri);
52+
return endpointUri;
53+
}
54+
55+
private AzureArcManagedIdentitySource(URI endpoint, MsalRequest msalRequest, ServiceBundle serviceBundle){
56+
super(msalRequest, serviceBundle, ManagedIdentitySourceType.AzureArc);
57+
this.MSI_ENDPOINT = endpoint;
58+
59+
ManagedIdentityIdType idType =
60+
((ManagedIdentityApplication) msalRequest.application()).getManagedIdentityId().getIdType();
61+
if (idType != ManagedIdentityIdType.SystemAssigned) {
62+
throw new MsalManagedIdentityException(MsalError.USER_ASSIGNED_MANAGED_IDENTITY_NOT_SUPPORTED,
63+
String.format(MsalErrorMessage.MANAGED_IDENTITY_USER_ASSIGNED_NOT_SUPPORTED, AZURE_ARC),
64+
ManagedIdentitySourceType.CloudShell);
65+
}
66+
}
67+
68+
@Override
69+
public void createManagedIdentityRequest(String resource)
70+
{
71+
managedIdentityRequest.baseEndpoint = MSI_ENDPOINT;
72+
managedIdentityRequest.method = HttpMethod.GET;
73+
74+
managedIdentityRequest.headers = new HashMap<>();
75+
managedIdentityRequest.headers.put("Metadata", "true");
76+
77+
managedIdentityRequest.queryParameters = new HashMap<>();
78+
managedIdentityRequest.queryParameters.put("api-version", Collections.singletonList(ARC_API_VERSION));
79+
managedIdentityRequest.queryParameters.put("resource", Collections.singletonList(resource));
80+
}
81+
82+
@Override
83+
public ManagedIdentityResponse handleResponse(
84+
ManagedIdentityParameters parameters,
85+
IHttpResponse response)
86+
{
87+
LOG.info("[Managed Identity] Response received. Status code: {response.StatusCode}");
88+
89+
if (response.statusCode() == HttpURLConnection.HTTP_UNAUTHORIZED)
90+
{
91+
if(!response.headers().containsKey("WWW-Authenticate")){
92+
LOG.error("[Managed Identity] WWW-Authenticate header is expected but not found.");
93+
throw new MsalManagedIdentityException(MsalError.MANAGED_IDENTITY_REQUEST_FAILED,
94+
MsalErrorMessage.MANAGED_IDENTITY_NO_CHALLENGE_ERROR,
95+
ManagedIdentitySourceType.AzureArc);
96+
}
97+
98+
String challenge = response.headers().get("WWW-Authenticate").get(0);
99+
String[] splitChallenge = challenge.split("=");
100+
101+
if (splitChallenge.length != 2)
102+
{
103+
LOG.error("[Managed Identity] The WWW-Authenticate header for Azure arc managed identity is not an expected format.");
104+
throw new MsalManagedIdentityException(MsalError.MANAGED_IDENTITY_REQUEST_FAILED,
105+
MsalErrorMessage.MANAGED_IDENTITY_INVALID_CHALLENGE,
106+
ManagedIdentitySourceType.AzureArc);
107+
}
108+
109+
String authHeaderValue = "Basic " + splitChallenge[1];
110+
111+
createManagedIdentityRequest(parameters.resource);
112+
113+
LOG.info("[Managed Identity] Adding authorization header to the request.");
114+
115+
managedIdentityRequest.headers.put("Authorization", authHeaderValue);
116+
117+
try {
118+
response = HttpHelper.executeHttpRequest(
119+
new HttpRequest(HttpMethod.GET, managedIdentityRequest.computeURI().toString(),
120+
managedIdentityRequest.headers),
121+
managedIdentityRequest.requestContext(),
122+
serviceBundle);
123+
} catch (URISyntaxException e) {
124+
throw new MsalManagedIdentityException(MsalError.INVALID_MANAGED_IDENTITY_ENDPOINT,
125+
MsalErrorMessage.MANAGED_IDENTITY_ENDPOINT_INVALID_URI_ERROR,
126+
managedIdentitySourceType);
127+
}
128+
129+
return super.handleResponse(parameters, response);
130+
}
131+
132+
return super.handleResponse(parameters, response);
133+
}
134+
}

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ManagedIdentityClient.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ private static AbstractManagedIdentitySource createManagedIdentitySource(MsalReq
3838
return managedIdentitySource;
3939
} else if ((managedIdentitySource = CloudShellManagedIdentitySource.create(msalRequest, serviceBundle)) != null) {
4040
return managedIdentitySource;
41+
} else if ((managedIdentitySource = AzureArcManagedIdentitySource.create(msalRequest, serviceBundle)) != null) {
42+
return managedIdentitySource;
4143
} else {
4244
return new IMDSManagedIdentitySource(msalRequest, serviceBundle);
4345
}

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ManagedIdentityParameters.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
@AllArgsConstructor(access = AccessLevel.PRIVATE)
2323
public class ManagedIdentityParameters implements IAcquireTokenParameters {
2424

25+
@Getter
2526
String resource;
2627

2728
boolean forceRefresh;

msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTestDataProvider.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ public static Stream<Arguments> createData() {
2121
ManagedIdentityTests.resource),
2222
Arguments.of(ManagedIdentitySourceType.CloudShell, ManagedIdentityTests.cloudShellEndpoint,
2323
ManagedIdentityTests.resourceDefaultSuffix),
24+
Arguments.of(ManagedIdentitySourceType.AzureArc, ManagedIdentityTests.azureArcEndpoint,
25+
ManagedIdentityTests.resource),
26+
Arguments.of(ManagedIdentitySourceType.AzureArc, ManagedIdentityTests.azureArcEndpoint,
27+
ManagedIdentityTests.resourceDefaultSuffix),
2428
Arguments.of(ManagedIdentitySourceType.Imds, ManagedIdentityTests.IMDS_ENDPOINT,
2529
ManagedIdentityTests.resource),
2630
Arguments.of(ManagedIdentitySourceType.Imds, ManagedIdentityTests.IMDS_ENDPOINT,
@@ -46,6 +50,10 @@ public static Stream<Arguments> createDataUserAssignedNotSupported() {
4650
Arguments.of(ManagedIdentitySourceType.CloudShell, ManagedIdentityTests.cloudShellEndpoint,
4751
ManagedIdentityId.userAssignedClientId(CLIENT_ID)),
4852
Arguments.of(ManagedIdentitySourceType.CloudShell, ManagedIdentityTests.cloudShellEndpoint,
53+
ManagedIdentityId.userAssignedResourceId(RESOURCE_ID)),
54+
Arguments.of(ManagedIdentitySourceType.AzureArc, ManagedIdentityTests.azureArcEndpoint,
55+
ManagedIdentityId.userAssignedClientId(CLIENT_ID)),
56+
Arguments.of(ManagedIdentitySourceType.AzureArc, ManagedIdentityTests.azureArcEndpoint,
4957
ManagedIdentityId.userAssignedResourceId(RESOURCE_ID)));
5058
}
5159

@@ -59,6 +67,10 @@ public static Stream<Arguments> createDataWrongScope() {
5967
"user.read"),
6068
Arguments.of(ManagedIdentitySourceType.CloudShell, ManagedIdentityTests.cloudShellEndpoint,
6169
"https://management.core.windows.net//user_impersonation"),
70+
Arguments.of(ManagedIdentitySourceType.AzureArc, ManagedIdentityTests.azureArcEndpoint,
71+
"user.read"),
72+
Arguments.of(ManagedIdentitySourceType.AzureArc, ManagedIdentityTests.azureArcEndpoint,
73+
"https://management.core.windows.net//user_impersonation"),
6274
Arguments.of(ManagedIdentitySourceType.Imds, ManagedIdentityTests.IMDS_ENDPOINT,
6375
"user.read"),
6476
Arguments.of(ManagedIdentitySourceType.Imds, ManagedIdentityTests.IMDS_ENDPOINT,
@@ -69,6 +81,7 @@ public static Stream<Arguments> createDataError() {
6981
return Stream.of(
7082
Arguments.of(ManagedIdentitySourceType.AppService, ManagedIdentityTests.appServiceEndpoint),
7183
Arguments.of(ManagedIdentitySourceType.CloudShell, ManagedIdentityTests.cloudShellEndpoint),
84+
Arguments.of(ManagedIdentitySourceType.AzureArc, ManagedIdentityTests.azureArcEndpoint),
7285
Arguments.of(ManagedIdentitySourceType.Imds, ManagedIdentityTests.IMDS_ENDPOINT));
7386
}
7487
}

msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTests.java

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
package com.microsoft.aad.msal4j;
55

66
import com.nimbusds.oauth2.sdk.util.URLUtils;
7+
import org.apache.http.HttpStatus;
8+
import org.junit.jupiter.api.Test;
79
import org.junit.jupiter.api.TestInstance;
810
import org.junit.jupiter.api.extension.ExtendWith;
911
import org.junit.jupiter.params.ParameterizedTest;
@@ -80,6 +82,15 @@ private HttpRequest expectedRequest(ManagedIdentitySourceType source, String res
8082
headers.put("Metadata", "true");
8183
break;
8284
}
85+
case AzureArc: {
86+
endpoint = azureArcEndpoint;
87+
88+
queryParameters.put("api-version", Collections.singletonList("2019-11-01"));
89+
queryParameters.put("resource", Collections.singletonList(resource));
90+
91+
headers.put("Metadata", "true");
92+
break;
93+
}
8394
}
8495

8596
switch (id.getIdType()) {
@@ -182,6 +193,7 @@ void managedIdentityTest_UserAssigned_NotSupported(ManagedIdentitySourceType sou
182193
.build()).get();
183194
} catch (Exception e) {
184195
assertNotNull(e);
196+
assertNotNull(e.getCause());
185197
assertInstanceOf(MsalManagedIdentityException.class, e.getCause());
186198

187199
MsalManagedIdentityException msalMsiException = (MsalManagedIdentityException) e.getCause();
@@ -348,4 +360,71 @@ void managedIdentity_RequestFailed_UnreachableNetwork(ManagedIdentitySourceType
348360

349361
fail("MsalManagedIdentityException is expected but not thrown.");
350362
}
363+
364+
@Test
365+
void azureArcManagedIdentity_MissingAuthHeader() throws Exception {
366+
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(ManagedIdentitySourceType.AzureArc, azureArcEndpoint);
367+
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);
368+
369+
HttpResponse response = new HttpResponse();
370+
response.statusCode(HttpStatus.SC_UNAUTHORIZED);
371+
372+
lenient().when(httpClientMock.send(any())).thenReturn(response);
373+
374+
ManagedIdentityApplication miApp = ManagedIdentityApplication
375+
.builder(ManagedIdentityId.systemAssigned())
376+
.httpClient(httpClientMock)
377+
.build();
378+
379+
try {
380+
miApp.acquireTokenForManagedIdentity(
381+
ManagedIdentityParameters.builder(resource)
382+
.environmentVariables(environmentVariables)
383+
.build()).get();
384+
} catch (Exception exception) {
385+
assert(exception.getCause() instanceof MsalManagedIdentityException);
386+
387+
MsalManagedIdentityException miException = (MsalManagedIdentityException) exception.getCause();
388+
assertEquals(ManagedIdentitySourceType.AzureArc, miException.managedIdentitySourceType);
389+
assertEquals(MsalError.MANAGED_IDENTITY_REQUEST_FAILED, miException.errorCode());
390+
assertEquals(MsalErrorMessage.MANAGED_IDENTITY_NO_CHALLENGE_ERROR, miException.getMessage());
391+
return;
392+
}
393+
394+
fail("MsalManagedIdentityException is expected but not thrown.");
395+
}
396+
397+
@Test
398+
void azureArcManagedIdentity_InvalidAuthHeader() throws Exception {
399+
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(ManagedIdentitySourceType.AzureArc, azureArcEndpoint);
400+
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);
401+
402+
HttpResponse response = new HttpResponse();
403+
response.statusCode(HttpStatus.SC_UNAUTHORIZED);
404+
response.headers().put("WWW-Authenticate", Collections.singletonList("Basic realm=filepath=somepath"));
405+
406+
lenient().when(httpClientMock.send(any())).thenReturn(response);
407+
408+
ManagedIdentityApplication miApp = ManagedIdentityApplication
409+
.builder(ManagedIdentityId.systemAssigned())
410+
.httpClient(httpClientMock)
411+
.build();
412+
413+
try {
414+
miApp.acquireTokenForManagedIdentity(
415+
ManagedIdentityParameters.builder(resource)
416+
.environmentVariables(environmentVariables)
417+
.build()).get();
418+
} catch (Exception exception) {
419+
assert(exception.getCause() instanceof MsalManagedIdentityException);
420+
421+
MsalManagedIdentityException miException = (MsalManagedIdentityException) exception.getCause();
422+
assertEquals(ManagedIdentitySourceType.AzureArc, miException.managedIdentitySourceType);
423+
assertEquals(MsalError.MANAGED_IDENTITY_REQUEST_FAILED, miException.errorCode());
424+
assertEquals(MsalErrorMessage.MANAGED_IDENTITY_INVALID_CHALLENGE, miException.getMessage());
425+
return;
426+
}
427+
428+
fail("MsalManagedIdentityException is expected but not thrown.");
429+
}
351430
}

0 commit comments

Comments
 (0)