Skip to content

Commit 6f1778d

Browse files
Merge pull request #823 from AzureAD/nebharg/MsiAddApiToDetectEnv
Add API to get managed identity source
2 parents fb02867 + 9a07472 commit 6f1778d

12 files changed

+124
-49
lines changed

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AbstractManagedIdentitySource.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,8 @@ protected String getMessageFromErrorResponse(IHttpResponse response) {
133133
managedIdentityErrorResponse.getError(), managedIdentityErrorResponse.getErrorDescription());
134134
}
135135

136-
protected static IEnvironmentVariables getEnvironmentVariables(ManagedIdentityParameters parameters) {
137-
return parameters.environmentVariables == null ? new EnvironmentVariables() : parameters.environmentVariables;
136+
protected static IEnvironmentVariables getEnvironmentVariables() {
137+
return ManagedIdentityApplication.environmentVariables == null ?
138+
new EnvironmentVariables() : ManagedIdentityApplication.environmentVariables;
138139
}
139140
}

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AppServiceManagedIdentitySource.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ private AppServiceManagedIdentitySource(MsalRequest msalRequest, ServiceBundle s
5656

5757
static AbstractManagedIdentitySource create(MsalRequest msalRequest, ServiceBundle serviceBundle) {
5858

59-
IEnvironmentVariables environmentVariables = getEnvironmentVariables((ManagedIdentityParameters) msalRequest.requestContext().apiParameters());
59+
IEnvironmentVariables environmentVariables = getEnvironmentVariables();
6060
String msiSecret = environmentVariables.getEnvironmentVariable(Constants.IDENTITY_HEADER);
6161
String msiEndpoint = environmentVariables.getEnvironmentVariable(Constants.IDENTITY_ENDPOINT);
6262

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AzureArcManagedIdentitySource.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class AzureArcManagedIdentitySource extends AbstractManagedIdentitySource{
2727

2828
static AbstractManagedIdentitySource create(MsalRequest msalRequest, ServiceBundle serviceBundle)
2929
{
30-
IEnvironmentVariables environmentVariables = getEnvironmentVariables((ManagedIdentityParameters) msalRequest.requestContext().apiParameters());
30+
IEnvironmentVariables environmentVariables = getEnvironmentVariables();
3131
String identityEndpoint = environmentVariables.getEnvironmentVariable(Constants.IDENTITY_ENDPOINT);
3232
String imdsEndpoint = environmentVariables.getEnvironmentVariable(Constants.IMDS_ENDPOINT);
3333

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/CloudShellManagedIdentitySource.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ private CloudShellManagedIdentitySource(MsalRequest msalRequest, ServiceBundle s
4646

4747
static AbstractManagedIdentitySource create(MsalRequest msalRequest, ServiceBundle serviceBundle) {
4848

49-
IEnvironmentVariables environmentVariables = getEnvironmentVariables((ManagedIdentityParameters) msalRequest.requestContext().apiParameters());
49+
IEnvironmentVariables environmentVariables = getEnvironmentVariables();
5050
String msiEndpoint = environmentVariables.getEnvironmentVariable(Constants.MSI_ENDPOINT);
5151

5252

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/IMDSManagedIdentitySource.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,7 @@ public IMDSManagedIdentitySource(MsalRequest msalRequest,
3535
ServiceBundle serviceBundle) {
3636
super(msalRequest, serviceBundle, ManagedIdentitySourceType.IMDS);
3737
ManagedIdentityParameters parameters = (ManagedIdentityParameters) msalRequest.requestContext().apiParameters();
38-
IEnvironmentVariables environmentVariables = ((ManagedIdentityParameters) msalRequest.requestContext().apiParameters()).environmentVariables == null ?
39-
new EnvironmentVariables() :
40-
parameters.environmentVariables;
38+
IEnvironmentVariables environmentVariables = getEnvironmentVariables();
4139
if (!StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(Constants.AZURE_POD_IDENTITY_AUTHORITY_HOST))){
4240
LOG.info(String.format("[Managed Identity] Environment variable AZURE_POD_IDENTITY_AUTHORITY_HOST for IMDS returned endpoint: %s", environmentVariables.getEnvironmentVariable(Constants.AZURE_POD_IDENTITY_AUTHORITY_HOST)));
4341
try {

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ManagedIdentityApplication.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33

44
package com.microsoft.aad.msal4j;
55

6+
import lombok.AccessLevel;
67
import lombok.Getter;
8+
import lombok.Setter;
79
import org.slf4j.LoggerFactory;
810

911
import java.util.concurrent.CompletableFuture;
@@ -22,6 +24,16 @@ public class ManagedIdentityApplication extends AbstractApplicationBase implemen
2224
@Getter
2325
static TokenCache sharedTokenCache = new TokenCache();
2426

27+
@Getter(value = AccessLevel.PUBLIC)
28+
static ManagedIdentitySourceType managedIdentitySource = ManagedIdentityClient.getManagedIdentitySource();
29+
30+
@Getter(value = AccessLevel.PACKAGE)
31+
static IEnvironmentVariables environmentVariables;
32+
33+
static void setEnvironmentVariables(IEnvironmentVariables environmentVariables) {
34+
ManagedIdentityApplication.environmentVariables = environmentVariables;
35+
}
36+
2537
private ManagedIdentityApplication(Builder builder) {
2638
super(builder);
2739

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ManagedIdentityClient.java

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
package com.microsoft.aad.msal4j;
55

6+
import lombok.AccessLevel;
7+
import lombok.Getter;
68
import org.slf4j.Logger;
79
import org.slf4j.LoggerFactory;
810

@@ -12,6 +14,37 @@
1214
class ManagedIdentityClient {
1315
private static final Logger LOG = LoggerFactory.getLogger(ManagedIdentityClient.class);
1416

17+
private static ManagedIdentitySourceType managedIdentitySourceType;
18+
19+
protected static void resetManagedIdentitySourceType() {
20+
managedIdentitySourceType = ManagedIdentitySourceType.NONE;
21+
}
22+
23+
static ManagedIdentitySourceType getManagedIdentitySource() {
24+
if (managedIdentitySourceType != null && managedIdentitySourceType != ManagedIdentitySourceType.NONE) {
25+
return managedIdentitySourceType;
26+
}
27+
28+
IEnvironmentVariables environmentVariables = AbstractManagedIdentitySource.getEnvironmentVariables();
29+
30+
if (!StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(Constants.IDENTITY_ENDPOINT)) &&
31+
!StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(Constants.IDENTITY_HEADER))) {
32+
if (!StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(Constants.IDENTITY_SERVER_THUMBPRINT))) {
33+
managedIdentitySourceType = ManagedIdentitySourceType.SERVICE_FABRIC;
34+
} else
35+
managedIdentitySourceType = ManagedIdentitySourceType.APP_SERVICE;
36+
} else if (!StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(Constants.MSI_ENDPOINT))) {
37+
managedIdentitySourceType = ManagedIdentitySourceType.CLOUD_SHELL;
38+
} else if (!StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(Constants.IDENTITY_ENDPOINT)) &&
39+
!StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(Constants.IMDS_ENDPOINT))) {
40+
managedIdentitySourceType = ManagedIdentitySourceType.AZURE_ARC;
41+
} else {
42+
managedIdentitySourceType = ManagedIdentitySourceType.DEFAULT_TO_IMDS;
43+
}
44+
45+
return managedIdentitySourceType;
46+
}
47+
1548
AbstractManagedIdentitySource managedIdentitySource;
1649

1750
ManagedIdentityClient(MsalRequest msalRequest, ServiceBundle serviceBundle) {
@@ -38,16 +71,22 @@ ManagedIdentityResponse getManagedIdentityResponse(ManagedIdentityParameters par
3871
private static AbstractManagedIdentitySource createManagedIdentitySource(MsalRequest msalRequest,
3972
ServiceBundle serviceBundle) {
4073
AbstractManagedIdentitySource managedIdentitySource;
41-
if ((managedIdentitySource = ServiceFabricManagedIdentitySource.create(msalRequest, serviceBundle)) != null) {
42-
return managedIdentitySource;
43-
} else if ((managedIdentitySource = AppServiceManagedIdentitySource.create(msalRequest, serviceBundle)) != null) {
44-
return managedIdentitySource;
45-
} else if ((managedIdentitySource = CloudShellManagedIdentitySource.create(msalRequest, serviceBundle)) != null) {
46-
return managedIdentitySource;
47-
} else if ((managedIdentitySource = AzureArcManagedIdentitySource.create(msalRequest, serviceBundle)) != null) {
48-
return managedIdentitySource;
49-
} else {
50-
return new IMDSManagedIdentitySource(msalRequest, serviceBundle);
74+
75+
if (managedIdentitySourceType == null || managedIdentitySourceType == ManagedIdentitySourceType.NONE) {
76+
managedIdentitySourceType = getManagedIdentitySource();
77+
}
78+
79+
switch (managedIdentitySourceType) {
80+
case SERVICE_FABRIC:
81+
return ServiceFabricManagedIdentitySource.create(msalRequest, serviceBundle);
82+
case APP_SERVICE:
83+
return AppServiceManagedIdentitySource.create(msalRequest, serviceBundle);
84+
case CLOUD_SHELL:
85+
return CloudShellManagedIdentitySource.create(msalRequest, serviceBundle);
86+
case AZURE_ARC:
87+
return AzureArcManagedIdentitySource.create(msalRequest, serviceBundle);
88+
default:
89+
return new IMDSManagedIdentitySource(msalRequest, serviceBundle);
5190
}
5291
}
5392
}

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ManagedIdentityParameters.java

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@ public class ManagedIdentityParameters implements IAcquireTokenParameters {
2727

2828
boolean forceRefresh;
2929

30-
IEnvironmentVariables environmentVariables;
31-
3230
@Override
3331
public Set<String> scopes() {
3432
return null;
@@ -54,10 +52,6 @@ public Map<String, String> extraQueryParameters() {
5452
return null;
5553
}
5654

57-
void setEnvironmentVariablesConfig(IEnvironmentVariables environmentVariables) {
58-
this.environmentVariables = environmentVariables;
59-
}
60-
6155
private static ManagedIdentityParametersBuilder builder() {
6256
return new ManagedIdentityParametersBuilder();
6357
}

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ManagedIdentitySourceType.java

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,16 @@
66
enum ManagedIdentitySourceType {
77
// Default.
88
NONE,
9-
// The source to acquire token for managed identity is IMDS.
9+
// The source used to acquire token for managed identity is IMDS.
1010
IMDS,
11-
// The source to acquire token for managed identity is App Service.
11+
// The source used to acquire token for managed identity is App Service.
1212
APP_SERVICE,
13-
// The source to acquire token for managed identity is Azure Arc.
13+
// The source used to acquire token for managed identity is Azure Arc.
1414
AZURE_ARC,
15-
// The source to acquire token for managed identity is Cloud Shell.
15+
// The source used to acquire token for managed identity is Cloud Shell.
1616
CLOUD_SHELL,
17-
// The source to acquire token for managed identity is Service Fabric.
18-
SERVICE_FABRIC
17+
// The source used to acquire token for managed identity is Service Fabric.
18+
SERVICE_FABRIC,
19+
// The source to acquire token for managed identity is defaulted to IMDS when no environment variables are set.
20+
DEFAULT_TO_IMDS
1921
}

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ServiceFabricManagedIdentitySource.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,12 +95,11 @@ public ManagedIdentityResponse getManagedIdentityResponse(
9595

9696
static AbstractManagedIdentitySource create(MsalRequest msalRequest, ServiceBundle serviceBundle) {
9797

98-
IEnvironmentVariables environmentVariables = getEnvironmentVariables((ManagedIdentityParameters) msalRequest.requestContext().apiParameters());
98+
IEnvironmentVariables environmentVariables = getEnvironmentVariables();
9999
String identityEndpoint = environmentVariables.getEnvironmentVariable(Constants.IDENTITY_ENDPOINT);
100100
String identityHeader = environmentVariables.getEnvironmentVariable(Constants.IDENTITY_HEADER);
101101
String identityServerThumbprint = environmentVariables.getEnvironmentVariable(Constants.IDENTITY_SERVER_THUMBPRINT);
102102

103-
104103
if (StringHelper.isNullOrBlank(identityEndpoint) || StringHelper.isNullOrBlank(identityHeader) || StringHelper.isNullOrBlank(identityServerThumbprint))
105104
{
106105
LOG.info("[Managed Identity] Service fabric managed identity is unavailable.");

msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTestDataProvider.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,4 +97,14 @@ public static Stream<Arguments> createDataError() {
9797
Arguments.of(ManagedIdentitySourceType.IMDS, ManagedIdentityTests.IMDS_ENDPOINT),
9898
Arguments.of(ManagedIdentitySourceType.SERVICE_FABRIC, ManagedIdentityTests.serviceFabricEndpoint));
9999
}
100+
101+
public static Stream<Arguments> createDataGetSource() {
102+
return Stream.of(
103+
Arguments.of(ManagedIdentitySourceType.AZURE_ARC, ManagedIdentityTests.azureArcEndpoint, ManagedIdentitySourceType.AZURE_ARC),
104+
Arguments.of(ManagedIdentitySourceType.APP_SERVICE, ManagedIdentityTests.appServiceEndpoint, ManagedIdentitySourceType.APP_SERVICE),
105+
Arguments.of(ManagedIdentitySourceType.CLOUD_SHELL, ManagedIdentityTests.cloudShellEndpoint, ManagedIdentitySourceType.CLOUD_SHELL),
106+
Arguments.of(ManagedIdentitySourceType.IMDS, ManagedIdentityTests.IMDS_ENDPOINT, ManagedIdentitySourceType.DEFAULT_TO_IMDS),
107+
Arguments.of(ManagedIdentitySourceType.IMDS, "", ManagedIdentitySourceType.DEFAULT_TO_IMDS),
108+
Arguments.of(ManagedIdentitySourceType.SERVICE_FABRIC, ManagedIdentityTests.serviceFabricEndpoint, ManagedIdentitySourceType.SERVICE_FABRIC));
109+
}
100110
}

0 commit comments

Comments
 (0)