Skip to content

Commit 290b543

Browse files
authored
Merge pull request #598 from AzureAD/SJAIN/add-2s-timeout-to-IMDS-call
add 2 seconds timeout while calling IMDS
2 parents 92eace8 + d6ac699 commit 290b543

File tree

3 files changed

+39
-30
lines changed

3 files changed

+39
-30
lines changed

msal4j-sdk/src/integrationtest/java/com.microsoft.aad.msal4j/ClientCredentialsIT.java

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import labapi.AzureEnvironment;
88
import org.testng.Assert;
99
import org.testng.annotations.BeforeClass;
10+
import org.testng.annotations.DataProvider;
1011
import org.testng.annotations.Test;
1112

1213
import java.io.IOException;
@@ -118,13 +119,18 @@ public void acquireTokenClientCredentials_DefaultCacheLookup() throws Exception
118119
Assert.assertNotEquals(result2.accessToken(), result3.accessToken());
119120
}
120121

121-
@Test
122-
public void acquireTokenClientCredentials_Regional() throws Exception {
122+
@DataProvider(name = "regionWithAuthority")
123+
public static Object[][] createData() {
124+
return new Object[][]{{"westus", TestConstants.REGIONAL_MICROSOFT_AUTHORITY_BASIC_HOST_WESTUS},
125+
{"eastus", TestConstants.REGIONAL_MICROSOFT_AUTHORITY_BASIC_HOST_EASTUS}};
126+
}
127+
128+
@Test(dataProvider = "regionWithAuthority")
129+
public void acquireTokenClientCredentials_Regional(String[] regionWithAuthority) throws Exception {
123130
String clientId = "2afb0add-2f32-4946-ac90-81a02aa4550e";
124131

125-
assertAcquireTokenCommon_withRegion(clientId, certificate);
132+
assertAcquireTokenCommon_withRegion(clientId, certificate, regionWithAuthority[0], regionWithAuthority[1]);
126133
}
127-
128134
private ClientAssertion getClientAssertion(String clientId) {
129135
return JwtHelper.buildJwt(
130136
clientId,
@@ -164,15 +170,15 @@ private void assertAcquireTokenCommon_withParameters(String clientId, IClientCre
164170
Assert.assertNotNull(result.accessToken());
165171
}
166172

167-
private void assertAcquireTokenCommon_withRegion(String clientId, IClientCredential credential) throws Exception {
173+
private void assertAcquireTokenCommon_withRegion(String clientId, IClientCredential credential, String region, String regionalAuthority) throws Exception {
168174
ConfidentialClientApplication ccaNoRegion = ConfidentialClientApplication.builder(
169175
clientId, credential).
170176
authority(TestConstants.MICROSOFT_AUTHORITY).
171177
build();
172178

173179
ConfidentialClientApplication ccaRegion = ConfidentialClientApplication.builder(
174180
clientId, credential).
175-
authority("https://login.microsoft.com/microsoft.onmicrosoft.com").azureRegion("westus").
181+
authority("https://login.microsoft.com/microsoft.onmicrosoft.com").azureRegion(region).
176182
build();
177183

178184
//Ensure behavior when region not specified
@@ -193,7 +199,7 @@ private void assertAcquireTokenCommon_withRegion(String clientId, IClientCredent
193199

194200
Assert.assertNotNull(resultRegion);
195201
Assert.assertNotNull(resultRegion.accessToken());
196-
Assert.assertEquals(resultRegion.environment(), TestConstants.REGIONAL_MICROSOFT_AUTHORITY_BASIC_HOST_WESTUS);
202+
Assert.assertEquals(resultRegion.environment(), regionalAuthority);
197203

198204
IAuthenticationResult resultRegionCached = ccaRegion.acquireToken(ClientCredentialParameters
199205
.builder(Collections.singleton(KEYVAULT_DEFAULT_SCOPE))

msal4j-sdk/src/integrationtest/java/com.microsoft.aad.msal4j/TestConstants.java

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,15 @@ public class TestConstants {
3232
public final static String TENANT_SPECIFIC_AUTHORITY = MICROSOFT_AUTHORITY_HOST + MICROSOFT_AUTHORITY_TENANT;
3333
public final static String REGIONAL_MICROSOFT_AUTHORITY_BASIC_HOST_WESTUS = "westus.login.microsoft.com";
3434

35+
public final static String REGIONAL_MICROSOFT_AUTHORITY_BASIC_HOST_EASTUS = "eastus.login.microsoft.com";
36+
3537
public final static String ARLINGTON_ORGANIZATIONS_AUTHORITY = ARLINGTON_MICROSOFT_AUTHORITY_HOST + "organizations/";
36-
public final static String ARLINGTON_COMMON_AUTHORITY = ARLINGTON_MICROSOFT_AUTHORITY_HOST + "common/";
3738
public final static String ARLINGTON_TENANT_SPECIFIC_AUTHORITY = ARLINGTON_MICROSOFT_AUTHORITY_HOST + ARLINGTON_AUTHORITY_TENANT;
3839
public final static String ARLINGTON_GRAPH_DEFAULT_SCOPE = "https://graph.microsoft.us/.default";
3940

40-
4141
public final static String B2C_AUTHORITY = "https://msidlabb2c.b2clogin.com/msidlabb2c.onmicrosoft.com/";
4242
public final static String B2C_AUTHORITY_LEGACY_FORMAT = "https://msidlabb2c.b2clogin.com/tfp/msidlabb2c.onmicrosoft.com/";
43+
4344
public final static String B2C_ROPC_POLICY = "B2C_1_ROPC_Auth";
4445
public final static String B2C_SIGN_IN_POLICY = "B2C_1_SignInPolicy";
4546
public final static String B2C_AUTHORITY_SIGN_IN = B2C_AUTHORITY + B2C_SIGN_IN_POLICY;
@@ -49,19 +50,13 @@ public class TestConstants {
4950
public final static String B2C_MICROSOFTLOGIN_ROPC = B2C_MICROSOFTLOGIN_AUTHORITY + B2C_ROPC_POLICY;
5051

5152
public final static String LOCALHOST = "http://localhost:";
52-
public final static String LOCAL_FLAG_ENV_VAR = "MSAL_JAVA_RUN_LOCAL";
5353

5454
public final static String ADFS_AUTHORITY = "https://fs.msidlab8.com/adfs/";
5555
public final static String ADFS_SCOPE = USER_READ_SCOPE;
5656
public final static String ADFS_APP_ID = "PublicClientId";
5757

5858
public final static String CLAIMS = "{\"id_token\":{\"auth_time\":{\"essential\":true}}}";
5959
public final static Set<String> CLIENT_CAPABILITIES_EMPTY = new HashSet<>(Collections.emptySet());
60-
public final static Set<String> CLIENT_CAPABILITIES_LLT = new HashSet<>(Collections.singletonList("llt"));
61-
62-
// cross cloud b2b settings
63-
public final static String AUTHORITY_ARLINGTON = "https://login.microsoftonline.us/" + ARLINGTON_AUTHORITY_TENANT;
64-
public final static String AUTHORITY_MOONCAKE = "https://login.chinacloudapi.cn/mncmsidlab1.partner.onmschina.cn";
6560
public final static String AUTHORITY_PUBLIC_TENANT_SPECIFIC = "https://login.microsoftonline.com/" + MICROSOFT_AUTHORITY_TENANT;
6661

6762
public final static String DEFAULT_ACCESS_TOKEN = "defaultAccessToken";

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AadInstanceDiscoveryProvider.java

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import java.util.TreeSet;
1515
import java.util.Map;
1616
import java.util.HashMap;
17-
import java.util.concurrent.ConcurrentHashMap;
17+
import java.util.concurrent.*;
1818

1919
class AadInstanceDiscoveryProvider {
2020

@@ -31,6 +31,8 @@ class AadInstanceDiscoveryProvider {
3131
private static final String DEFAULT_API_VERSION = "2020-06-01";
3232
private static final String IMDS_ENDPOINT = "https://169.254.169.254/metadata/instance/compute/location?" + DEFAULT_API_VERSION + "&format=text";
3333

34+
private static final int IMDS_TIMEOUT = 2;
35+
private static final TimeUnit IMDS_TIMEOUT_UNIT = TimeUnit.SECONDS;
3436
static final TreeSet<String> TRUSTED_HOSTS_SET = new TreeSet<>(String.CASE_INSENSITIVE_ORDER);
3537
static final TreeSet<String> TRUSTED_SOVEREIGN_HOSTS_SET = new TreeSet<>(String.CASE_INSENSITIVE_ORDER);
3638

@@ -71,8 +73,8 @@ static InstanceDiscoveryMetadataEntry getMetadataEntry(URL authorityUrl,
7173
//If region autodetection is enabled and a specific region not already set,
7274
// set the application's region to the discovered region so that future requests can skip the IMDS endpoint call
7375
if (null == msalRequest.application().azureRegion() && msalRequest.application().autoDetectRegion()
74-
&& null != detectedRegion) {
75-
msalRequest.application().azureRegion = detectedRegion;
76+
&& null != detectedRegion) {
77+
msalRequest.application().azureRegion = detectedRegion;
7678
}
7779
cacheRegionInstanceMetadata(authorityUrl.getHost(), msalRequest.application().azureRegion());
7880
serviceBundle.getServerSideTelemetry().getCurrentRequest().regionOutcome(
@@ -291,33 +293,39 @@ private static String discoverRegion(MsalRequest msalRequest, ServiceBundle serv
291293
return System.getenv(REGION_NAME);
292294
}
293295

294-
try {
295-
//Check the IMDS endpoint to retrieve current region (will only work if application is running in an Azure VM)
296-
Map<String, String> headers = new HashMap<>();
297-
headers.put("Metadata", "true");
298-
IHttpResponse httpResponse = executeRequest(IMDS_ENDPOINT, headers, msalRequest, serviceBundle);
296+
//Check the IMDS endpoint to retrieve current region (will only work if application is running in an Azure VM)
297+
Map<String, String> headers = new HashMap<>();
298+
headers.put("Metadata", "true");
299+
300+
ExecutorService executor = Executors.newSingleThreadExecutor();
301+
Future<IHttpResponse> future = executor.submit(() -> executeRequest(IMDS_ENDPOINT, headers, msalRequest, serviceBundle));
299302

303+
try {
304+
log.info("Starting call to IMDS endpoint.");
305+
IHttpResponse httpResponse = future.get(IMDS_TIMEOUT, IMDS_TIMEOUT_UNIT);
300306
//If call to IMDS endpoint was successful, return region from response body
301307
if (httpResponse.statusCode() == HttpHelper.HTTP_STATUS_200 && !httpResponse.body().isEmpty()) {
302-
log.info("Region retrieved from IMDS endpoint: " + httpResponse.body());
308+
log.info(String.format("Region retrieved from IMDS endpoint: %s", httpResponse.body()));
303309
currentRequest.regionSource(RegionTelemetry.REGION_SOURCE_IMDS.telemetryValue);
304310

305311
return httpResponse.body();
306312
}
307-
308313
log.warn(String.format("Call to local IMDS failed with status code: %s, or response was empty", httpResponse.statusCode()));
309314
currentRequest.regionSource(RegionTelemetry.REGION_SOURCE_FAILED_AUTODETECT.telemetryValue);
310-
311-
return null;
312-
} catch (Exception e) {
315+
} catch (Exception ex) {
316+
// handle other exceptions
313317
//IMDS call failed, cannot find region
314318
//The IMDS endpoint is only available from within an Azure environment, so the most common cause of this
315319
// exception will likely be java.net.SocketException: Network is unreachable: connect
316-
log.warn(String.format("Exception during call to local IMDS endpoint: %s", e.getMessage()));
320+
log.warn(String.format("Exception during call to local IMDS endpoint: %s", ex.getMessage()));
317321
currentRequest.regionSource(RegionTelemetry.REGION_SOURCE_FAILED_AUTODETECT.telemetryValue);
322+
future.cancel(true);
318323

319-
return null;
324+
} finally {
325+
executor.shutdownNow();
320326
}
327+
328+
return null;
321329
}
322330

323331
private static void doInstanceDiscoveryAndCache(URL authorityUrl,

0 commit comments

Comments
 (0)