Skip to content

Commit 5f5e32e

Browse files
Merge pull request #713 from AzureAD/nebharg/MsiCloudShell
Cloud shell MSI
2 parents 76578e2 + 8eb5a0d commit 5f5e32e

File tree

12 files changed

+404
-89
lines changed

12 files changed

+404
-89
lines changed

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AbstractManagedIdentitySource.java

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,12 @@
33

44
package com.microsoft.aad.msal4j;
55

6-
import com.nimbusds.oauth2.sdk.ParseException;
7-
import com.nimbusds.oauth2.sdk.SerializeException;
8-
import com.nimbusds.oauth2.sdk.http.HTTPRequest;
9-
import com.nimbusds.oauth2.sdk.http.HTTPResponse;
106
import lombok.Getter;
117
import lombok.Setter;
128
import org.slf4j.Logger;
139
import org.slf4j.LoggerFactory;
1410

15-
import java.beans.Encoder;
16-
import java.io.IOException;
1711
import java.net.HttpURLConnection;
18-
import java.net.MalformedURLException;
1912
import java.net.SocketException;
2013
import java.net.URISyntaxException;
2114

@@ -56,7 +49,15 @@ public ManagedIdentityResponse getManagedIdentityResponse(
5649
IHttpResponse response;
5750

5851
try {
59-
HttpRequest httpRequest = new HttpRequest(HttpMethod.GET, managedIdentityRequest.computeURI().toString(), managedIdentityRequest.headers);
52+
53+
HttpRequest httpRequest = managedIdentityRequest.method.equals(HttpMethod.GET) ?
54+
new HttpRequest(HttpMethod.GET,
55+
managedIdentityRequest.computeURI().toString(),
56+
managedIdentityRequest.headers) :
57+
new HttpRequest(HttpMethod.POST,
58+
managedIdentityRequest.computeURI().toString(),
59+
managedIdentityRequest.headers,
60+
managedIdentityRequest.getBodyAsString());
6061
response = HttpHelper.executeHttpRequest(httpRequest, managedIdentityRequest.requestContext(), serviceBundle);
6162
} catch (URISyntaxException e) {
6263
throw new RuntimeException(e);
@@ -89,7 +90,7 @@ public ManagedIdentityResponse handleResponse(
8990
throw new MsalManagedIdentityException(AuthenticationErrorCode.MANAGED_IDENTITY_REQUEST_FAILED, message, managedIdentitySourceType);
9091
}
9192
} catch (Exception e) {
92-
if (!(e instanceof MsalServiceException)) {
93+
if (!(e instanceof MsalManagedIdentityException)) {
9394
LOG.error(
9495
String.format("[Managed Identity] Exception: %s Http status code: %s", e.getMessage(),
9596
response != null ? response.statusCode() : ""));

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AppServiceManagedIdentitySource.java

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010
import java.net.URISyntaxException;
1111
import java.util.Collections;
1212
import java.util.HashMap;
13-
import java.util.List;
14-
import java.util.Map;
1513

1614
class AppServiceManagedIdentitySource extends AbstractManagedIdentitySource{
1715

@@ -20,25 +18,22 @@ class AppServiceManagedIdentitySource extends AbstractManagedIdentitySource{
2018
// MSI Constants. Docs for MSI are available here https://docs.microsoft.com/azure/app-service/overview-managed-identity
2119
private static final String APP_SERVICE_MSI_API_VERSION = "2019-08-01";
2220
private static final String SECRET_HEADER_NAME = "X-IDENTITY-HEADER";
23-
private static URI endpointUri;
2421

25-
private URI endpoint;
26-
private String secret;
22+
private final URI MSI_ENDPOINT;
23+
private final String SECRET;
2724

2825
@Override
2926
public void createManagedIdentityRequest(String resource) {
30-
managedIdentityRequest.baseEndpoint = endpoint;
27+
managedIdentityRequest.baseEndpoint = MSI_ENDPOINT;
3128
managedIdentityRequest.method = HttpMethod.GET;
3229

3330
managedIdentityRequest.headers = new HashMap<>();
34-
managedIdentityRequest.headers.put(SECRET_HEADER_NAME, secret);
31+
managedIdentityRequest.headers.put(SECRET_HEADER_NAME, SECRET);
3532

3633
managedIdentityRequest.queryParameters = new HashMap<>();
3734
managedIdentityRequest.queryParameters.put("api-version", Collections.singletonList(APP_SERVICE_MSI_API_VERSION));
3835
managedIdentityRequest.queryParameters.put("resource", Collections.singletonList(resource));
3936

40-
String clientId = getManagedIdentityUserAssignedClientId();
41-
String resourceId = getManagedIdentityUserAssignedResourceId();
4237
if (!StringHelper.isNullOrBlank(getManagedIdentityUserAssignedClientId()))
4338
{
4439
LOG.info("[Managed Identity] Adding user assigned client id to the request.");
@@ -52,35 +47,34 @@ public void createManagedIdentityRequest(String resource) {
5247
}
5348
}
5449

55-
private AppServiceManagedIdentitySource(MsalRequest msalRequest, ServiceBundle serviceBundle, URI endpoint, String secret)
50+
private AppServiceManagedIdentitySource(MsalRequest msalRequest, ServiceBundle serviceBundle, URI msiEndpoint, String secret)
5651
{
5752
super(msalRequest, serviceBundle, ManagedIdentitySourceType.AppService);
58-
this.endpoint = endpoint;
59-
this.secret = secret;
53+
this.MSI_ENDPOINT = msiEndpoint;
54+
this.SECRET = secret;
6055
}
6156

62-
protected static AbstractManagedIdentitySource create(MsalRequest msalRequest, ServiceBundle serviceBundle) {
57+
static AbstractManagedIdentitySource create(MsalRequest msalRequest, ServiceBundle serviceBundle) {
6358

6459
IEnvironmentVariables environmentVariables = getEnvironmentVariables((ManagedIdentityParameters) msalRequest.requestContext().apiParameters());
6560
String msiSecret = environmentVariables.getEnvironmentVariable(Constants.IDENTITY_HEADER);
6661
String msiEndpoint = environmentVariables.getEnvironmentVariable(Constants.IDENTITY_ENDPOINT);
6762

68-
return validateEnvironmentVariables(msiEndpoint, msiSecret)
69-
? new AppServiceManagedIdentitySource(msalRequest, serviceBundle, endpointUri, msiSecret)
70-
: null;
63+
URI validatedEndpoint = validateAndGetUri(msiEndpoint, msiSecret);
64+
return validatedEndpoint == null ? null
65+
: new AppServiceManagedIdentitySource(msalRequest, serviceBundle, validatedEndpoint, msiSecret);
7166
}
7267

73-
private static boolean validateEnvironmentVariables(String msiEndpoint, String secret)
68+
private static URI validateAndGetUri(String msiEndpoint, String secret)
7469
{
75-
endpointUri = null;
76-
7770
// if BOTH the env vars endpoint and secret values are null, this MSI provider is unavailable.
7871
if (StringHelper.isNullOrBlank(msiEndpoint) || StringHelper.isNullOrBlank(secret))
7972
{
8073
LOG.info("[Managed Identity] App service managed identity is unavailable.");
81-
return false;
74+
return null;
8275
}
8376

77+
URI endpointUri;
8478
try
8579
{
8680
endpointUri = new URI(msiEndpoint);
@@ -93,7 +87,7 @@ private static boolean validateEnvironmentVariables(String msiEndpoint, String s
9387
}
9488

9589
LOG.info("[Managed Identity] Environment variables validation passed for app service managed identity. Endpoint URI: {endpointUri}. Creating App Service managed identity.");
96-
return true;
90+
return endpointUri;
9791
}
9892

9993
}
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
package com.microsoft.aad.msal4j;
5+
6+
import org.slf4j.Logger;
7+
import org.slf4j.LoggerFactory;
8+
9+
import java.net.URI;
10+
import java.net.URISyntaxException;
11+
import java.util.Collections;
12+
import java.util.HashMap;
13+
14+
class CloudShellManagedIdentitySource extends AbstractManagedIdentitySource{
15+
16+
private static final Logger LOG = LoggerFactory.getLogger(CloudShellManagedIdentitySource.class);
17+
18+
private final URI MSI_ENDPOINT;
19+
20+
@Override
21+
public void createManagedIdentityRequest(String resource) {
22+
managedIdentityRequest.baseEndpoint = MSI_ENDPOINT;
23+
managedIdentityRequest.method = HttpMethod.POST;
24+
25+
managedIdentityRequest.headers = new HashMap<>();
26+
managedIdentityRequest.headers.put("ContentType", "application/x-www-form-urlencoded");
27+
managedIdentityRequest.headers.put("Metadata", "true");
28+
29+
managedIdentityRequest.bodyParameters = new HashMap<>();
30+
managedIdentityRequest.bodyParameters.put("resource", Collections.singletonList(resource));
31+
}
32+
33+
private CloudShellManagedIdentitySource(MsalRequest msalRequest, ServiceBundle serviceBundle, URI msiEndpoint)
34+
{
35+
super(msalRequest, serviceBundle, ManagedIdentitySourceType.CloudShell);
36+
this.MSI_ENDPOINT = msiEndpoint;
37+
38+
ManagedIdentityIdType idType =
39+
((ManagedIdentityApplication) msalRequest.application()).getManagedIdentityId().getIdType();
40+
if (idType != ManagedIdentityIdType.SystemAssigned) {
41+
throw new MsalManagedIdentityException(MsalError.USER_ASSIGNED_MANAGED_IDENTITY_NOT_SUPPORTED,
42+
String.format(MsalErrorMessage.MANAGED_IDENTITY_USER_ASSIGNED_NOT_SUPPORTED, "cloud shell"),
43+
ManagedIdentitySourceType.CloudShell);
44+
}
45+
}
46+
47+
static AbstractManagedIdentitySource create(MsalRequest msalRequest, ServiceBundle serviceBundle) {
48+
49+
IEnvironmentVariables environmentVariables = getEnvironmentVariables((ManagedIdentityParameters) msalRequest.requestContext().apiParameters());
50+
String msiEndpoint = environmentVariables.getEnvironmentVariable(Constants.MSI_ENDPOINT);
51+
52+
53+
// if ONLY the env var MSI_ENDPOINT is set the MsiType is CloudShell
54+
if (StringHelper.isNullOrBlank(msiEndpoint))
55+
{
56+
LOG.info("[Managed Identity] Cloud shell managed identity is unavailable.");
57+
return null;
58+
}
59+
60+
URI validatedUri = validateAndGetUri(msiEndpoint);
61+
return validatedUri == null ? null
62+
: new CloudShellManagedIdentitySource(msalRequest, serviceBundle, validatedUri);
63+
}
64+
65+
private static URI validateAndGetUri(String msiEndpoint)
66+
{
67+
URI endpointUri = null;
68+
69+
try
70+
{
71+
endpointUri = new URI(msiEndpoint);
72+
}
73+
catch (URISyntaxException ex)
74+
{
75+
throw new MsalManagedIdentityException(MsalError.INVALID_MANAGED_IDENTITY_ENDPOINT, String.format(
76+
MsalErrorMessage.MANAGED_IDENTITY_ENDPOINT_INVALID_URI_ERROR, "MSI_ENDPOINT", msiEndpoint, "Cloud Shell"),
77+
ManagedIdentitySourceType.CloudShell);
78+
}
79+
80+
LOG.info("[Managed Identity] Environment variables validation passed for cloud shell managed identity. Endpoint URI: " + endpointUri + ". Creating cloud shell managed identity.");
81+
return endpointUri;
82+
}
83+
84+
}

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/IMDSManagedIdentitySource.java

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import java.net.URISyntaxException;
1212
import java.util.Collections;
1313
import java.util.HashMap;
14-
import java.util.Map;
1514

1615
class IMDSManagedIdentitySource extends AbstractManagedIdentitySource{
1716

@@ -27,12 +26,8 @@ class IMDSManagedIdentitySource extends AbstractManagedIdentitySource{
2726
}
2827
}
2928

30-
private String imdsTokenPath = "/metadata/identity/oauth2/token";
31-
private String imdsApiVersion = "2018-02-01";
32-
private static String defaultMessage = "[Managed Identity] Service request failed.";
33-
34-
private static String identityUnavailableError = "[Managed Identity] Authentication unavailable. The requested identity has not been assigned to this resource.";
35-
private static String gatewayError = "[Managed Identity] Authentication unavailable. The request failed due to a gateway error.";
29+
private static final String IMDS_TOKEN_PATH = "/metadata/identity/oauth2/token";
30+
private static final String IMDS_API_VERSION = "2018-02-01";
3631

3732
private URI imdsEndpoint;
3833

@@ -52,7 +47,7 @@ public IMDSManagedIdentitySource(MsalRequest msalRequest,
5247
}
5348

5449
StringBuilder builder = new StringBuilder(environmentVariables.getEnvironmentVariable(Constants.AZURE_POD_IDENTITY_AUTHORITY_HOST));
55-
builder.append("/" + imdsTokenPath);
50+
builder.append("/" + IMDS_TOKEN_PATH);
5651
try {
5752
imdsEndpoint = new URI(builder.toString());
5853
} catch (URISyntaxException e) {
@@ -82,7 +77,7 @@ public void createManagedIdentityRequest(String resource) {
8277
managedIdentityRequest.headers.put("Metadata", "true");
8378

8479
managedIdentityRequest.queryParameters = new HashMap<>();
85-
managedIdentityRequest.queryParameters.put("api-version", Collections.singletonList(imdsApiVersion));
80+
managedIdentityRequest.queryParameters.put("api-version", Collections.singletonList(IMDS_API_VERSION));
8681
managedIdentityRequest.queryParameters.put("resource", Collections.singletonList(resource));
8782

8883
String clientId = getManagedIdentityUserAssignedClientId();
@@ -109,10 +104,10 @@ public ManagedIdentityResponse handleResponse(
109104
String baseMessage;
110105

111106
if(response.statusCode()== HttpURLConnection.HTTP_BAD_REQUEST){
112-
baseMessage = identityUnavailableError;
107+
baseMessage = MsalErrorMessage.IDENTITY_UNAVAILABLE_ERROR;
113108
}else if(response.statusCode()== HttpURLConnection.HTTP_BAD_GATEWAY ||
114109
response.statusCode()== HttpURLConnection.HTTP_GATEWAY_TIMEOUT){
115-
baseMessage = gatewayError;
110+
baseMessage = MsalErrorMessage.GATEWAY_ERROR;
116111
}else{
117112
baseMessage = null;
118113
}
@@ -125,8 +120,8 @@ public ManagedIdentityResponse handleResponse(
125120

126121
message = message + " " + errorContentMessage;
127122

128-
LOG.error("Error message: {message} Http status code: {response.StatusCode}");
129-
throw new MsalManagedIdentityException("managed_identity_request_failed", message,
123+
LOG.error(String.format("Error message: %s Http status code: %s"), message, response.statusCode());
124+
throw new MsalManagedIdentityException(MsalError.MANAGED_IDENTITY_REQUEST_FAILED, message,
130125
ManagedIdentitySourceType.Imds);
131126
}
132127

@@ -138,7 +133,7 @@ private static String createRequestFailedMessage(IHttpResponse response, String
138133
{
139134
StringBuilder messageBuilder = new StringBuilder();
140135

141-
messageBuilder.append(StringHelper.isNullOrBlank(message) ? defaultMessage : message);
136+
messageBuilder.append(StringHelper.isNullOrBlank(message) ? MsalErrorMessage.DEFAULT_MESSAGE : message);
142137
messageBuilder.append("Status: ");
143138
messageBuilder.append(response.statusCode());
144139

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ManagedIdentityClient.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,12 @@ public ManagedIdentityResponse getManagedIdentityResponse(ManagedIdentityParamet
3232

3333
// This method tries to create managed identity source for different sources, if none is created then defaults to IMDS.
3434
private static AbstractManagedIdentitySource createManagedIdentitySource(MsalRequest msalRequest,
35-
ServiceBundle serviceBundle) throws Exception {
35+
ServiceBundle serviceBundle) {
3636
AbstractManagedIdentitySource managedIdentitySource;
3737
if ((managedIdentitySource = AppServiceManagedIdentitySource.create(msalRequest, serviceBundle)) != null) {
3838
return managedIdentitySource;
39+
} else if ((managedIdentitySource = CloudShellManagedIdentitySource.create(msalRequest, serviceBundle)) != null) {
40+
return managedIdentitySource;
3941
} else {
4042
return new IMDSManagedIdentitySource(msalRequest, serviceBundle);
4143
}

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ManagedIdentityErrorResponse.java

Lines changed: 2 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
package com.microsoft.aad.msal4j;
55

66
import com.fasterxml.jackson.annotation.JsonProperty;
7+
import lombok.Getter;
78

9+
@Getter
810
public class ManagedIdentityErrorResponse {
911

1012
@JsonProperty("message")
@@ -18,36 +20,4 @@ public class ManagedIdentityErrorResponse {
1820

1921
@JsonProperty("error_description")
2022
private String errorDescription;
21-
22-
public String getMessage() {
23-
return message;
24-
}
25-
26-
public void setMessage(String message) {
27-
this.message = message;
28-
}
29-
30-
public String getCorrelationId() {
31-
return correlationId;
32-
}
33-
34-
public void setCorrelationId(String correlationId) {
35-
this.correlationId = correlationId;
36-
}
37-
38-
public String getError() {
39-
return error;
40-
}
41-
42-
public void setError(String error) {
43-
this.error = error;
44-
}
45-
46-
public String getErrorDescription() {
47-
return errorDescription;
48-
}
49-
50-
public void setErrorDescription(String errorDescription) {
51-
this.errorDescription = errorDescription;
52-
}
5323
}

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ManagedIdentityRequest.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,21 @@ class ManagedIdentityRequest extends MsalRequest {
2222

2323
Map<String, String> headers;
2424

25-
Map<String, String> bodyParameters;
25+
Map<String, List<String>> bodyParameters;
2626

2727
Map<String, List<String>> queryParameters;
2828

2929
public ManagedIdentityRequest(ManagedIdentityApplication managedIdentityApplication, RequestContext requestContext) {
3030
super(managedIdentityApplication, requestContext);
3131
}
3232

33+
public String getBodyAsString() {
34+
if (bodyParameters == null || bodyParameters.isEmpty())
35+
return "";
36+
37+
return URLUtils.serializeParameters(bodyParameters);
38+
}
39+
3340
public URL computeURI() throws URISyntaxException {
3441
String endpoint = this.appendQueryParametersToBaseEndpoint();
3542
try {
@@ -40,7 +47,7 @@ public URL computeURI() throws URISyntaxException {
4047
}
4148

4249
private String appendQueryParametersToBaseEndpoint() {
43-
if (queryParameters.isEmpty()) {
50+
if (queryParameters == null || queryParameters.isEmpty()) {
4451
return baseEndpoint.toString();
4552
}
4653

0 commit comments

Comments
 (0)