Skip to content

Commit 8c26f4c

Browse files
committed
Fix failing tests
1 parent 9a286ab commit 8c26f4c

File tree

1 file changed

+59
-54
lines changed

1 file changed

+59
-54
lines changed

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AadInstanceDiscoveryProvider.java

Lines changed: 59 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -62,27 +62,23 @@ static InstanceDiscoveryMetadataEntry getMetadataEntry(URL authorityUrl,
6262
ServiceBundle serviceBundle) {
6363
String host = authorityUrl.getHost();
6464

65-
ExecutorService executor = Executors.newSingleThreadExecutor();
66-
67-
Future<String> future = executor.submit(() -> performRegionalDiscovery(authorityUrl, msalRequest, serviceBundle));
65+
if (shouldUseRegionalEndpoint(msalRequest)) {
66+
//Server side telemetry requires the result from region discovery when any part of the region API is used
67+
String detectedRegion = discoverRegion(msalRequest, serviceBundle);
6868

69-
try {
70-
log.info("Starting call to IMDS endpoint.");
71-
host = future.get(IMDS_TIMEOUT, IMDS_TIMEOUT_UNIT);
72-
} catch (TimeoutException ex) {
73-
log.info("Cancelled call to IMDS endpoint after waiting for 2 seconds");
74-
future.cancel(true);
7569
if (msalRequest.application().azureRegion() != null) {
7670
host = getRegionalizedHost(authorityUrl.getHost(), msalRequest.application().azureRegion());
7771
}
78-
} catch (Exception ex) {
79-
// handle other exceptions
80-
log.info("Exception while calling IMDS endpoint" + ex.getMessage());
81-
if (msalRequest.application().azureRegion() != null) {
82-
host = getRegionalizedHost(authorityUrl.getHost(), msalRequest.application().azureRegion());
72+
73+
//If region autodetection is enabled and a specific region not already set,
74+
// set the application's region to the discovered region so that future requests can skip the IMDS endpoint call
75+
if (null == msalRequest.application().azureRegion() && msalRequest.application().autoDetectRegion()
76+
&& null != detectedRegion) {
77+
msalRequest.application().azureRegion = detectedRegion;
8378
}
84-
} finally {
85-
executor.shutdownNow();
79+
cacheRegionInstanceMetadata(authorityUrl.getHost(), msalRequest.application().azureRegion());
80+
serviceBundle.getServerSideTelemetry().getCurrentRequest().regionOutcome(
81+
determineRegionOutcome(detectedRegion, msalRequest.application().azureRegion(), msalRequest.application().autoDetectRegion()));
8682
}
8783

8884
InstanceDiscoveryMetadataEntry result = cache.get(host);
@@ -103,32 +99,6 @@ static InstanceDiscoveryMetadataEntry getMetadataEntry(URL authorityUrl,
10399
return cache.get(host);
104100
}
105101

106-
private static String performRegionalDiscovery(URL authorityUrl, MsalRequest msalRequest, ServiceBundle serviceBundle){
107-
108-
String host = authorityUrl.getHost();
109-
110-
if (shouldUseRegionalEndpoint(msalRequest)) {
111-
//Server side telemetry requires the result from region discovery when any part of the region API is used
112-
String detectedRegion = discoverRegion(msalRequest, serviceBundle);
113-
114-
if (msalRequest.application().azureRegion() != null) {
115-
host = getRegionalizedHost(authorityUrl.getHost(), msalRequest.application().azureRegion());
116-
}
117-
118-
//If region autodetection is enabled and a specific region not already set,
119-
// set the application's region to the discovered region so that future requests can skip the IMDS endpoint call
120-
if (null == msalRequest.application().azureRegion() && msalRequest.application().autoDetectRegion()
121-
&& null != detectedRegion) {
122-
msalRequest.application().azureRegion = detectedRegion;
123-
}
124-
cacheRegionInstanceMetadata(host, authorityUrl.getHost());
125-
serviceBundle.getServerSideTelemetry().getCurrentRequest().regionOutcome(
126-
determineRegionOutcome(detectedRegion, msalRequest.application().azureRegion(), msalRequest.application().autoDetectRegion()));
127-
}
128-
129-
return host;
130-
}
131-
132102
static Set<String> getAliases(String host) {
133103
if (cache.containsKey(host)) {
134104
return cache.get(host).aliases();
@@ -192,10 +162,11 @@ private static boolean shouldUseRegionalEndpoint(MsalRequest msalRequest){
192162
return false;
193163
}
194164

195-
static void cacheRegionInstanceMetadata(String regionalHost, String host) {
165+
static void cacheRegionInstanceMetadata(String host, String region) {
196166

197167
Set<String> aliases = new HashSet<>();
198168
aliases.add(host);
169+
String regionalHost = getRegionalizedHost(host, region);
199170

200171
cache.putIfAbsent(regionalHost, InstanceDiscoveryMetadataEntry.builder().
201172
preferredCache(host).
@@ -322,33 +293,67 @@ private static String discoverRegion(MsalRequest msalRequest, ServiceBundle serv
322293
return System.getenv(REGION_NAME);
323294
}
324295

325-
try {
326-
//Check the IMDS endpoint to retrieve current region (will only work if application is running in an Azure VM)
327-
Map<String, String> headers = new HashMap<>();
328-
headers.put("Metadata", "true");
329-
IHttpResponse httpResponse = executeRequest(IMDS_ENDPOINT, headers, msalRequest, serviceBundle);
296+
// try {
297+
// //Check the IMDS endpoint to retrieve current region (will only work if application is running in an Azure VM)
298+
// Map<String, String> headers = new HashMap<>();
299+
// headers.put("Metadata", "true");
300+
// IHttpResponse httpResponse = executeRequest(IMDS_ENDPOINT, headers, msalRequest, serviceBundle);
301+
//
302+
// //If call to IMDS endpoint was successful, return region from response body
303+
// if (httpResponse.statusCode() == HttpHelper.HTTP_STATUS_200 && !httpResponse.body().isEmpty()) {
304+
// log.info(String.format("Region retrieved from IMDS endpoint: %s", httpResponse.body()));
305+
// currentRequest.regionSource(RegionTelemetry.REGION_SOURCE_IMDS.telemetryValue);
306+
//
307+
// return httpResponse.body();
308+
// }
309+
//
310+
// log.warn(String.format("Call to local IMDS failed with status code: %s, or response was empty", httpResponse.statusCode()));
311+
// currentRequest.regionSource(RegionTelemetry.REGION_SOURCE_FAILED_AUTODETECT.telemetryValue);
312+
//
313+
// return null;
314+
// } catch (Exception e) {
315+
// //IMDS call failed, cannot find region
316+
// //The IMDS endpoint is only available from within an Azure environment, so the most common cause of this
317+
// // exception will likely be java.net.SocketException: Network is unreachable: connect
318+
// log.warn(String.format("Exception during call to local IMDS endpoint: %s", e.getMessage()));
319+
// currentRequest.regionSource(RegionTelemetry.REGION_SOURCE_FAILED_AUTODETECT.telemetryValue);
320+
//
321+
// return null;
322+
// }
323+
324+
//Check the IMDS endpoint to retrieve current region (will only work if application is running in an Azure VM)
325+
Map<String, String> headers = new HashMap<>();
326+
headers.put("Metadata", "true");
327+
328+
ExecutorService executor = Executors.newSingleThreadExecutor();
329+
Future<IHttpResponse> future = executor.submit(() -> executeRequest(IMDS_ENDPOINT, headers, msalRequest, serviceBundle));
330330

331+
try {
332+
log.info("Starting call to IMDS endpoint.");
333+
IHttpResponse httpResponse = future.get(IMDS_TIMEOUT, IMDS_TIMEOUT_UNIT);
331334
//If call to IMDS endpoint was successful, return region from response body
332335
if (httpResponse.statusCode() == HttpHelper.HTTP_STATUS_200 && !httpResponse.body().isEmpty()) {
333336
log.info(String.format("Region retrieved from IMDS endpoint: %s", httpResponse.body()));
334337
currentRequest.regionSource(RegionTelemetry.REGION_SOURCE_IMDS.telemetryValue);
335338

336339
return httpResponse.body();
337340
}
338-
339341
log.warn(String.format("Call to local IMDS failed with status code: %s, or response was empty", httpResponse.statusCode()));
340342
currentRequest.regionSource(RegionTelemetry.REGION_SOURCE_FAILED_AUTODETECT.telemetryValue);
341-
342-
return null;
343-
} catch (Exception e) {
343+
} catch (Exception ex) {
344+
// handle other exceptions
344345
//IMDS call failed, cannot find region
345346
//The IMDS endpoint is only available from within an Azure environment, so the most common cause of this
346347
// exception will likely be java.net.SocketException: Network is unreachable: connect
347-
log.warn(String.format("Exception during call to local IMDS endpoint: %s", e.getMessage()));
348+
log.warn(String.format("Exception during call to local IMDS endpoint: %s", ex.getMessage()));
348349
currentRequest.regionSource(RegionTelemetry.REGION_SOURCE_FAILED_AUTODETECT.telemetryValue);
350+
future.cancel(true);
349351

350-
return null;
352+
} finally {
353+
executor.shutdownNow();
351354
}
355+
356+
return null;
352357
}
353358

354359
private static void doInstanceDiscoveryAndCache(URL authorityUrl,

0 commit comments

Comments
 (0)