Skip to content

Commit 0913e04

Browse files
committed
Add code and unit tests for cloud shell
1 parent 76578e2 commit 0913e04

File tree

8 files changed

+173
-23
lines changed

8 files changed

+173
-23
lines changed

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AbstractManagedIdentitySource.java

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,12 @@
33

44
package com.microsoft.aad.msal4j;
55

6-
import com.nimbusds.oauth2.sdk.ParseException;
7-
import com.nimbusds.oauth2.sdk.SerializeException;
8-
import com.nimbusds.oauth2.sdk.http.HTTPRequest;
9-
import com.nimbusds.oauth2.sdk.http.HTTPResponse;
106
import lombok.Getter;
117
import lombok.Setter;
128
import org.slf4j.Logger;
139
import org.slf4j.LoggerFactory;
1410

15-
import java.beans.Encoder;
16-
import java.io.IOException;
1711
import java.net.HttpURLConnection;
18-
import java.net.MalformedURLException;
1912
import java.net.SocketException;
2013
import java.net.URISyntaxException;
2114

@@ -56,7 +49,15 @@ public ManagedIdentityResponse getManagedIdentityResponse(
5649
IHttpResponse response;
5750

5851
try {
59-
HttpRequest httpRequest = new HttpRequest(HttpMethod.GET, managedIdentityRequest.computeURI().toString(), managedIdentityRequest.headers);
52+
53+
HttpRequest httpRequest = managedIdentityRequest.method.equals(HttpMethod.GET) ?
54+
new HttpRequest(HttpMethod.GET,
55+
managedIdentityRequest.computeURI().toString(),
56+
managedIdentityRequest.headers) :
57+
new HttpRequest(HttpMethod.POST,
58+
managedIdentityRequest.computeURI().toString(),
59+
managedIdentityRequest.headers,
60+
managedIdentityRequest.getBodyAsString());
6061
response = HttpHelper.executeHttpRequest(httpRequest, managedIdentityRequest.requestContext(), serviceBundle);
6162
} catch (URISyntaxException e) {
6263
throw new RuntimeException(e);
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
package com.microsoft.aad.msal4j;
5+
6+
import com.nimbusds.oauth2.sdk.util.URLUtils;
7+
import org.slf4j.Logger;
8+
import org.slf4j.LoggerFactory;
9+
10+
import java.net.URI;
11+
import java.net.URISyntaxException;
12+
import java.util.Collections;
13+
import java.util.HashMap;
14+
15+
class CloudShellManagedIdentitySource extends AbstractManagedIdentitySource{
16+
17+
private static final Logger LOG = LoggerFactory.getLogger(CloudShellManagedIdentitySource.class);
18+
19+
// MSI Constants. Docs for MSI are available here https://learn.microsoft.com/en-us/azure/cloud-shell/msi-authorization
20+
private static URI endpointUri;
21+
22+
private URI endpoint;
23+
24+
@Override
25+
public void createManagedIdentityRequest(String resource) {
26+
managedIdentityRequest.baseEndpoint = endpoint;
27+
managedIdentityRequest.method = HttpMethod.POST;
28+
29+
managedIdentityRequest.headers = new HashMap<>();
30+
managedIdentityRequest.headers.put("ContentType", "application/x-www-form-urlencoded");
31+
managedIdentityRequest.headers.put("Metadata", "true");
32+
33+
managedIdentityRequest.bodyParameters = new HashMap<>();
34+
managedIdentityRequest.bodyParameters.put("resource", Collections.singletonList(resource));
35+
}
36+
37+
private CloudShellManagedIdentitySource(MsalRequest msalRequest, ServiceBundle serviceBundle, URI endpoint)
38+
{
39+
super(msalRequest, serviceBundle, ManagedIdentitySourceType.CloudShell);
40+
this.endpoint = endpoint;
41+
42+
ManagedIdentityIdType idType =
43+
((ManagedIdentityApplication) msalRequest.application()).getManagedIdentityId().getIdType();
44+
if (idType != ManagedIdentityIdType.SystemAssigned) {
45+
throw new MsalManagedIdentityException(MsalError.USER_ASSIGNED_MANAGED_IDENTITY_NOT_SUPPORTED,
46+
String.format(MsalErrorMessage.MANAGED_IDENTITY_USER_ASSIGNED_NOT_SUPPORTED, "cloud shell"),
47+
ManagedIdentitySourceType.CloudShell);
48+
}
49+
}
50+
51+
protected static AbstractManagedIdentitySource create(MsalRequest msalRequest, ServiceBundle serviceBundle) {
52+
53+
IEnvironmentVariables environmentVariables = getEnvironmentVariables((ManagedIdentityParameters) msalRequest.requestContext().apiParameters());
54+
String msiEndpoint = environmentVariables.getEnvironmentVariable(Constants.MSI_ENDPOINT);
55+
56+
57+
// if ONLY the env var MSI_ENDPOINT is set the MsiType is CloudShell
58+
if (StringHelper.isNullOrBlank(msiEndpoint))
59+
{
60+
LOG.info("[Managed Identity] Cloud shell managed identity is unavailable.");
61+
return null;
62+
}
63+
64+
return validateEnvironmentVariables(msiEndpoint)
65+
? new CloudShellManagedIdentitySource(msalRequest, serviceBundle, endpointUri)
66+
: null;
67+
}
68+
69+
private static boolean validateEnvironmentVariables(String msiEndpoint)
70+
{
71+
endpointUri = null;
72+
73+
try
74+
{
75+
endpointUri = new URI(msiEndpoint);
76+
}
77+
catch (URISyntaxException ex)
78+
{
79+
throw new MsalManagedIdentityException(MsalError.INVALID_MANAGED_IDENTITY_ENDPOINT, String.format(
80+
MsalErrorMessage.MANAGED_IDENTITY_ENDPOINT_INVALID_URI_ERROR, "MSI_ENDPOINT", msiEndpoint, "Cloud Shell"),
81+
ManagedIdentitySourceType.CloudShell);
82+
}
83+
84+
LOG.info("[Managed Identity] Environment variables validation passed for cloud shell managed identity. Endpoint URI: " + endpointUri + ". Creating cloud shell managed identity.");
85+
return true;
86+
}
87+
88+
}

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/IMDSManagedIdentitySource.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import java.net.URISyntaxException;
1212
import java.util.Collections;
1313
import java.util.HashMap;
14-
import java.util.Map;
1514

1615
class IMDSManagedIdentitySource extends AbstractManagedIdentitySource{
1716

@@ -125,7 +124,7 @@ public ManagedIdentityResponse handleResponse(
125124

126125
message = message + " " + errorContentMessage;
127126

128-
LOG.error("Error message: {message} Http status code: {response.StatusCode}");
127+
LOG.error(String.format("Error message: %s Http status code: %s"), message, response.statusCode());
129128
throw new MsalManagedIdentityException("managed_identity_request_failed", message,
130129
ManagedIdentitySourceType.Imds);
131130
}

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ManagedIdentityClient.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,12 @@ public ManagedIdentityResponse getManagedIdentityResponse(ManagedIdentityParamet
3232

3333
// This method tries to create managed identity source for different sources, if none is created then defaults to IMDS.
3434
private static AbstractManagedIdentitySource createManagedIdentitySource(MsalRequest msalRequest,
35-
ServiceBundle serviceBundle) throws Exception {
35+
ServiceBundle serviceBundle) {
3636
AbstractManagedIdentitySource managedIdentitySource;
3737
if ((managedIdentitySource = AppServiceManagedIdentitySource.create(msalRequest, serviceBundle)) != null) {
3838
return managedIdentitySource;
39+
} else if ((managedIdentitySource = CloudShellManagedIdentitySource.create(msalRequest, serviceBundle)) != null) {
40+
return managedIdentitySource;
3941
} else {
4042
return new IMDSManagedIdentitySource(msalRequest, serviceBundle);
4143
}

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ManagedIdentityRequest.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,21 @@ class ManagedIdentityRequest extends MsalRequest {
2222

2323
Map<String, String> headers;
2424

25-
Map<String, String> bodyParameters;
25+
Map<String, List<String>> bodyParameters;
2626

2727
Map<String, List<String>> queryParameters;
2828

2929
public ManagedIdentityRequest(ManagedIdentityApplication managedIdentityApplication, RequestContext requestContext) {
3030
super(managedIdentityApplication, requestContext);
3131
}
3232

33+
public String getBodyAsString() {
34+
if (bodyParameters == null || bodyParameters.isEmpty())
35+
return "";
36+
37+
return URLUtils.serializeParameters(bodyParameters);
38+
}
39+
3340
public URL computeURI() throws URISyntaxException {
3441
String endpoint = this.appendQueryParametersToBaseEndpoint();
3542
try {
@@ -40,7 +47,7 @@ public URL computeURI() throws URISyntaxException {
4047
}
4148

4249
private String appendQueryParametersToBaseEndpoint() {
43-
if (queryParameters.isEmpty()) {
50+
if (queryParameters == null || queryParameters.isEmpty()) {
4451
return baseEndpoint.toString();
4552
}
4653

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/MsalErrorMessage.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class MsalErrorMessage {
1515

1616
public static final String MANAGED_IDENTITY_USER_ASSIGNED_NOT_CONFIGURABLE_AT_RUNTIME = "[Managed Identity] Service Fabric user assigned managed identity ClientId or ResourceId is not configurable at runtime.";
1717

18-
public static final String MANAGED_IDENTITY_USER_ASSIGNED_NOT_SUPPORTED = "[Managed Identity] User assigned identity is not supported by the %s Managed Identity. To authenticate with the system assigned identity omit the client id in ManagedIdentityApplicationBuilder.Create().";
18+
public static final String MANAGED_IDENTITY_USER_ASSIGNED_NOT_SUPPORTED = "[Managed Identity] User assigned identity is not supported by the %s Managed Identity. To authenticate with the system assigned identity use ManagedIdentityApplication.builder(ManagedIdentityId.systemAssigned()).build().";
1919

2020
public static final String SCOPES_REQUIRED = "At least one scope needs to be requested for this authentication flow. ";
2121
}

msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTestDataProvider.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ public static Stream<Arguments> createData() {
1717
ManagedIdentityTests.resource),
1818
Arguments.of(ManagedIdentitySourceType.AppService, ManagedIdentityTests.appServiceEndpoint,
1919
ManagedIdentityTests.resourceDefaultSuffix),
20+
Arguments.of(ManagedIdentitySourceType.CloudShell, ManagedIdentityTests.cloudShellEndpoint,
21+
ManagedIdentityTests.resource),
22+
Arguments.of(ManagedIdentitySourceType.CloudShell, ManagedIdentityTests.cloudShellEndpoint,
23+
ManagedIdentityTests.resourceDefaultSuffix),
2024
Arguments.of(ManagedIdentitySourceType.Imds, ManagedIdentityTests.IMDS_ENDPOINT,
2125
ManagedIdentityTests.resource),
2226
Arguments.of(ManagedIdentitySourceType.Imds, ManagedIdentityTests.IMDS_ENDPOINT,
@@ -37,12 +41,24 @@ public static Stream<Arguments> createDataUserAssigned() {
3741
ManagedIdentityId.userAssignedResourceId(RESOURCE_ID)));
3842
}
3943

44+
public static Stream<Arguments> createDataUserAssignedNotSupported() {
45+
return Stream.of(
46+
Arguments.of(ManagedIdentitySourceType.CloudShell, ManagedIdentityTests.cloudShellEndpoint,
47+
ManagedIdentityId.userAssignedClientId(CLIENT_ID)),
48+
Arguments.of(ManagedIdentitySourceType.CloudShell, ManagedIdentityTests.cloudShellEndpoint,
49+
ManagedIdentityId.userAssignedResourceId(RESOURCE_ID)));
50+
}
51+
4052
public static Stream<Arguments> createDataWrongScope() {
4153
return Stream.of(
4254
Arguments.of(ManagedIdentitySourceType.AppService, ManagedIdentityTests.appServiceEndpoint,
4355
"user.read"),
4456
Arguments.of(ManagedIdentitySourceType.AppService, ManagedIdentityTests.appServiceEndpoint,
4557
"https://management.core.windows.net//user_impersonation"),
58+
Arguments.of(ManagedIdentitySourceType.CloudShell, ManagedIdentityTests.cloudShellEndpoint,
59+
"user.read"),
60+
Arguments.of(ManagedIdentitySourceType.CloudShell, ManagedIdentityTests.cloudShellEndpoint,
61+
"https://management.core.windows.net//user_impersonation"),
4662
Arguments.of(ManagedIdentitySourceType.Imds, ManagedIdentityTests.IMDS_ENDPOINT,
4763
"user.read"),
4864
Arguments.of(ManagedIdentitySourceType.Imds, ManagedIdentityTests.IMDS_ENDPOINT,
@@ -52,6 +68,7 @@ public static Stream<Arguments> createDataWrongScope() {
5268
public static Stream<Arguments> createDataError() {
5369
return Stream.of(
5470
Arguments.of(ManagedIdentitySourceType.AppService, ManagedIdentityTests.appServiceEndpoint),
71+
Arguments.of(ManagedIdentitySourceType.CloudShell, ManagedIdentityTests.cloudShellEndpoint),
5572
Arguments.of(ManagedIdentitySourceType.Imds, ManagedIdentityTests.IMDS_ENDPOINT));
5673
}
5774
}

msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTests.java

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import org.mockito.junit.jupiter.MockitoExtension;
1212

1313
import java.net.SocketException;
14-
import java.net.URISyntaxException;
1514
import java.time.Instant;
1615
import java.time.temporal.ChronoUnit;
1716
import java.util.Collections;
@@ -36,17 +35,15 @@ public class ManagedIdentityTests {
3635

3736
private String getSuccessfulResponse(String resource) {
3837
long expiresOn = Instant.now().plus(1, ChronoUnit.HOURS).getEpochSecond();
39-
String response = "{\"access_token\":\"accesstoken\",\"expires_on\":\"" + expiresOn + "\",\"resource\":\"" + resource + "\",\"token_type\":" +
38+
return "{\"access_token\":\"accesstoken\",\"expires_on\":\"" + expiresOn + "\",\"resource\":\"" + resource + "\",\"token_type\":" +
4039
"\"Bearer\",\"client_id\":\"client_id\"}";
41-
42-
return response;
4340
}
4441

4542
private String getMsiErrorResponse() {
4643
return "{\"statusCode\":\"500\",\"message\":\"An unexpected error occured while fetching the AAD Token.\",\"correlationId\":\"7d0c9763-ff1d-4842-a3f3-6d49e64f4513\"}";
4744
}
4845

49-
private HttpRequest expectedRequest(ManagedIdentitySourceType source, String resource) throws URISyntaxException {
46+
private HttpRequest expectedRequest(ManagedIdentitySourceType source, String resource) {
5047
return expectedRequest(source, resource, ManagedIdentityId.systemAssigned());
5148
}
5249

@@ -55,17 +52,27 @@ private HttpRequest expectedRequest(ManagedIdentitySourceType source, String res
5552
String endpoint = null;
5653
Map<String, String> headers = new HashMap<>();
5754
Map<String, List<String>> queryParameters = new HashMap<>();
55+
Map<String, List<String>> bodyParameters = new HashMap<>();
5856

5957
switch (source) {
6058
case AppService: {
6159
endpoint = appServiceEndpoint;
62-
queryParameters = new HashMap<>();
60+
6361
queryParameters.put("api-version", Collections.singletonList("2019-08-01"));
6462
queryParameters.put("resource", Collections.singletonList(resource));
65-
headers = new HashMap<>();
63+
6664
headers.put("X-IDENTITY-HEADER", "secret");
6765
break;
6866
}
67+
case CloudShell: {
68+
endpoint = cloudShellEndpoint;
69+
70+
headers.put("ContentType", "application/x-www-form-urlencoded");
71+
headers.put("Metadata", "true");
72+
73+
bodyParameters.put("resource", Collections.singletonList(resource));
74+
return new HttpRequest(HttpMethod.POST, computeUri(endpoint, queryParameters), headers, URLUtils.serializeParameters(bodyParameters));
75+
}
6976
case Imds: {
7077
endpoint = IMDS_ENDPOINT;
7178
queryParameters.put("api-version", Collections.singletonList("2018-02-01"));
@@ -89,12 +96,12 @@ private HttpRequest expectedRequest(ManagedIdentitySourceType source, String res
8996

9097
private String computeUri(String endpoint, Map<String, List<String>> queryParameters) {
9198
if (queryParameters.isEmpty()) {
92-
return endpoint.toString();
99+
return endpoint;
93100
}
94101

95102
String queryString = URLUtils.serializeParameters(queryParameters);
96103

97-
return endpoint.toString() + "?" + queryString;
104+
return endpoint + "?" + queryString;
98105
}
99106

100107
private HttpResponse expectedResponse(int statusCode, String response) {
@@ -157,6 +164,35 @@ void managedIdentityTest_UserAssigned_SuccessfulResponse(ManagedIdentitySourceTy
157164
assertNotNull(result.accessToken());
158165
}
159166

167+
@ParameterizedTest
168+
@MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataUserAssignedNotSupported")
169+
void managedIdentityTest_UserAssigned_NotSupported(ManagedIdentitySourceType source, String endpoint, ManagedIdentityId id) throws Exception {
170+
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint);
171+
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);
172+
173+
ManagedIdentityApplication miApp = ManagedIdentityApplication
174+
.builder(id)
175+
.httpClient(httpClientMock)
176+
.build();
177+
178+
try {
179+
IAuthenticationResult result = miApp.acquireTokenForManagedIdentity(
180+
ManagedIdentityParameters.builder(resource)
181+
.environmentVariables(environmentVariables)
182+
.build()).get();
183+
} catch (Exception e) {
184+
assertNotNull(e);
185+
assertInstanceOf(MsalManagedIdentityException.class, e.getCause());
186+
187+
MsalManagedIdentityException msalMsiException = (MsalManagedIdentityException) e.getCause();
188+
assertEquals(ManagedIdentitySourceType.CloudShell, msalMsiException.managedIdentitySourceType);
189+
assertEquals(MsalError.USER_ASSIGNED_MANAGED_IDENTITY_NOT_SUPPORTED, msalMsiException.errorCode());
190+
return;
191+
}
192+
193+
fail("MsalManagedIdentityException is expected but not thrown.");
194+
}
195+
160196
@ParameterizedTest
161197
@MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createData")
162198
void managedIdentityTest_DifferentScopes_RequestsNewToken(ManagedIdentitySourceType source, String endpoint) throws Exception {

0 commit comments

Comments
 (0)