Skip to content

Add azure arc managed identity #730

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Nov 9, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ abstract class AbstractManagedIdentitySource {

protected final ManagedIdentityRequest managedIdentityRequest;
protected final ServiceBundle serviceBundle;
private ManagedIdentitySourceType managedIdentitySourceType;
ManagedIdentitySourceType managedIdentitySourceType;

@Getter
@Setter
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.microsoft.aad.msal4j;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.net.HttpURLConnection;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.Collections;
import java.util.HashMap;

class AzureArcManagedIdentitySource extends AbstractManagedIdentitySource{

private final static Logger LOG = LoggerFactory.getLogger(AzureArcManagedIdentitySource.class);
private static final String ARC_API_VERSION = "2019-11-01";
private static final String AZURE_ARC = "Azure Arc";

private final URI MSI_ENDPOINT;

static AbstractManagedIdentitySource create(MsalRequest msalRequest, ServiceBundle serviceBundle)
{
IEnvironmentVariables environmentVariables = getEnvironmentVariables((ManagedIdentityParameters) msalRequest.requestContext().apiParameters());
String identityEndpoint = environmentVariables.getEnvironmentVariable(Constants.IDENTITY_ENDPOINT);
String imdsEndpoint = environmentVariables.getEnvironmentVariable(Constants.IMDS_ENDPOINT);

URI validatedUri = validateAndGetUri(identityEndpoint, imdsEndpoint);
return validatedUri == null ? null : new AzureArcManagedIdentitySource(validatedUri, msalRequest, serviceBundle );
}

private static URI validateAndGetUri(String identityEndpoint, String imdsEndpoint) {

// if BOTH the env vars IDENTITY_ENDPOINT and IMDS_ENDPOINT are set the MsiType is Azure Arc
if (StringHelper.isNullOrBlank(identityEndpoint) || StringHelper.isNullOrBlank(imdsEndpoint))
{
LOG.info("[Managed Identity] Azure Arc managed identity is unavailable.");
return null;
}

URI endpointUri;
try {
endpointUri = new URI(identityEndpoint);
} catch (URISyntaxException e) {
throw new MsalManagedIdentityException(MsalError.INVALID_MANAGED_IDENTITY_ENDPOINT, String.format(
MsalErrorMessage.MANAGED_IDENTITY_ENDPOINT_INVALID_URI_ERROR, "IDENTITY_ENDPOINT", identityEndpoint, AZURE_ARC),
ManagedIdentitySourceType.AzureArc);
}

LOG.info("[Managed Identity] Creating Azure Arc managed identity. Endpoint URI: " + endpointUri);
return endpointUri;
}

private AzureArcManagedIdentitySource(URI endpoint, MsalRequest msalRequest, ServiceBundle serviceBundle){
super(msalRequest, serviceBundle, ManagedIdentitySourceType.AzureArc);
this.MSI_ENDPOINT = endpoint;

ManagedIdentityIdType idType =
((ManagedIdentityApplication) msalRequest.application()).getManagedIdentityId().getIdType();
if (idType != ManagedIdentityIdType.SystemAssigned) {
throw new MsalManagedIdentityException(MsalError.USER_ASSIGNED_MANAGED_IDENTITY_NOT_SUPPORTED,
String.format(MsalErrorMessage.MANAGED_IDENTITY_USER_ASSIGNED_NOT_SUPPORTED, AZURE_ARC),
ManagedIdentitySourceType.CloudShell);
}
}

@Override
public void createManagedIdentityRequest(String resource)
{
managedIdentityRequest.baseEndpoint = MSI_ENDPOINT;
managedIdentityRequest.method = HttpMethod.GET;

managedIdentityRequest.headers = new HashMap<>();
managedIdentityRequest.headers.put("Metadata", "true");

managedIdentityRequest.queryParameters = new HashMap<>();
managedIdentityRequest.queryParameters.put("api-version", Collections.singletonList(ARC_API_VERSION));
managedIdentityRequest.queryParameters.put("resource", Collections.singletonList(resource));
}

@Override
public ManagedIdentityResponse handleResponse(
ManagedIdentityParameters parameters,
IHttpResponse response)
{
LOG.info("[Managed Identity] Response received. Status code: {response.StatusCode}");

if (response.statusCode() == HttpURLConnection.HTTP_UNAUTHORIZED)
{
if(!response.headers().containsKey("WWW-Authenticate")){
LOG.error("[Managed Identity] WWW-Authenticate header is expected but not found.");
throw new MsalManagedIdentityException(MsalError.MANAGED_IDENTITY_REQUEST_FAILED,
MsalErrorMessage.MANAGED_IDENTITY_NO_CHALLENGE_ERROR,
ManagedIdentitySourceType.AzureArc);
}

String challenge = response.headers().get("WWW-Authenticate").get(0);
String[] splitChallenge = challenge.split("=");

if (splitChallenge.length != 2)
{
LOG.error("[Managed Identity] The WWW-Authenticate header for Azure arc managed identity is not an expected format.");
throw new MsalManagedIdentityException(MsalError.MANAGED_IDENTITY_REQUEST_FAILED,
MsalErrorMessage.MANAGED_IDENTITY_INVALID_CHALLENGE,
ManagedIdentitySourceType.AzureArc);
}

String authHeaderValue = "Basic " + splitChallenge[1];

createManagedIdentityRequest(parameters.resource);

LOG.info("[Managed Identity] Adding authorization header to the request.");

managedIdentityRequest.headers.put("Authorization", authHeaderValue);

try {
response = HttpHelper.executeHttpRequest(
new HttpRequest(HttpMethod.GET, managedIdentityRequest.computeURI().toString(),
managedIdentityRequest.headers),
managedIdentityRequest.requestContext(),
serviceBundle);
} catch (URISyntaxException e) {
throw new MsalManagedIdentityException(MsalError.INVALID_MANAGED_IDENTITY_ENDPOINT,
MsalErrorMessage.MANAGED_IDENTITY_ENDPOINT_INVALID_URI_ERROR,
managedIdentitySourceType);
}

return super.handleResponse(parameters, response);
}

return super.handleResponse(parameters, response);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ private static AbstractManagedIdentitySource createManagedIdentitySource(MsalReq
return managedIdentitySource;
} else if ((managedIdentitySource = CloudShellManagedIdentitySource.create(msalRequest, serviceBundle)) != null) {
return managedIdentitySource;
} else if ((managedIdentitySource = AzureArcManagedIdentitySource.create(msalRequest, serviceBundle)) != null) {
return managedIdentitySource;
} else {
return new IMDSManagedIdentitySource(msalRequest, serviceBundle);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
@AllArgsConstructor(access = AccessLevel.PRIVATE)
public class ManagedIdentityParameters implements IAcquireTokenParameters {

@Getter
String resource;

boolean forceRefresh;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ public static Stream<Arguments> createData() {
ManagedIdentityTests.resource),
Arguments.of(ManagedIdentitySourceType.CLOUD_SHELL, ManagedIdentityTests.cloudShellEndpoint,
ManagedIdentityTests.resourceDefaultSuffix),
Arguments.of(ManagedIdentitySourceType.AzureArc, ManagedIdentityTests.azureArcEndpoint,
ManagedIdentityTests.resource),
Arguments.of(ManagedIdentitySourceType.AzureArc, ManagedIdentityTests.azureArcEndpoint,
ManagedIdentityTests.resourceDefaultSuffix),
Arguments.of(ManagedIdentitySourceType.IMDS, ManagedIdentityTests.IMDS_ENDPOINT,
ManagedIdentityTests.resource),
Arguments.of(ManagedIdentitySourceType.IMDS, ManagedIdentityTests.IMDS_ENDPOINT,
Expand All @@ -45,7 +49,11 @@ public static Stream<Arguments> createDataUserAssignedNotSupported() {
return Stream.of(
Arguments.of(ManagedIdentitySourceType.CLOUD_SHELL, ManagedIdentityTests.cloudShellEndpoint,
ManagedIdentityId.userAssignedClientId(CLIENT_ID)),
Arguments.of(ManagedIdentitySourceType.CLOUD_SHELL, ManagedIdentityTests.cloudShellEndpoint,
Arguments.of(ManagedIdentitySourceType.CloudShell, ManagedIdentityTests.cloudShellEndpoint,
ManagedIdentityId.userAssignedResourceId(RESOURCE_ID)),
Arguments.of(ManagedIdentitySourceType.AzureArc, ManagedIdentityTests.azureArcEndpoint,
ManagedIdentityId.userAssignedClientId(CLIENT_ID)),
Arguments.of(ManagedIdentitySourceType.AzureArc, ManagedIdentityTests.azureArcEndpoint,
ManagedIdentityId.userAssignedResourceId(RESOURCE_ID)));
}

Expand All @@ -59,6 +67,10 @@ public static Stream<Arguments> createDataWrongScope() {
"user.read"),
Arguments.of(ManagedIdentitySourceType.CLOUD_SHELL, ManagedIdentityTests.cloudShellEndpoint,
"https://management.core.windows.net//user_impersonation"),
Arguments.of(ManagedIdentitySourceType.AzureArc, ManagedIdentityTests.azureArcEndpoint,
"user.read"),
Arguments.of(ManagedIdentitySourceType.AzureArc, ManagedIdentityTests.azureArcEndpoint,
"https://management.core.windows.net//user_impersonation"),
Arguments.of(ManagedIdentitySourceType.IMDS, ManagedIdentityTests.IMDS_ENDPOINT,
"user.read"),
Arguments.of(ManagedIdentitySourceType.IMDS, ManagedIdentityTests.IMDS_ENDPOINT,
Expand All @@ -67,6 +79,7 @@ public static Stream<Arguments> createDataWrongScope() {

public static Stream<Arguments> createDataError() {
return Stream.of(
Arguments.of(ManagedIdentitySourceType.AzureArc, ManagedIdentityTests.azureArcEndpoint),
Arguments.of(ManagedIdentitySourceType.APP_SERVICE, ManagedIdentityTests.appServiceEndpoint),
Arguments.of(ManagedIdentitySourceType.CLOUD_SHELL, ManagedIdentityTests.cloudShellEndpoint),
Arguments.of(ManagedIdentitySourceType.IMDS, ManagedIdentityTests.IMDS_ENDPOINT));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
package com.microsoft.aad.msal4j;

import com.nimbusds.oauth2.sdk.util.URLUtils;
import org.apache.http.HttpStatus;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
Expand Down Expand Up @@ -80,6 +82,15 @@ private HttpRequest expectedRequest(ManagedIdentitySourceType source, String res
headers.put("Metadata", "true");
break;
}
case AzureArc: {
endpoint = azureArcEndpoint;

queryParameters.put("api-version", Collections.singletonList("2019-11-01"));
queryParameters.put("resource", Collections.singletonList(resource));

headers.put("Metadata", "true");
break;
}
}

switch (id.getIdType()) {
Expand Down Expand Up @@ -182,6 +193,7 @@ void managedIdentityTest_UserAssigned_NotSupported(ManagedIdentitySourceType sou
.build()).get();
} catch (Exception e) {
assertNotNull(e);
assertNotNull(e.getCause());
assertInstanceOf(MsalManagedIdentityException.class, e.getCause());

MsalManagedIdentityException msalMsiException = (MsalManagedIdentityException) e.getCause();
Expand Down Expand Up @@ -349,6 +361,39 @@ void managedIdentity_RequestFailed_UnreachableNetwork(ManagedIdentitySourceType
fail("MsalManagedIdentityException is expected but not thrown.");
}

@Test
void azureArcManagedIdentity_MissingAuthHeader() throws Exception {
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(ManagedIdentitySourceType.AzureArc, azureArcEndpoint);
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);

HttpResponse response = new HttpResponse();
response.statusCode(HttpStatus.SC_UNAUTHORIZED);

lenient().when(httpClientMock.send(any())).thenReturn(response);

ManagedIdentityApplication miApp = ManagedIdentityApplication
.builder(ManagedIdentityId.systemAssigned())
.httpClient(httpClientMock)
.build();

try {
miApp.acquireTokenForManagedIdentity(
ManagedIdentityParameters.builder(resource)
.environmentVariables(environmentVariables)
.build()).get();
} catch (Exception exception) {
assert(exception.getCause() instanceof MsalManagedIdentityException);

MsalManagedIdentityException miException = (MsalManagedIdentityException) exception.getCause();
assertEquals(ManagedIdentitySourceType.AzureArc, miException.managedIdentitySourceType);
assertEquals(MsalError.MANAGED_IDENTITY_REQUEST_FAILED, miException.errorCode());
assertEquals(MsalErrorMessage.MANAGED_IDENTITY_NO_CHALLENGE_ERROR, miException.getMessage());
return;
}

fail("MsalManagedIdentityException is expected but not thrown.");
}

@ParameterizedTest
@MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataError")
void managedIdentity_SharedCache(ManagedIdentitySourceType source, String endpoint) throws Exception {
Expand All @@ -361,13 +406,13 @@ void managedIdentity_SharedCache(ManagedIdentitySourceType source, String endpoi
.builder(ManagedIdentityId.systemAssigned())
.httpClient(httpClientMock)
.build();

ManagedIdentityApplication miApp2 = ManagedIdentityApplication
.builder(ManagedIdentityId.systemAssigned())
.httpClient(httpClientMock)
.build();

IAuthenticationResult resultMiApp1 = miApp1.acquireTokenForManagedIdentity(
IAuthenticationResult resultMiApp1 = miApp1.acquireTokenForManagedIdentity(
ManagedIdentityParameters.builder(resource)
.environmentVariables(environmentVariables)
.build()).get();
Expand All @@ -386,4 +431,38 @@ void managedIdentity_SharedCache(ManagedIdentitySourceType source, String endpoi
// should return the same token
assertEquals(resultMiApp1.accessToken(), resultMiApp2.accessToken());
}

@Test
void azureArcManagedIdentity_InvalidAuthHeader() throws Exception {
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(ManagedIdentitySourceType.AzureArc, azureArcEndpoint);
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);

HttpResponse response = new HttpResponse();
response.statusCode(HttpStatus.SC_UNAUTHORIZED);
response.headers().put("WWW-Authenticate", Collections.singletonList("Basic realm=filepath=somepath"));

lenient().when(httpClientMock.send(any())).thenReturn(response);

ManagedIdentityApplication miApp = ManagedIdentityApplication
.builder(ManagedIdentityId.systemAssigned())
.httpClient(httpClientMock)
.build();

try {
miApp.acquireTokenForManagedIdentity(
ManagedIdentityParameters.builder(resource)
.environmentVariables(environmentVariables)
.build()).get();
} catch (Exception exception) {
assert(exception.getCause() instanceof MsalManagedIdentityException);

MsalManagedIdentityException miException = (MsalManagedIdentityException) exception.getCause();
assertEquals(ManagedIdentitySourceType.AzureArc, miException.managedIdentitySourceType);
assertEquals(MsalError.MANAGED_IDENTITY_REQUEST_FAILED, miException.errorCode());
assertEquals(MsalErrorMessage.MANAGED_IDENTITY_INVALID_CHALLENGE, miException.getMessage());
return;
}

fail("MsalManagedIdentityException is expected but not thrown.");
}
}