Skip to content

Commit c2e9899

Browse files
committed
Expensive IMDS call
1 parent 9052ef7 commit c2e9899

File tree

3 files changed

+64
-37
lines changed

3 files changed

+64
-37
lines changed

msal/application.py

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import sys
1010
import warnings
1111
from threading import Lock
12-
import os
1312

1413
import requests
1514

@@ -20,6 +19,7 @@
2019
from .wstrust_response import *
2120
from .token_cache import TokenCache
2221
import msal.telemetry
22+
from .region import _detect_region
2323

2424

2525
# The __init__.py will import this. Not the other way around.
@@ -261,7 +261,7 @@ def __init__(
261261
# Requests, does not support session - wide timeout
262262
# But you can patch that (https://github.com/psf/requests/issues/3341):
263263
self.http_client.request = functools.partial(
264-
self.http_client.request, timeout=timeout)
264+
self.http_client.request, timeout=timeout or 2)
265265

266266
# Enable a minimal retry. Better than nothing.
267267
# https://github.com/psf/requests/blob/v2.25.1/requests/adapters.py#L94-L108
@@ -290,11 +290,8 @@ def _build_telemetry_context(
290290
self._telemetry_buffer, self._telemetry_lock, api_id,
291291
correlation_id=correlation_id, refresh_reason=refresh_reason)
292292

293-
def _detect_region(self):
294-
return os.environ.get("REGION_NAME") # TODO: or Call IMDS
295-
296293
def _get_regional_authority(self, central_authority):
297-
self._region_detected = self._region_detected or self._detect_region()
294+
self._region_detected = self._region_detected or _detect_region(self.http_client)
298295
if self._region_configured and self._region_detected != self._region_configured:
299296
logger.warning('Region configured ({}) != region detected ({})'.format(
300297
repr(self._region_configured), repr(self._region_detected)))
@@ -369,27 +366,28 @@ def _build_client(self, client_credential, authority):
369366
on_updating_rt=self.token_cache.update_rt)
370367

371368
regional_client = None
372-
regional_authority = self._get_regional_authority(authority)
373-
if regional_authority:
374-
regional_configuration = {
375-
"authorization_endpoint": regional_authority.authorization_endpoint,
376-
"token_endpoint": regional_authority.token_endpoint,
377-
"device_authorization_endpoint":
378-
regional_authority.device_authorization_endpoint or
379-
urljoin(regional_authority.token_endpoint, "devicecode"),
380-
}
381-
regional_client = Client(
382-
regional_configuration,
383-
self.client_id,
384-
http_client=self.http_client,
385-
default_headers=default_headers,
386-
default_body=default_body,
387-
client_assertion=client_assertion,
388-
client_assertion_type=client_assertion_type,
389-
on_obtaining_tokens=lambda event: self.token_cache.add(dict(
390-
event, environment=authority.instance)),
391-
on_removing_rt=self.token_cache.remove_rt,
392-
on_updating_rt=self.token_cache.update_rt)
369+
if client_credential: # Currently regional endpoint only serves some CCA flows
370+
regional_authority = self._get_regional_authority(authority)
371+
if regional_authority:
372+
regional_configuration = {
373+
"authorization_endpoint": regional_authority.authorization_endpoint,
374+
"token_endpoint": regional_authority.token_endpoint,
375+
"device_authorization_endpoint":
376+
regional_authority.device_authorization_endpoint or
377+
urljoin(regional_authority.token_endpoint, "devicecode"),
378+
}
379+
regional_client = Client(
380+
regional_configuration,
381+
self.client_id,
382+
http_client=self.http_client,
383+
default_headers=default_headers,
384+
default_body=default_body,
385+
client_assertion=client_assertion,
386+
client_assertion_type=client_assertion_type,
387+
on_obtaining_tokens=lambda event: self.token_cache.add(dict(
388+
event, environment=authority.instance)),
389+
on_removing_rt=self.token_cache.remove_rt,
390+
on_updating_rt=self.token_cache.update_rt)
393391
return central_client, regional_client
394392

395393
def initiate_auth_code_flow(

msal/region.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import os
2+
import json
3+
import logging
4+
5+
logger = logging.getLogger(__name__)
6+
7+
8+
def _detect_region(http_client):
9+
return _detect_region_of_azure_function() or _detect_region_of_azure_vm(http_client)
10+
11+
12+
def _detect_region_of_azure_function():
13+
return os.environ.get("REGION_NAME")
14+
15+
16+
def _detect_region_of_azure_vm(http_client):
17+
url = "http://169.254.169.254/metadata/instance?api-version=2021-01-01"
18+
try:
19+
# https://docs.microsoft.com/en-us/azure/virtual-machines/windows/instance-metadata-service?tabs=linux#instance-metadata
20+
resp = http_client.get(url, headers={"Metadata": "true"})
21+
except:
22+
logger.info("IMDS {} unavailable. Perhaps not running in Azure VM?".format(url))
23+
return None
24+
else:
25+
return json.loads(resp.text)["compute"]["location"]
26+

tests/test_e2e.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def _get_app_and_auth_code(
2727
app = msal.ConfidentialClientApplication(
2828
client_id,
2929
client_credential=client_secret,
30-
authority=authority, http_client=MinimalHttpClient())
30+
authority=authority, http_client=MinimalHttpClient(timeout=2))
3131
else:
3232
app = msal.PublicClientApplication(
3333
client_id, authority=authority, http_client=MinimalHttpClient())
@@ -292,7 +292,7 @@ def test_client_secret(self):
292292
self.config["client_id"],
293293
client_credential=self.config.get("client_secret"),
294294
authority=self.config.get("authority"),
295-
http_client=MinimalHttpClient())
295+
http_client=MinimalHttpClient(timeout=2))
296296
scope = self.config.get("scope", [])
297297
result = self.app.acquire_token_for_client(scope)
298298
self.assertIn('access_token', result)
@@ -307,7 +307,7 @@ def test_client_certificate(self):
307307
self.app = msal.ConfidentialClientApplication(
308308
self.config['client_id'],
309309
{"private_key": private_key, "thumbprint": client_cert["thumbprint"]},
310-
http_client=MinimalHttpClient())
310+
http_client=MinimalHttpClient(timeout=2))
311311
scope = self.config.get("scope", [])
312312
result = self.app.acquire_token_for_client(scope)
313313
self.assertIn('access_token', result)
@@ -330,7 +330,7 @@ def test_subject_name_issuer_authentication(self):
330330
"thumbprint": self.config["thumbprint"],
331331
"public_certificate": public_certificate,
332332
},
333-
http_client=MinimalHttpClient())
333+
http_client=MinimalHttpClient(timeout=2))
334334
scope = self.config.get("scope", [])
335335
result = self.app.acquire_token_for_client(scope)
336336
self.assertIn('access_token', result)
@@ -379,14 +379,17 @@ def get_lab_app(
379379
client_id,
380380
client_credential=client_secret,
381381
authority=authority,
382-
http_client=MinimalHttpClient(),
382+
http_client=MinimalHttpClient(timeout=2),
383383
**kwargs)
384384

385385
def get_session(lab_app, scopes): # BTW, this infrastructure tests the confidential client flow
386386
logger.info("Creating session")
387-
lab_token = lab_app.acquire_token_for_client(scopes)
387+
result = lab_app.acquire_token_for_client(scopes)
388+
assert result.get("access_token"), \
389+
"Unable to obtain token for lab. Encountered {}: {}".format(
390+
result.get("error"), result.get("error_description"))
388391
session = requests.Session()
389-
session.headers.update({"Authorization": "Bearer %s" % lab_token["access_token"]})
392+
session.headers.update({"Authorization": "Bearer %s" % result["access_token"]})
390393
session.hooks["response"].append(lambda r, *args, **kwargs: r.raise_for_status())
391394
return session
392395

@@ -525,7 +528,7 @@ def _test_acquire_token_obo(self, config_pca, config_cca):
525528
config_cca["client_id"],
526529
client_credential=config_cca["client_secret"],
527530
authority=config_cca["authority"],
528-
http_client=MinimalHttpClient(),
531+
http_client=MinimalHttpClient(timeout=2),
529532
# token_cache= ..., # Default token cache is all-tokens-store-in-memory.
530533
# That's fine if OBO app uses short-lived msal instance per session.
531534
# Otherwise, the OBO app need to implement a one-cache-per-user setup.
@@ -553,7 +556,7 @@ def _test_acquire_token_by_client_secret(
553556
assert client_id and client_secret and authority and scope
554557
app = msal.ConfidentialClientApplication(
555558
client_id, client_credential=client_secret, authority=authority,
556-
http_client=MinimalHttpClient())
559+
http_client=MinimalHttpClient(timeout=2))
557560
result = app.acquire_token_for_client(scope)
558561
self.assertIsNotNone(result.get("access_token"), "Got %s instead" % result)
559562

@@ -852,7 +855,7 @@ def test_acquire_token_silent_with_an_empty_cache_should_return_none(self):
852855
usertype="cloud", azureenvironment=self.environment, publicClient="no")
853856
app = msal.ConfidentialClientApplication(
854857
config['client_id'], authority=config['authority'],
855-
http_client=MinimalHttpClient())
858+
http_client=MinimalHttpClient(timeout=2))
856859
result = app.acquire_token_silent(scopes=config['scope'], account=None)
857860
self.assertEqual(result, None)
858861
# Note: An alias in this region is no longer accepting HTTPS traffic.

0 commit comments

Comments
 (0)