9
9
import sys
10
10
import warnings
11
11
from threading import Lock
12
+ import os
12
13
13
14
import requests
14
15
@@ -108,14 +109,21 @@ class ClientApplication(object):
108
109
GET_ACCOUNTS_ID = "902"
109
110
REMOVE_ACCOUNT_ID = "903"
110
111
112
+ ATTEMPT_REGION_DISCOVERY = "TryAutoDetect"
113
+
111
114
def __init__ (
112
115
self , client_id ,
113
116
client_credential = None , authority = None , validate_authority = True ,
114
117
token_cache = None ,
115
118
http_client = None ,
116
119
verify = True , proxies = None , timeout = None ,
117
120
client_claims = None , app_name = None , app_version = None ,
118
- client_capabilities = None ):
121
+ client_capabilities = None ,
122
+ region = None , # Note: We choose to add this param in this base class,
123
+ # despite it is currently only needed by ConfidentialClientApplication.
124
+ # This way, it holds the same positional param place for PCA,
125
+ # when we would eventually want to add this feature to PCA in future.
126
+ ):
119
127
"""Create an instance of application.
120
128
121
129
:param str client_id: Your app has a client_id after you register it on AAD.
@@ -220,6 +228,25 @@ def __init__(
220
228
MSAL will combine them into
221
229
`claims parameter <https://openid.net/specs/openid-connect-core-1_0-final.html#ClaimsParameter`_
222
230
which you will later provide via one of the acquire-token request.
231
+
232
+ :param str region:
233
+ Added since MSAL Python 1.12.0.
234
+
235
+ If enabled, MSAL token requests would remain inside that region.
236
+ Currently, regional endpoint only supports using
237
+ ``acquire_token_for_client()`` for some scopes.
238
+
239
+ The default value is None, which means region support remains turned off.
240
+
241
+ App developer can opt in to regional endpoint,
242
+ by provide a region name, such as "westus", "eastus2".
243
+
244
+ An app running inside Azure VM can use a special keyword
245
+ ``ClientApplication.ATTEMPT_REGION_DISCOVERY`` to auto-detect region.
246
+ (Attempting this on a non-VM could hang indefinitely.
247
+ Make sure you configure a short timeout,
248
+ or provide a custom http_client which has a short timeout.
249
+ That way, the latency would be under your control.)
223
250
"""
224
251
self .client_id = client_id
225
252
self .client_credential = client_credential
@@ -249,7 +276,10 @@ def __init__(
249
276
self .http_client , validate_authority = validate_authority )
250
277
# Here the self.authority is not the same type as authority in input
251
278
self .token_cache = token_cache or TokenCache ()
252
- self .client = self ._build_client (client_credential , self .authority )
279
+ self ._region_configured = region
280
+ self ._region_detected = None
281
+ self .client , self ._regional_client = self ._build_client (
282
+ client_credential , self .authority )
253
283
self .authority_groups = None
254
284
self ._telemetry_buffer = {}
255
285
self ._telemetry_lock = Lock ()
@@ -260,6 +290,26 @@ def _build_telemetry_context(
260
290
self ._telemetry_buffer , self ._telemetry_lock , api_id ,
261
291
correlation_id = correlation_id , refresh_reason = refresh_reason )
262
292
293
+ def _detect_region (self ):
294
+ return os .environ .get ("REGION_NAME" ) # TODO: or Call IMDS
295
+
296
+ def _get_regional_authority (self , central_authority ):
297
+ self ._region_detected = self ._region_detected or self ._detect_region ()
298
+ if self ._region_configured and self ._region_detected != self ._region_configured :
299
+ logger .warning ('Region configured ({}) != region detected ({})' .format (
300
+ repr (self ._region_configured ), repr (self ._region_detected )))
301
+ region_to_use = self ._region_configured or self ._region_detected
302
+ if region_to_use :
303
+ logger .info ('Region to be used: {}' .format (repr (region_to_use )))
304
+ regional_host = ("{}.login.microsoft.com" .format (region_to_use )
305
+ if central_authority .instance == "login.microsoftonline.com"
306
+ else "{}.{}" .format (region_to_use , central_authority .instance ))
307
+ return Authority (
308
+ "https://{}/{}" .format (regional_host , central_authority .tenant ),
309
+ self .http_client ,
310
+ validate_authority = False ) # The central_authority has already been validated
311
+ return None
312
+
263
313
def _build_client (self , client_credential , authority ):
264
314
client_assertion = None
265
315
client_assertion_type = None
@@ -298,15 +348,15 @@ def _build_client(self, client_credential, authority):
298
348
client_assertion_type = Client .CLIENT_ASSERTION_TYPE_JWT
299
349
else :
300
350
default_body ['client_secret' ] = client_credential
301
- server_configuration = {
351
+ central_configuration = {
302
352
"authorization_endpoint" : authority .authorization_endpoint ,
303
353
"token_endpoint" : authority .token_endpoint ,
304
354
"device_authorization_endpoint" :
305
355
authority .device_authorization_endpoint or
306
356
urljoin (authority .token_endpoint , "devicecode" ),
307
357
}
308
- return Client (
309
- server_configuration ,
358
+ central_client = Client (
359
+ central_configuration ,
310
360
self .client_id ,
311
361
http_client = self .http_client ,
312
362
default_headers = default_headers ,
@@ -318,6 +368,30 @@ def _build_client(self, client_credential, authority):
318
368
on_removing_rt = self .token_cache .remove_rt ,
319
369
on_updating_rt = self .token_cache .update_rt )
320
370
371
+ regional_client = None
372
+ regional_authority = self ._get_regional_authority (authority )
373
+ if regional_authority :
374
+ regional_configuration = {
375
+ "authorization_endpoint" : regional_authority .authorization_endpoint ,
376
+ "token_endpoint" : regional_authority .token_endpoint ,
377
+ "device_authorization_endpoint" :
378
+ regional_authority .device_authorization_endpoint or
379
+ urljoin (regional_authority .token_endpoint , "devicecode" ),
380
+ }
381
+ regional_client = Client (
382
+ regional_configuration ,
383
+ self .client_id ,
384
+ http_client = self .http_client ,
385
+ default_headers = default_headers ,
386
+ default_body = default_body ,
387
+ client_assertion = client_assertion ,
388
+ client_assertion_type = client_assertion_type ,
389
+ on_obtaining_tokens = lambda event : self .token_cache .add (dict (
390
+ event , environment = authority .instance )),
391
+ on_removing_rt = self .token_cache .remove_rt ,
392
+ on_updating_rt = self .token_cache .update_rt )
393
+ return central_client , regional_client
394
+
321
395
def initiate_auth_code_flow (
322
396
self ,
323
397
scopes , # type: list[str]
@@ -953,7 +1027,7 @@ def _acquire_token_silent_by_finding_specific_refresh_token(
953
1027
# target=scopes, # AAD RTs are scope-independent
954
1028
query = query )
955
1029
logger .debug ("Found %d RTs matching %s" , len (matches ), query )
956
- client = self ._build_client (self .client_credential , authority )
1030
+ client , _ = self ._build_client (self .client_credential , authority )
957
1031
958
1032
response = None # A distinguishable value to mean cache is empty
959
1033
telemetry_context = self ._build_telemetry_context (
@@ -1304,7 +1378,8 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs):
1304
1378
self ._validate_ssh_cert_input_data (kwargs .get ("data" , {}))
1305
1379
telemetry_context = self ._build_telemetry_context (
1306
1380
self .ACQUIRE_TOKEN_FOR_CLIENT_ID )
1307
- response = _clean_up (self .client .obtain_token_for_client (
1381
+ client = self ._regional_client or self .client
1382
+ response = _clean_up (client .obtain_token_for_client (
1308
1383
scope = scopes , # This grant flow requires no scope decoration
1309
1384
headers = telemetry_context .generate_headers (),
1310
1385
data = dict (
0 commit comments