Skip to content

Commit 9052ef7

Browse files
committed
acquire_token_for_client() can use regional endpoint
1 parent 5b8f12c commit 9052ef7

File tree

2 files changed

+176
-12
lines changed

2 files changed

+176
-12
lines changed

msal/application.py

Lines changed: 82 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import sys
1010
import warnings
1111
from threading import Lock
12+
import os
1213

1314
import requests
1415

@@ -108,14 +109,21 @@ class ClientApplication(object):
108109
GET_ACCOUNTS_ID = "902"
109110
REMOVE_ACCOUNT_ID = "903"
110111

112+
ATTEMPT_REGION_DISCOVERY = "TryAutoDetect"
113+
111114
def __init__(
112115
self, client_id,
113116
client_credential=None, authority=None, validate_authority=True,
114117
token_cache=None,
115118
http_client=None,
116119
verify=True, proxies=None, timeout=None,
117120
client_claims=None, app_name=None, app_version=None,
118-
client_capabilities=None):
121+
client_capabilities=None,
122+
region=None, # Note: We choose to add this param in this base class,
123+
# despite it is currently only needed by ConfidentialClientApplication.
124+
# This way, it holds the same positional param place for PCA,
125+
# when we would eventually want to add this feature to PCA in future.
126+
):
119127
"""Create an instance of application.
120128
121129
:param str client_id: Your app has a client_id after you register it on AAD.
@@ -220,6 +228,25 @@ def __init__(
220228
MSAL will combine them into
221229
`claims parameter <https://openid.net/specs/openid-connect-core-1_0-final.html#ClaimsParameter`_
222230
which you will later provide via one of the acquire-token request.
231+
232+
:param str region:
233+
Added since MSAL Python 1.12.0.
234+
235+
If enabled, MSAL token requests would remain inside that region.
236+
Currently, regional endpoint only supports using
237+
``acquire_token_for_client()`` for some scopes.
238+
239+
The default value is None, which means region support remains turned off.
240+
241+
App developer can opt in to regional endpoint,
242+
by provide a region name, such as "westus", "eastus2".
243+
244+
An app running inside Azure VM can use a special keyword
245+
``ClientApplication.ATTEMPT_REGION_DISCOVERY`` to auto-detect region.
246+
(Attempting this on a non-VM could hang indefinitely.
247+
Make sure you configure a short timeout,
248+
or provide a custom http_client which has a short timeout.
249+
That way, the latency would be under your control.)
223250
"""
224251
self.client_id = client_id
225252
self.client_credential = client_credential
@@ -249,7 +276,10 @@ def __init__(
249276
self.http_client, validate_authority=validate_authority)
250277
# Here the self.authority is not the same type as authority in input
251278
self.token_cache = token_cache or TokenCache()
252-
self.client = self._build_client(client_credential, self.authority)
279+
self._region_configured = region
280+
self._region_detected = None
281+
self.client, self._regional_client = self._build_client(
282+
client_credential, self.authority)
253283
self.authority_groups = None
254284
self._telemetry_buffer = {}
255285
self._telemetry_lock = Lock()
@@ -260,6 +290,26 @@ def _build_telemetry_context(
260290
self._telemetry_buffer, self._telemetry_lock, api_id,
261291
correlation_id=correlation_id, refresh_reason=refresh_reason)
262292

293+
def _detect_region(self):
294+
return os.environ.get("REGION_NAME") # TODO: or Call IMDS
295+
296+
def _get_regional_authority(self, central_authority):
297+
self._region_detected = self._region_detected or self._detect_region()
298+
if self._region_configured and self._region_detected != self._region_configured:
299+
logger.warning('Region configured ({}) != region detected ({})'.format(
300+
repr(self._region_configured), repr(self._region_detected)))
301+
region_to_use = self._region_configured or self._region_detected
302+
if region_to_use:
303+
logger.info('Region to be used: {}'.format(repr(region_to_use)))
304+
regional_host = ("{}.login.microsoft.com".format(region_to_use)
305+
if central_authority.instance == "login.microsoftonline.com"
306+
else "{}.{}".format(region_to_use, central_authority.instance))
307+
return Authority(
308+
"https://{}/{}".format(regional_host, central_authority.tenant),
309+
self.http_client,
310+
validate_authority=False) # The central_authority has already been validated
311+
return None
312+
263313
def _build_client(self, client_credential, authority):
264314
client_assertion = None
265315
client_assertion_type = None
@@ -298,15 +348,15 @@ def _build_client(self, client_credential, authority):
298348
client_assertion_type = Client.CLIENT_ASSERTION_TYPE_JWT
299349
else:
300350
default_body['client_secret'] = client_credential
301-
server_configuration = {
351+
central_configuration = {
302352
"authorization_endpoint": authority.authorization_endpoint,
303353
"token_endpoint": authority.token_endpoint,
304354
"device_authorization_endpoint":
305355
authority.device_authorization_endpoint or
306356
urljoin(authority.token_endpoint, "devicecode"),
307357
}
308-
return Client(
309-
server_configuration,
358+
central_client = Client(
359+
central_configuration,
310360
self.client_id,
311361
http_client=self.http_client,
312362
default_headers=default_headers,
@@ -318,6 +368,30 @@ def _build_client(self, client_credential, authority):
318368
on_removing_rt=self.token_cache.remove_rt,
319369
on_updating_rt=self.token_cache.update_rt)
320370

371+
regional_client = None
372+
regional_authority = self._get_regional_authority(authority)
373+
if regional_authority:
374+
regional_configuration = {
375+
"authorization_endpoint": regional_authority.authorization_endpoint,
376+
"token_endpoint": regional_authority.token_endpoint,
377+
"device_authorization_endpoint":
378+
regional_authority.device_authorization_endpoint or
379+
urljoin(regional_authority.token_endpoint, "devicecode"),
380+
}
381+
regional_client = Client(
382+
regional_configuration,
383+
self.client_id,
384+
http_client=self.http_client,
385+
default_headers=default_headers,
386+
default_body=default_body,
387+
client_assertion=client_assertion,
388+
client_assertion_type=client_assertion_type,
389+
on_obtaining_tokens=lambda event: self.token_cache.add(dict(
390+
event, environment=authority.instance)),
391+
on_removing_rt=self.token_cache.remove_rt,
392+
on_updating_rt=self.token_cache.update_rt)
393+
return central_client, regional_client
394+
321395
def initiate_auth_code_flow(
322396
self,
323397
scopes, # type: list[str]
@@ -953,7 +1027,7 @@ def _acquire_token_silent_by_finding_specific_refresh_token(
9531027
# target=scopes, # AAD RTs are scope-independent
9541028
query=query)
9551029
logger.debug("Found %d RTs matching %s", len(matches), query)
956-
client = self._build_client(self.client_credential, authority)
1030+
client, _ = self._build_client(self.client_credential, authority)
9571031

9581032
response = None # A distinguishable value to mean cache is empty
9591033
telemetry_context = self._build_telemetry_context(
@@ -1304,7 +1378,8 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs):
13041378
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
13051379
telemetry_context = self._build_telemetry_context(
13061380
self.ACQUIRE_TOKEN_FOR_CLIENT_ID)
1307-
response = _clean_up(self.client.obtain_token_for_client(
1381+
client = self._regional_client or self.client
1382+
response = _clean_up(client.obtain_token_for_client(
13081383
scope=scopes, # This grant flow requires no scope decoration
13091384
headers=telemetry_context.generate_headers(),
13101385
data=dict(

tests/test_e2e.py

Lines changed: 94 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,12 @@ def assertCacheWorksForUser(
9999
"We should get an AT from acquire_token_silent(...) call")
100100

101101
def assertCacheWorksForApp(self, result_from_wire, scope):
102+
logger.debug(
103+
"%s: cache = %s, id_token_claims = %s",
104+
self.id(),
105+
json.dumps(self.app.token_cache._cache, indent=4),
106+
json.dumps(result_from_wire.get("id_token_claims"), indent=4),
107+
)
102108
# Going to test acquire_token_silent(...) to locate an AT from cache
103109
result_from_cache = self.app.acquire_token_silent(scope, account=None)
104110
self.assertIsNotNone(result_from_cache)
@@ -345,7 +351,9 @@ def test_device_flow(self):
345351
def get_lab_app(
346352
env_client_id="LAB_APP_CLIENT_ID",
347353
env_client_secret="LAB_APP_CLIENT_SECRET",
348-
):
354+
authority="https://login.microsoftonline.com/"
355+
"72f988bf-86f1-41af-91ab-2d7cd011db47", # Microsoft tenant ID
356+
**kwargs):
349357
"""Returns the lab app as an MSAL confidential client.
350358
351359
Get it from environment variables if defined, otherwise fall back to use MSI.
@@ -367,10 +375,12 @@ def get_lab_app(
367375
env_client_id, env_client_secret)
368376
# See also https://microsoft.sharepoint-df.com/teams/MSIDLABSExtended/SitePages/Programmatically-accessing-LAB-API's.aspx
369377
raise unittest.SkipTest("MSI-based mechanism has not been implemented yet")
370-
return msal.ConfidentialClientApplication(client_id, client_secret,
371-
authority="https://login.microsoftonline.com/"
372-
"72f988bf-86f1-41af-91ab-2d7cd011db47", # Microsoft tenant ID
373-
http_client=MinimalHttpClient())
378+
return msal.ConfidentialClientApplication(
379+
client_id,
380+
client_credential=client_secret,
381+
authority=authority,
382+
http_client=MinimalHttpClient(),
383+
**kwargs)
374384

375385
def get_session(lab_app, scopes): # BTW, this infrastructure tests the confidential client flow
376386
logger.info("Creating session")
@@ -726,6 +736,85 @@ def test_b2c_acquire_token_by_ropc(self):
726736
)
727737

728738

739+
class WorldWideRegionalEndpointTestCase(LabBasedTestCase):
740+
region = "westus"
741+
742+
def test_acquire_token_for_client_should_hit_regional_endpoint(self):
743+
"""This is the only grant supported by regional endpoint, for now"""
744+
self.app = get_lab_app( # Regional endpoint only supports confidential client
745+
## Would fail the OIDC Discovery
746+
#authority="https://westus2.login.microsoftonline.com/"
747+
# "72f988bf-86f1-41af-91ab-2d7cd011db47", # Microsoft tenant ID
748+
749+
#authority="https://westus.login.microsoft.com/microsoft.onmicrosoft.com",
750+
#validate_authority=False,
751+
752+
authority="https://login.microsoftonline.com/microsoft.onmicrosoft.com",
753+
region=self.region, # Explicitly use this region, regardless of detection
754+
)
755+
scopes = ["https://graph.microsoft.com/.default"]
756+
result = self.app.acquire_token_for_client(
757+
scopes,
758+
params={"AllowEstsRNonMsi": "true"}, # For testing regional endpoint
759+
)
760+
self.assertIn('access_token', result)
761+
self.assertCacheWorksForApp(result, scopes)
762+
# TODO: Test the request hit the regional endpoint self.region?
763+
764+
765+
class RegionalEndpointViaEnvVarTestCase(WorldWideRegionalEndpointTestCase):
766+
767+
def setUp(self):
768+
os.environ["REGION_NAME"] = "eastus"
769+
770+
def tearDown(self):
771+
del os.environ["REGION_NAME"]
772+
773+
@unittest.skipUnless(
774+
os.getenv("LAB_OBO_CLIENT_SECRET"),
775+
"Need LAB_OBO_CLIENT_SECRET from https://aka.ms/GetLabSecret?Secret=TodoListServiceV2-OBO")
776+
@unittest.skipUnless(
777+
os.getenv("LAB_OBO_CONFIDENTIAL_CLIENT_ID"),
778+
"Need LAB_OBO_CONFIDENTIAL_CLIENT_ID from https://docs.msidlab.com/flows/onbehalfofflow.html")
779+
@unittest.skipUnless(
780+
os.getenv("LAB_OBO_PUBLIC_CLIENT_ID"),
781+
"Need LAB_OBO_PUBLIC_CLIENT_ID from https://docs.msidlab.com/flows/onbehalfofflow.html")
782+
def test_cca_obo_should_bypass_regional_endpoint_therefore_still_work(self):
783+
"""We test OBO because it is implemented in sub class ConfidentialClientApplication"""
784+
config = self.get_lab_user(usertype="cloud")
785+
786+
config_cca = {}
787+
config_cca.update(config)
788+
config_cca["client_id"] = os.getenv("LAB_OBO_CONFIDENTIAL_CLIENT_ID")
789+
config_cca["scope"] = ["https://graph.microsoft.com/.default"]
790+
config_cca["client_secret"] = os.getenv("LAB_OBO_CLIENT_SECRET")
791+
792+
config_pca = {}
793+
config_pca.update(config)
794+
config_pca["client_id"] = os.getenv("LAB_OBO_PUBLIC_CLIENT_ID")
795+
config_pca["password"] = self.get_lab_user_secret(config_pca["lab_name"])
796+
config_pca["scope"] = ["api://%s/read" % config_cca["client_id"]]
797+
798+
self._test_acquire_token_obo(config_pca, config_cca)
799+
800+
@unittest.skipUnless(
801+
os.getenv("LAB_OBO_CLIENT_SECRET"),
802+
"Need LAB_OBO_CLIENT_SECRET from https://aka.ms/GetLabSecret?Secret=TodoListServiceV2-OBO")
803+
@unittest.skipUnless(
804+
os.getenv("LAB_OBO_CONFIDENTIAL_CLIENT_ID"),
805+
"Need LAB_OBO_CONFIDENTIAL_CLIENT_ID from https://docs.msidlab.com/flows/onbehalfofflow.html")
806+
def test_cca_ropc_should_bypass_regional_endpoint_therefore_still_work(self):
807+
"""We test ROPC because it is implemented in base class ClientApplication"""
808+
config = self.get_lab_user(usertype="cloud")
809+
config["password"] = self.get_lab_user_secret(config["lab_name"])
810+
# We repurpose the obo confidential app to test ROPC
811+
# Swap in the OBO confidential app
812+
config["client_id"] = os.getenv("LAB_OBO_CONFIDENTIAL_CLIENT_ID")
813+
config["scope"] = ["https://graph.microsoft.com/.default"]
814+
config["client_secret"] = os.getenv("LAB_OBO_CLIENT_SECRET")
815+
self._test_username_password(**config)
816+
817+
729818
class ArlingtonCloudTestCase(LabBasedTestCase):
730819
environment = "azureusgovernment"
731820

0 commit comments

Comments
 (0)