Skip to content

Commit 4d744c1

Browse files
authored
Merge pull request #320 from AzureAD/refresh-in
Filter out refresh_in from auth responses
2 parents cf3c99a + 76348db commit 4d744c1

File tree

2 files changed

+46
-32
lines changed

2 files changed

+46
-32
lines changed

msal/application.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,12 @@ def _str2bytes(raw):
100100
return raw
101101

102102

103+
def _clean_up(result):
104+
if isinstance(result, dict):
105+
result.pop("refresh_in", None) # MSAL handled refresh_in, customers need not
106+
return result
107+
108+
103109
class ClientApplication(object):
104110

105111
ACQUIRE_TOKEN_SILENT_ID = "84"
@@ -507,7 +513,7 @@ def authorize(): # A controller in a web app
507513
return redirect(url_for("index"))
508514
"""
509515
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
510-
return self.client.obtain_token_by_auth_code_flow(
516+
return _clean_up(self.client.obtain_token_by_auth_code_flow(
511517
auth_code_flow,
512518
auth_response,
513519
scope=decorate_scope(scopes, self.client_id) if scopes else None,
@@ -521,7 +527,7 @@ def authorize(): # A controller in a web app
521527
claims=_merge_claims_challenge_and_capabilities(
522528
self._client_capabilities,
523529
auth_code_flow.pop("claims_challenge", None))),
524-
**kwargs)
530+
**kwargs))
525531

526532
def acquire_token_by_authorization_code(
527533
self,
@@ -580,7 +586,7 @@ def acquire_token_by_authorization_code(
580586
"Change your acquire_token_by_authorization_code() "
581587
"to acquire_token_by_auth_code_flow()", DeprecationWarning)
582588
with warnings.catch_warnings(record=True):
583-
return self.client.obtain_token_by_authorization_code(
589+
return _clean_up(self.client.obtain_token_by_authorization_code(
584590
code, redirect_uri=redirect_uri,
585591
scope=decorate_scope(scopes, self.client_id),
586592
headers={
@@ -593,7 +599,7 @@ def acquire_token_by_authorization_code(
593599
claims=_merge_claims_challenge_and_capabilities(
594600
self._client_capabilities, claims_challenge)),
595601
nonce=nonce,
596-
**kwargs)
602+
**kwargs))
597603

598604
def get_accounts(self, username=None):
599605
"""Get a list of accounts which previously signed in, i.e. exists in cache.
@@ -855,13 +861,13 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
855861
result = self._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
856862
authority, decorate_scope(scopes, self.client_id), account,
857863
force_refresh=force_refresh, claims_challenge=claims_challenge, **kwargs)
864+
result = _clean_up(result)
858865
if (result and "error" not in result) or (not access_token_from_cache):
859866
return result
860867
except: # The exact HTTP exception is transportation-layer dependent
861868
logger.exception("Refresh token failed") # Potential AAD outage?
862869
return access_token_from_cache
863870

864-
865871
def _acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
866872
self, authority, scopes, account, **kwargs):
867873
query = {
@@ -993,7 +999,7 @@ def acquire_token_by_refresh_token(self, refresh_token, scopes, **kwargs):
993999
* A dict contains no "error" key means migration was successful.
9941000
"""
9951001
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
996-
return self.client.obtain_token_by_refresh_token(
1002+
return _clean_up(self.client.obtain_token_by_refresh_token(
9971003
refresh_token,
9981004
scope=decorate_scope(scopes, self.client_id),
9991005
headers={
@@ -1004,7 +1010,7 @@ def acquire_token_by_refresh_token(self, refresh_token, scopes, **kwargs):
10041010
rt_getter=lambda rt: rt,
10051011
on_updating_rt=False,
10061012
on_removing_rt=lambda rt_item: None, # No OP
1007-
**kwargs)
1013+
**kwargs))
10081014

10091015

10101016
class PublicClientApplication(ClientApplication): # browser app or mobile app
@@ -1081,7 +1087,7 @@ def acquire_token_interactive(
10811087
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
10821088
claims = _merge_claims_challenge_and_capabilities(
10831089
self._client_capabilities, claims_challenge)
1084-
return self.client.obtain_token_by_browser(
1090+
return _clean_up(self.client.obtain_token_by_browser(
10851091
scope=decorate_scope(scopes, self.client_id) if scopes else None,
10861092
extra_scope_to_consent=extra_scopes_to_consent,
10871093
redirect_uri="http://localhost:{port}".format(
@@ -1100,7 +1106,7 @@ def acquire_token_interactive(
11001106
CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header(
11011107
self.ACQUIRE_TOKEN_INTERACTIVE),
11021108
},
1103-
**kwargs)
1109+
**kwargs))
11041110

11051111
def initiate_device_flow(self, scopes=None, **kwargs):
11061112
"""Initiate a Device Flow instance,
@@ -1143,7 +1149,7 @@ def acquire_token_by_device_flow(self, flow, claims_challenge=None, **kwargs):
11431149
- A successful response would contain "access_token" key,
11441150
- an error response would contain "error" and usually "error_description".
11451151
"""
1146-
return self.client.obtain_token_by_device_flow(
1152+
return _clean_up(self.client.obtain_token_by_device_flow(
11471153
flow,
11481154
data=dict(
11491155
kwargs.pop("data", {}),
@@ -1159,7 +1165,7 @@ def acquire_token_by_device_flow(self, flow, claims_challenge=None, **kwargs):
11591165
CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header(
11601166
self.ACQUIRE_TOKEN_BY_DEVICE_FLOW_ID),
11611167
},
1162-
**kwargs)
1168+
**kwargs))
11631169

11641170
def acquire_token_by_username_password(
11651171
self, username, password, scopes, claims_challenge=None, **kwargs):
@@ -1197,15 +1203,15 @@ def acquire_token_by_username_password(
11971203
user_realm_result = self.authority.user_realm_discovery(
11981204
username, correlation_id=headers[CLIENT_REQUEST_ID])
11991205
if user_realm_result.get("account_type") == "Federated":
1200-
return self._acquire_token_by_username_password_federated(
1206+
return _clean_up(self._acquire_token_by_username_password_federated(
12011207
user_realm_result, username, password, scopes=scopes,
12021208
data=data,
1203-
headers=headers, **kwargs)
1204-
return self.client.obtain_token_by_username_password(
1209+
headers=headers, **kwargs))
1210+
return _clean_up(self.client.obtain_token_by_username_password(
12051211
username, password, scope=scopes,
12061212
headers=headers,
12071213
data=data,
1208-
**kwargs)
1214+
**kwargs))
12091215

12101216
def _acquire_token_by_username_password_federated(
12111217
self, user_realm_result, username, password, scopes=None, **kwargs):
@@ -1265,7 +1271,7 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs):
12651271
"""
12661272
# TBD: force_refresh behavior
12671273
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
1268-
return self.client.obtain_token_for_client(
1274+
return _clean_up(self.client.obtain_token_for_client(
12691275
scope=scopes, # This grant flow requires no scope decoration
12701276
headers={
12711277
CLIENT_REQUEST_ID: _get_new_correlation_id(),
@@ -1276,7 +1282,7 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs):
12761282
kwargs.pop("data", {}),
12771283
claims=_merge_claims_challenge_and_capabilities(
12781284
self._client_capabilities, claims_challenge)),
1279-
**kwargs)
1285+
**kwargs))
12801286

12811287
def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=None, **kwargs):
12821288
"""Acquires token using on-behalf-of (OBO) flow.
@@ -1306,7 +1312,7 @@ def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=No
13061312
"""
13071313
# The implementation is NOT based on Token Exchange
13081314
# https://tools.ietf.org/html/draft-ietf-oauth-token-exchange-16
1309-
return self.client.obtain_token_by_assertion( # bases on assertion RFC 7521
1315+
return _clean_up(self.client.obtain_token_by_assertion( # bases on assertion RFC 7521
13101316
user_assertion,
13111317
self.client.GRANT_TYPE_JWT, # IDTs and AAD ATs are all JWTs
13121318
scope=decorate_scope(scopes, self.client_id), # Decoration is used for:
@@ -1325,4 +1331,4 @@ def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=No
13251331
CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header(
13261332
self.ACQUIRE_TOKEN_ON_BEHALF_OF_ID),
13271333
},
1328-
**kwargs)
1334+
**kwargs))

tests/test_application.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -353,19 +353,23 @@ def test_fresh_token_should_be_returned_from_cache(self):
353353
# a.k.a. Return unexpired token that is not above token refresh expiration threshold
354354
access_token = "An access token prepopulated into cache"
355355
self.populate_cache(access_token=access_token, expires_in=900, refresh_in=450)
356-
self.assertEqual(
357-
access_token,
358-
self.app.acquire_token_silent(['s1'], self.account).get("access_token"))
356+
result = self.app.acquire_token_silent(['s1'], self.account)
357+
self.assertEqual(access_token, result.get("access_token"))
358+
self.assertNotIn("refresh_in", result, "Customers need not know refresh_in")
359359

360360
def test_aging_token_and_available_aad_should_return_new_token(self):
361361
# a.k.a. Attempt to refresh unexpired token when AAD available
362362
self.populate_cache(access_token="old AT", expires_in=3599, refresh_in=-1)
363363
new_access_token = "new AT"
364-
self.app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family = (
365-
lambda *args, **kwargs: {"access_token": new_access_token})
366-
self.assertEqual(
367-
new_access_token,
368-
self.app.acquire_token_silent(['s1'], self.account).get("access_token"))
364+
def mock_post(*args, **kwargs):
365+
return MinimalResponse(status_code=200, text=json.dumps({
366+
"access_token": new_access_token,
367+
"refresh_in": 123,
368+
}))
369+
self.app.http_client.post = mock_post
370+
result = self.app.acquire_token_silent(['s1'], self.account)
371+
self.assertEqual(new_access_token, result.get("access_token"))
372+
self.assertNotIn("refresh_in", result, "Customers need not know refresh_in")
369373

370374
def test_aging_token_and_unavailable_aad_should_return_old_token(self):
371375
# a.k.a. Attempt refresh unexpired token when AAD unavailable
@@ -392,9 +396,13 @@ def test_expired_token_and_available_aad_should_return_new_token(self):
392396
# a.k.a. Attempt refresh expired token when AAD available
393397
self.populate_cache(access_token="expired at", expires_in=-1, refresh_in=-900)
394398
new_access_token = "new AT"
395-
self.app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family = (
396-
lambda *args, **kwargs: {"access_token": new_access_token})
397-
self.assertEqual(
398-
new_access_token,
399-
self.app.acquire_token_silent(['s1'], self.account).get("access_token"))
399+
def mock_post(*args, **kwargs):
400+
return MinimalResponse(status_code=200, text=json.dumps({
401+
"access_token": new_access_token,
402+
"refresh_in": 123,
403+
}))
404+
self.app.http_client.post = mock_post
405+
result = self.app.acquire_token_silent(['s1'], self.account)
406+
self.assertEqual(new_access_token, result.get("access_token"))
407+
self.assertNotIn("refresh_in", result, "Customers need not know refresh_in")
400408

0 commit comments

Comments
 (0)