Skip to content

Commit 6d850b1

Browse files
committed
add 2 seconds timeout while calling IMDS
1 parent d1cb3be commit 6d850b1

File tree

3 files changed

+56
-32
lines changed

3 files changed

+56
-32
lines changed

msal4j-sdk/src/integrationtest/java/com.microsoft.aad.msal4j/ClientCredentialsIT.java

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import labapi.AzureEnvironment;
88
import org.testng.Assert;
99
import org.testng.annotations.BeforeClass;
10+
import org.testng.annotations.DataProvider;
1011
import org.testng.annotations.Test;
1112

1213
import java.io.IOException;
@@ -118,13 +119,18 @@ public void acquireTokenClientCredentials_DefaultCacheLookup() throws Exception
118119
Assert.assertNotEquals(result2.accessToken(), result3.accessToken());
119120
}
120121

121-
@Test
122-
public void acquireTokenClientCredentials_Regional() throws Exception {
122+
@DataProvider(name = "regionWithAuthority")
123+
public static Object[][] createData() {
124+
return new Object[][]{{"westus", TestConstants.REGIONAL_MICROSOFT_AUTHORITY_BASIC_HOST_WESTUS},
125+
{"eastus", TestConstants.REGIONAL_MICROSOFT_AUTHORITY_BASIC_HOST_EASTUS}};
126+
}
127+
128+
@Test(dataProvider = "regionWithAuthority")
129+
public void acquireTokenClientCredentials_Regional(String[] regionWithAuthority) throws Exception {
123130
String clientId = "2afb0add-2f32-4946-ac90-81a02aa4550e";
124131

125-
assertAcquireTokenCommon_withRegion(clientId, certificate);
132+
assertAcquireTokenCommon_withRegion(clientId, certificate, regionWithAuthority[0], regionWithAuthority[1]);
126133
}
127-
128134
private ClientAssertion getClientAssertion(String clientId) {
129135
return JwtHelper.buildJwt(
130136
clientId,
@@ -164,15 +170,15 @@ private void assertAcquireTokenCommon_withParameters(String clientId, IClientCre
164170
Assert.assertNotNull(result.accessToken());
165171
}
166172

167-
private void assertAcquireTokenCommon_withRegion(String clientId, IClientCredential credential) throws Exception {
173+
private void assertAcquireTokenCommon_withRegion(String clientId, IClientCredential credential, String region, String regionalAuthority) throws Exception {
168174
ConfidentialClientApplication ccaNoRegion = ConfidentialClientApplication.builder(
169175
clientId, credential).
170176
authority(TestConstants.MICROSOFT_AUTHORITY).
171177
build();
172178

173179
ConfidentialClientApplication ccaRegion = ConfidentialClientApplication.builder(
174180
clientId, credential).
175-
authority("https://login.microsoft.com/microsoft.onmicrosoft.com").azureRegion("westus").
181+
authority("https://login.microsoft.com/microsoft.onmicrosoft.com").azureRegion(region).
176182
build();
177183

178184
//Ensure behavior when region not specified
@@ -193,7 +199,7 @@ private void assertAcquireTokenCommon_withRegion(String clientId, IClientCredent
193199

194200
Assert.assertNotNull(resultRegion);
195201
Assert.assertNotNull(resultRegion.accessToken());
196-
Assert.assertEquals(resultRegion.environment(), TestConstants.REGIONAL_MICROSOFT_AUTHORITY_BASIC_HOST_WESTUS);
202+
Assert.assertEquals(resultRegion.environment(), regionalAuthority);
197203

198204
IAuthenticationResult resultRegionCached = ccaRegion.acquireToken(ClientCredentialParameters
199205
.builder(Collections.singleton(KEYVAULT_DEFAULT_SCOPE))

msal4j-sdk/src/integrationtest/java/com.microsoft.aad.msal4j/TestConstants.java

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,14 @@ public class TestConstants {
3232
public final static String TENANT_SPECIFIC_AUTHORITY = MICROSOFT_AUTHORITY_HOST + MICROSOFT_AUTHORITY_TENANT;
3333
public final static String REGIONAL_MICROSOFT_AUTHORITY_BASIC_HOST_WESTUS = "westus.login.microsoft.com";
3434

35+
public final static String REGIONAL_MICROSOFT_AUTHORITY_BASIC_HOST_EASTUS = "eastus.login.microsoft.com";
36+
3537
public final static String ARLINGTON_ORGANIZATIONS_AUTHORITY = ARLINGTON_MICROSOFT_AUTHORITY_HOST + "organizations/";
36-
public final static String ARLINGTON_COMMON_AUTHORITY = ARLINGTON_MICROSOFT_AUTHORITY_HOST + "common/";
3738
public final static String ARLINGTON_TENANT_SPECIFIC_AUTHORITY = ARLINGTON_MICROSOFT_AUTHORITY_HOST + ARLINGTON_AUTHORITY_TENANT;
3839
public final static String ARLINGTON_GRAPH_DEFAULT_SCOPE = "https://graph.microsoft.us/.default";
3940

4041

4142
public final static String B2C_AUTHORITY = "https://msidlabb2c.b2clogin.com/tfp/msidlabb2c.onmicrosoft.com/";
42-
public final static String B2C_AUTHORITY_URL = "https://msidlabb2c.b2clogin.com/msidlabb2c.onmicrosoft.com/";
4343
public final static String B2C_ROPC_POLICY = "B2C_1_ROPC_Auth";
4444
public final static String B2C_SIGN_IN_POLICY = "B2C_1_SignInPolicy";
4545
public final static String B2C_AUTHORITY_SIGN_IN = B2C_AUTHORITY + B2C_SIGN_IN_POLICY;
@@ -49,19 +49,13 @@ public class TestConstants {
4949
public final static String B2C_MICROSOFTLOGIN_ROPC = B2C_MICROSOFTLOGIN_AUTHORITY + B2C_ROPC_POLICY;
5050

5151
public final static String LOCALHOST = "http://localhost:";
52-
public final static String LOCAL_FLAG_ENV_VAR = "MSAL_JAVA_RUN_LOCAL";
5352

5453
public final static String ADFS_AUTHORITY = "https://fs.msidlab8.com/adfs/";
5554
public final static String ADFS_SCOPE = USER_READ_SCOPE;
5655
public final static String ADFS_APP_ID = "PublicClientId";
5756

5857
public final static String CLAIMS = "{\"id_token\":{\"auth_time\":{\"essential\":true}}}";
5958
public final static Set<String> CLIENT_CAPABILITIES_EMPTY = new HashSet<>(Collections.emptySet());
60-
public final static Set<String> CLIENT_CAPABILITIES_LLT = new HashSet<>(Collections.singletonList("llt"));
61-
62-
// cross cloud b2b settings
63-
public final static String AUTHORITY_ARLINGTON = "https://login.microsoftonline.us/" + ARLINGTON_AUTHORITY_TENANT;
64-
public final static String AUTHORITY_MOONCAKE = "https://login.chinacloudapi.cn/mncmsidlab1.partner.onmschina.cn";
6559
public final static String AUTHORITY_PUBLIC_TENANT_SPECIFIC = "https://login.microsoftonline.com/" + MICROSOFT_AUTHORITY_TENANT;
6660

6761
public final static String DEFAULT_ACCESS_TOKEN = "defaultAccessToken";

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AadInstanceDiscoveryProvider.java

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import java.util.TreeSet;
1515
import java.util.Map;
1616
import java.util.HashMap;
17-
import java.util.concurrent.ConcurrentHashMap;
17+
import java.util.concurrent.*;
1818

1919
class AadInstanceDiscoveryProvider {
2020

@@ -60,23 +60,21 @@ static InstanceDiscoveryMetadataEntry getMetadataEntry(URL authorityUrl,
6060
ServiceBundle serviceBundle) {
6161
String host = authorityUrl.getHost();
6262

63-
if (shouldUseRegionalEndpoint(msalRequest)) {
64-
//Server side telemetry requires the result from region discovery when any part of the region API is used
65-
String detectedRegion = discoverRegion(msalRequest, serviceBundle);
63+
ExecutorService executor = Executors.newSingleThreadExecutor();
6664

67-
if (msalRequest.application().azureRegion() != null) {
68-
host = getRegionalizedHost(authorityUrl.getHost(), msalRequest.application().azureRegion());
69-
}
65+
Future<String> future = executor.submit(() -> performRegionalDiscovery(authorityUrl, msalRequest, serviceBundle));
7066

71-
//If region autodetection is enabled and a specific region not already set,
72-
// set the application's region to the discovered region so that future requests can skip the IMDS endpoint call
73-
if (null == msalRequest.application().azureRegion() && msalRequest.application().autoDetectRegion()
74-
&& null != detectedRegion) {
75-
msalRequest.application().azureRegion = detectedRegion;
76-
}
77-
cacheRegionInstanceMetadata(authorityUrl.getHost(), msalRequest.application().azureRegion());
78-
serviceBundle.getServerSideTelemetry().getCurrentRequest().regionOutcome(
79-
determineRegionOutcome(detectedRegion, msalRequest.application().azureRegion(), msalRequest.application().autoDetectRegion()));
67+
try {
68+
log.info("Starting call to IMDS endpoint.");
69+
host = future.get(2, TimeUnit.SECONDS);
70+
} catch (TimeoutException ex) {
71+
log.info("Cancelled call to IMDS endpoint after waiting for 2 seconds");
72+
future.cancel(true);
73+
} catch (Exception ex) {
74+
// handle other exceptions
75+
log.info("Exception while calling IMDS endpoint" + ex.getMessage());
76+
} finally {
77+
executor.shutdownNow();
8078
}
8179

8280
InstanceDiscoveryMetadataEntry result = cache.get(host);
@@ -97,6 +95,32 @@ static InstanceDiscoveryMetadataEntry getMetadataEntry(URL authorityUrl,
9795
return cache.get(host);
9896
}
9997

98+
private static String performRegionalDiscovery(URL authorityUrl, MsalRequest msalRequest, ServiceBundle serviceBundle){
99+
100+
String host = authorityUrl.getHost();
101+
102+
if (shouldUseRegionalEndpoint(msalRequest)) {
103+
//Server side telemetry requires the result from region discovery when any part of the region API is used
104+
String detectedRegion = discoverRegion(msalRequest, serviceBundle);
105+
106+
if (msalRequest.application().azureRegion() != null) {
107+
host = getRegionalizedHost(authorityUrl.getHost(), msalRequest.application().azureRegion());
108+
}
109+
110+
//If region autodetection is enabled and a specific region not already set,
111+
// set the application's region to the discovered region so that future requests can skip the IMDS endpoint call
112+
if (null == msalRequest.application().azureRegion() && msalRequest.application().autoDetectRegion()
113+
&& null != detectedRegion) {
114+
msalRequest.application().azureRegion = detectedRegion;
115+
}
116+
cacheRegionInstanceMetadata(authorityUrl.getHost(), msalRequest.application().azureRegion());
117+
serviceBundle.getServerSideTelemetry().getCurrentRequest().regionOutcome(
118+
determineRegionOutcome(detectedRegion, msalRequest.application().azureRegion(), msalRequest.application().autoDetectRegion()));
119+
}
120+
121+
return host;
122+
}
123+
100124
static Set<String> getAliases(String host) {
101125
if (cache.containsKey(host)) {
102126
return cache.get(host).aliases();
@@ -299,7 +323,7 @@ private static String discoverRegion(MsalRequest msalRequest, ServiceBundle serv
299323

300324
//If call to IMDS endpoint was successful, return region from response body
301325
if (httpResponse.statusCode() == HttpHelper.HTTP_STATUS_200 && !httpResponse.body().isEmpty()) {
302-
log.info("Region retrieved from IMDS endpoint: " + httpResponse.body());
326+
log.info(String.format("Region retrieved from IMDS endpoint: %s", httpResponse.body()));
303327
currentRequest.regionSource(RegionTelemetry.REGION_SOURCE_IMDS.telemetryValue);
304328

305329
return httpResponse.body();

0 commit comments

Comments
 (0)