Skip to content

Commit 0b84f5e

Browse files
authored
Merge pull request #358 from AzureAD/reginal-endpoint-experiment
Reginal endpoint support
2 parents cb5f36f + babe142 commit 0b84f5e

File tree

4 files changed

+293
-17
lines changed

4 files changed

+293
-17
lines changed

msal/application.py

Lines changed: 133 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .wstrust_response import *
2020
from .token_cache import TokenCache
2121
import msal.telemetry
22+
from .region import _detect_region
2223

2324

2425
# The __init__.py will import this. Not the other way around.
@@ -108,14 +109,21 @@ class ClientApplication(object):
108109
GET_ACCOUNTS_ID = "902"
109110
REMOVE_ACCOUNT_ID = "903"
110111

112+
ATTEMPT_REGION_DISCOVERY = True # "TryAutoDetect"
113+
111114
def __init__(
112115
self, client_id,
113116
client_credential=None, authority=None, validate_authority=True,
114117
token_cache=None,
115118
http_client=None,
116119
verify=True, proxies=None, timeout=None,
117120
client_claims=None, app_name=None, app_version=None,
118-
client_capabilities=None):
121+
client_capabilities=None,
122+
azure_region=None, # Note: We choose to add this param in this base class,
123+
# despite it is currently only needed by ConfidentialClientApplication.
124+
# This way, it holds the same positional param place for PCA,
125+
# when we would eventually want to add this feature to PCA in future.
126+
):
119127
"""Create an instance of application.
120128
121129
:param str client_id: Your app has a client_id after you register it on AAD.
@@ -220,6 +228,53 @@ def __init__(
220228
MSAL will combine them into
221229
`claims parameter <https://openid.net/specs/openid-connect-core-1_0-final.html#ClaimsParameter`_
222230
which you will later provide via one of the acquire-token request.
231+
232+
:param str azure_region:
233+
Added since MSAL Python 1.12.0.
234+
235+
As of 2021 May, regional service is only available for
236+
``acquire_token_for_client()`` sent by any of the following scenarios::
237+
238+
1. An app powered by a capable MSAL
239+
(MSAL Python 1.12+ will be provisioned)
240+
241+
2. An app with managed identity, which is formerly known as MSI.
242+
(However MSAL Python does not support managed identity,
243+
so this one does not apply.)
244+
245+
3. An app authenticated by
246+
`Subject Name/Issuer (SNI) <https://github.com/AzureAD/microsoft-authentication-library-for-python/issues/60>`_.
247+
248+
4. An app which already onboard to the region's allow-list.
249+
250+
MSAL's default value is None, which means region behavior remains off.
251+
If enabled, the `acquire_token_for_client()`-relevant traffic
252+
would remain inside that region.
253+
254+
App developer can opt in to a regional endpoint,
255+
by provide its region name, such as "westus", "eastus2".
256+
You can find a full list of regions by running
257+
``az account list-locations -o table``, or referencing to
258+
`this doc <https://docs.microsoft.com/en-us/dotnet/api/microsoft.azure.management.resourcemanager.fluent.core.region?view=azure-dotnet>`_.
259+
260+
An app running inside Azure Functions and Azure VM can use a special keyword
261+
``ClientApplication.ATTEMPT_REGION_DISCOVERY`` to auto-detect region.
262+
263+
.. note::
264+
265+
Setting ``azure_region`` to non-``None`` for an app running
266+
outside of Azure Function/VM could hang indefinitely.
267+
268+
You should consider opting in/out region behavior on-demand,
269+
by loading ``azure_region=None`` or ``azure_region="westus"``
270+
or ``azure_region=True`` (which means opt-in and auto-detect)
271+
from your per-deployment configuration, and then do
272+
``app = ConfidentialClientApplication(..., azure_region=azure_region)``.
273+
274+
Alternatively, you can configure a short timeout,
275+
or provide a custom http_client which has a short timeout.
276+
That way, the latency would be under your control,
277+
but still less performant than opting out of region feature.
223278
"""
224279
self.client_id = client_id
225280
self.client_credential = client_credential
@@ -244,12 +299,29 @@ def __init__(
244299

245300
self.app_name = app_name
246301
self.app_version = app_version
247-
self.authority = Authority(
302+
303+
# Here the self.authority will not be the same type as authority in input
304+
try:
305+
self.authority = Authority(
248306
authority or "https://login.microsoftonline.com/common/",
249307
self.http_client, validate_authority=validate_authority)
250-
# Here the self.authority is not the same type as authority in input
308+
except ValueError: # Those are explicit authority validation errors
309+
raise
310+
except Exception: # The rest are typically connection errors
311+
if validate_authority and region:
312+
# Since caller opts in to use region, here we tolerate connection
313+
# errors happened during authority validation at non-region endpoint
314+
self.authority = Authority(
315+
authority or "https://login.microsoftonline.com/common/",
316+
self.http_client, validate_authority=False)
317+
else:
318+
raise
319+
251320
self.token_cache = token_cache or TokenCache()
252-
self.client = self._build_client(client_credential, self.authority)
321+
self._region_configured = azure_region
322+
self._region_detected = None
323+
self.client, self._regional_client = self._build_client(
324+
client_credential, self.authority)
253325
self.authority_groups = None
254326
self._telemetry_buffer = {}
255327
self._telemetry_lock = Lock()
@@ -260,6 +332,32 @@ def _build_telemetry_context(
260332
self._telemetry_buffer, self._telemetry_lock, api_id,
261333
correlation_id=correlation_id, refresh_reason=refresh_reason)
262334

335+
def _get_regional_authority(self, central_authority):
336+
is_region_specified = bool(self._region_configured
337+
and self._region_configured != self.ATTEMPT_REGION_DISCOVERY)
338+
self._region_detected = self._region_detected or _detect_region(
339+
self.http_client if self._region_configured is not None else None)
340+
if (is_region_specified and self._region_configured != self._region_detected):
341+
logger.warning('Region configured ({}) != region detected ({})'.format(
342+
repr(self._region_configured), repr(self._region_detected)))
343+
region_to_use = (
344+
self._region_configured if is_region_specified else self._region_detected)
345+
if region_to_use:
346+
logger.info('Region to be used: {}'.format(repr(region_to_use)))
347+
regional_host = ("{}.login.microsoft.com".format(region_to_use)
348+
if central_authority.instance in (
349+
# The list came from https://github.com/AzureAD/microsoft-authentication-library-for-python/pull/358/files#r629400328
350+
"login.microsoftonline.com",
351+
"login.windows.net",
352+
"sts.windows.net",
353+
)
354+
else "{}.{}".format(region_to_use, central_authority.instance))
355+
return Authority(
356+
"https://{}/{}".format(regional_host, central_authority.tenant),
357+
self.http_client,
358+
validate_authority=False) # The central_authority has already been validated
359+
return None
360+
263361
def _build_client(self, client_credential, authority):
264362
client_assertion = None
265363
client_assertion_type = None
@@ -298,15 +396,15 @@ def _build_client(self, client_credential, authority):
298396
client_assertion_type = Client.CLIENT_ASSERTION_TYPE_JWT
299397
else:
300398
default_body['client_secret'] = client_credential
301-
server_configuration = {
399+
central_configuration = {
302400
"authorization_endpoint": authority.authorization_endpoint,
303401
"token_endpoint": authority.token_endpoint,
304402
"device_authorization_endpoint":
305403
authority.device_authorization_endpoint or
306404
urljoin(authority.token_endpoint, "devicecode"),
307405
}
308-
return Client(
309-
server_configuration,
406+
central_client = Client(
407+
central_configuration,
310408
self.client_id,
311409
http_client=self.http_client,
312410
default_headers=default_headers,
@@ -318,6 +416,31 @@ def _build_client(self, client_credential, authority):
318416
on_removing_rt=self.token_cache.remove_rt,
319417
on_updating_rt=self.token_cache.update_rt)
320418

419+
regional_client = None
420+
if client_credential: # Currently regional endpoint only serves some CCA flows
421+
regional_authority = self._get_regional_authority(authority)
422+
if regional_authority:
423+
regional_configuration = {
424+
"authorization_endpoint": regional_authority.authorization_endpoint,
425+
"token_endpoint": regional_authority.token_endpoint,
426+
"device_authorization_endpoint":
427+
regional_authority.device_authorization_endpoint or
428+
urljoin(regional_authority.token_endpoint, "devicecode"),
429+
}
430+
regional_client = Client(
431+
regional_configuration,
432+
self.client_id,
433+
http_client=self.http_client,
434+
default_headers=default_headers,
435+
default_body=default_body,
436+
client_assertion=client_assertion,
437+
client_assertion_type=client_assertion_type,
438+
on_obtaining_tokens=lambda event: self.token_cache.add(dict(
439+
event, environment=authority.instance)),
440+
on_removing_rt=self.token_cache.remove_rt,
441+
on_updating_rt=self.token_cache.update_rt)
442+
return central_client, regional_client
443+
321444
def initiate_auth_code_flow(
322445
self,
323446
scopes, # type: list[str]
@@ -953,7 +1076,7 @@ def _acquire_token_silent_by_finding_specific_refresh_token(
9531076
# target=scopes, # AAD RTs are scope-independent
9541077
query=query)
9551078
logger.debug("Found %d RTs matching %s", len(matches), query)
956-
client = self._build_client(self.client_credential, authority)
1079+
client, _ = self._build_client(self.client_credential, authority)
9571080

9581081
response = None # A distinguishable value to mean cache is empty
9591082
telemetry_context = self._build_telemetry_context(
@@ -1304,7 +1427,8 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs):
13041427
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
13051428
telemetry_context = self._build_telemetry_context(
13061429
self.ACQUIRE_TOKEN_FOR_CLIENT_ID)
1307-
response = _clean_up(self.client.obtain_token_for_client(
1430+
client = self._regional_client or self.client
1431+
response = _clean_up(client.obtain_token_for_client(
13081432
scope=scopes, # This grant flow requires no scope decoration
13091433
headers=telemetry_context.generate_headers(),
13101434
data=dict(

msal/region.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import os
2+
import logging
3+
4+
logger = logging.getLogger(__name__)
5+
6+
7+
def _detect_region(http_client=None):
8+
region = _detect_region_of_azure_function() # It is cheap, so we do it always
9+
if http_client and not region:
10+
return _detect_region_of_azure_vm(http_client) # It could hang for minutes
11+
return region
12+
13+
14+
def _detect_region_of_azure_function():
15+
return os.environ.get("REGION_NAME")
16+
17+
18+
def _detect_region_of_azure_vm(http_client):
19+
url = (
20+
"http://169.254.169.254/metadata/instance"
21+
22+
# Utilize the "route parameters" feature to obtain region as a string
23+
# https://docs.microsoft.com/en-us/azure/virtual-machines/windows/instance-metadata-service?tabs=linux#route-parameters
24+
"/compute/location?format=text"
25+
26+
# Location info is available since API version 2017-04-02
27+
# https://docs.microsoft.com/en-us/azure/virtual-machines/windows/instance-metadata-service?tabs=linux#response-1
28+
"&api-version=2021-01-01"
29+
)
30+
logger.info(
31+
"Connecting to IMDS {}. "
32+
"It may take a while if you are running outside of Azure. "
33+
"You should consider opting in/out region behavior on-demand, "
34+
'by loading a boolean flag "is_deployed_in_azure" '
35+
'from your per-deployment config and then do '
36+
'"app = ConfidentialClientApplication(..., '
37+
'azure_region=is_deployed_in_azure)"'.format(url))
38+
try:
39+
# https://docs.microsoft.com/en-us/azure/virtual-machines/windows/instance-metadata-service?tabs=linux#instance-metadata
40+
resp = http_client.get(url, headers={"Metadata": "true"})
41+
except:
42+
logger.info(
43+
"IMDS {} unavailable. Perhaps not running in Azure VM?".format(url))
44+
return None
45+
else:
46+
return resp.text.strip()
47+

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
# We will go with "<4" for now, which is also what our another dependency,
8585
# pyjwt, currently use.
8686

87+
"mock;python_version<'3.3'",
8788
]
8889
)
8990

0 commit comments

Comments
 (0)