Skip to content

Commit 76348db

Browse files
committed
Filter out refresh_in from auth responses
1 parent 36365ac commit 76348db

File tree

2 files changed

+46
-32
lines changed

2 files changed

+46
-32
lines changed

msal/application.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,12 @@ def _str2bytes(raw):
100100
return raw
101101

102102

103+
def _clean_up(result):
104+
if isinstance(result, dict):
105+
result.pop("refresh_in", None) # MSAL handled refresh_in, customers need not
106+
return result
107+
108+
103109
class ClientApplication(object):
104110

105111
ACQUIRE_TOKEN_SILENT_ID = "84"
@@ -507,7 +513,7 @@ def authorize(): # A controller in a web app
507513
return redirect(url_for("index"))
508514
"""
509515
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
510-
return self.client.obtain_token_by_auth_code_flow(
516+
return _clean_up(self.client.obtain_token_by_auth_code_flow(
511517
auth_code_flow,
512518
auth_response,
513519
scope=decorate_scope(scopes, self.client_id) if scopes else None,
@@ -521,7 +527,7 @@ def authorize(): # A controller in a web app
521527
claims=_merge_claims_challenge_and_capabilities(
522528
self._client_capabilities,
523529
auth_code_flow.pop("claims_challenge", None))),
524-
**kwargs)
530+
**kwargs))
525531

526532
def acquire_token_by_authorization_code(
527533
self,
@@ -580,7 +586,7 @@ def acquire_token_by_authorization_code(
580586
"Change your acquire_token_by_authorization_code() "
581587
"to acquire_token_by_auth_code_flow()", DeprecationWarning)
582588
with warnings.catch_warnings(record=True):
583-
return self.client.obtain_token_by_authorization_code(
589+
return _clean_up(self.client.obtain_token_by_authorization_code(
584590
code, redirect_uri=redirect_uri,
585591
scope=decorate_scope(scopes, self.client_id),
586592
headers={
@@ -593,7 +599,7 @@ def acquire_token_by_authorization_code(
593599
claims=_merge_claims_challenge_and_capabilities(
594600
self._client_capabilities, claims_challenge)),
595601
nonce=nonce,
596-
**kwargs)
602+
**kwargs))
597603

598604
def get_accounts(self, username=None):
599605
"""Get a list of accounts which previously signed in, i.e. exists in cache.
@@ -855,13 +861,13 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
855861
result = self._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
856862
authority, decorate_scope(scopes, self.client_id), account,
857863
force_refresh=force_refresh, claims_challenge=claims_challenge, **kwargs)
864+
result = _clean_up(result)
858865
if (result and "error" not in result) or (not access_token_from_cache):
859866
return result
860867
except: # The exact HTTP exception is transportation-layer dependent
861868
logger.exception("Refresh token failed") # Potential AAD outage?
862869
return access_token_from_cache
863870

864-
865871
def _acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
866872
self, authority, scopes, account, **kwargs):
867873
query = {
@@ -987,7 +993,7 @@ def acquire_token_by_refresh_token(self, refresh_token, scopes, **kwargs):
987993
* A dict contains no "error" key means migration was successful.
988994
"""
989995
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
990-
return self.client.obtain_token_by_refresh_token(
996+
return _clean_up(self.client.obtain_token_by_refresh_token(
991997
refresh_token,
992998
scope=decorate_scope(scopes, self.client_id),
993999
headers={
@@ -998,7 +1004,7 @@ def acquire_token_by_refresh_token(self, refresh_token, scopes, **kwargs):
9981004
rt_getter=lambda rt: rt,
9991005
on_updating_rt=False,
10001006
on_removing_rt=lambda rt_item: None, # No OP
1001-
**kwargs)
1007+
**kwargs))
10021008

10031009

10041010
class PublicClientApplication(ClientApplication): # browser app or mobile app
@@ -1072,7 +1078,7 @@ def acquire_token_interactive(
10721078
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
10731079
claims = _merge_claims_challenge_and_capabilities(
10741080
self._client_capabilities, claims_challenge)
1075-
return self.client.obtain_token_by_browser(
1081+
return _clean_up(self.client.obtain_token_by_browser(
10761082
scope=decorate_scope(scopes, self.client_id) if scopes else None,
10771083
extra_scope_to_consent=extra_scopes_to_consent,
10781084
redirect_uri="http://localhost:{port}".format(
@@ -1091,7 +1097,7 @@ def acquire_token_interactive(
10911097
CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header(
10921098
self.ACQUIRE_TOKEN_INTERACTIVE),
10931099
},
1094-
**kwargs)
1100+
**kwargs))
10951101

10961102
def initiate_device_flow(self, scopes=None, **kwargs):
10971103
"""Initiate a Device Flow instance,
@@ -1134,7 +1140,7 @@ def acquire_token_by_device_flow(self, flow, claims_challenge=None, **kwargs):
11341140
- A successful response would contain "access_token" key,
11351141
- an error response would contain "error" and usually "error_description".
11361142
"""
1137-
return self.client.obtain_token_by_device_flow(
1143+
return _clean_up(self.client.obtain_token_by_device_flow(
11381144
flow,
11391145
data=dict(
11401146
kwargs.pop("data", {}),
@@ -1150,7 +1156,7 @@ def acquire_token_by_device_flow(self, flow, claims_challenge=None, **kwargs):
11501156
CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header(
11511157
self.ACQUIRE_TOKEN_BY_DEVICE_FLOW_ID),
11521158
},
1153-
**kwargs)
1159+
**kwargs))
11541160

11551161
def acquire_token_by_username_password(
11561162
self, username, password, scopes, claims_challenge=None, **kwargs):
@@ -1188,15 +1194,15 @@ def acquire_token_by_username_password(
11881194
user_realm_result = self.authority.user_realm_discovery(
11891195
username, correlation_id=headers[CLIENT_REQUEST_ID])
11901196
if user_realm_result.get("account_type") == "Federated":
1191-
return self._acquire_token_by_username_password_federated(
1197+
return _clean_up(self._acquire_token_by_username_password_federated(
11921198
user_realm_result, username, password, scopes=scopes,
11931199
data=data,
1194-
headers=headers, **kwargs)
1195-
return self.client.obtain_token_by_username_password(
1200+
headers=headers, **kwargs))
1201+
return _clean_up(self.client.obtain_token_by_username_password(
11961202
username, password, scope=scopes,
11971203
headers=headers,
11981204
data=data,
1199-
**kwargs)
1205+
**kwargs))
12001206

12011207
def _acquire_token_by_username_password_federated(
12021208
self, user_realm_result, username, password, scopes=None, **kwargs):
@@ -1256,7 +1262,7 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs):
12561262
"""
12571263
# TBD: force_refresh behavior
12581264
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
1259-
return self.client.obtain_token_for_client(
1265+
return _clean_up(self.client.obtain_token_for_client(
12601266
scope=scopes, # This grant flow requires no scope decoration
12611267
headers={
12621268
CLIENT_REQUEST_ID: _get_new_correlation_id(),
@@ -1267,7 +1273,7 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs):
12671273
kwargs.pop("data", {}),
12681274
claims=_merge_claims_challenge_and_capabilities(
12691275
self._client_capabilities, claims_challenge)),
1270-
**kwargs)
1276+
**kwargs))
12711277

12721278
def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=None, **kwargs):
12731279
"""Acquires token using on-behalf-of (OBO) flow.
@@ -1297,7 +1303,7 @@ def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=No
12971303
"""
12981304
# The implementation is NOT based on Token Exchange
12991305
# https://tools.ietf.org/html/draft-ietf-oauth-token-exchange-16
1300-
return self.client.obtain_token_by_assertion( # bases on assertion RFC 7521
1306+
return _clean_up(self.client.obtain_token_by_assertion( # bases on assertion RFC 7521
13011307
user_assertion,
13021308
self.client.GRANT_TYPE_JWT, # IDTs and AAD ATs are all JWTs
13031309
scope=decorate_scope(scopes, self.client_id), # Decoration is used for:
@@ -1316,4 +1322,4 @@ def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=No
13161322
CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header(
13171323
self.ACQUIRE_TOKEN_ON_BEHALF_OF_ID),
13181324
},
1319-
**kwargs)
1325+
**kwargs))

tests/test_application.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -354,19 +354,23 @@ def test_fresh_token_should_be_returned_from_cache(self):
354354
# a.k.a. Return unexpired token that is not above token refresh expiration threshold
355355
access_token = "An access token prepopulated into cache"
356356
self.populate_cache(access_token=access_token, expires_in=900, refresh_in=450)
357-
self.assertEqual(
358-
access_token,
359-
self.app.acquire_token_silent(['s1'], self.account).get("access_token"))
357+
result = self.app.acquire_token_silent(['s1'], self.account)
358+
self.assertEqual(access_token, result.get("access_token"))
359+
self.assertNotIn("refresh_in", result, "Customers need not know refresh_in")
360360

361361
def test_aging_token_and_available_aad_should_return_new_token(self):
362362
# a.k.a. Attempt to refresh unexpired token when AAD available
363363
self.populate_cache(access_token="old AT", expires_in=3599, refresh_in=-1)
364364
new_access_token = "new AT"
365-
self.app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family = (
366-
lambda *args, **kwargs: {"access_token": new_access_token})
367-
self.assertEqual(
368-
new_access_token,
369-
self.app.acquire_token_silent(['s1'], self.account).get("access_token"))
365+
def mock_post(*args, **kwargs):
366+
return MinimalResponse(status_code=200, text=json.dumps({
367+
"access_token": new_access_token,
368+
"refresh_in": 123,
369+
}))
370+
self.app.http_client.post = mock_post
371+
result = self.app.acquire_token_silent(['s1'], self.account)
372+
self.assertEqual(new_access_token, result.get("access_token"))
373+
self.assertNotIn("refresh_in", result, "Customers need not know refresh_in")
370374

371375
def test_aging_token_and_unavailable_aad_should_return_old_token(self):
372376
# a.k.a. Attempt refresh unexpired token when AAD unavailable
@@ -393,9 +397,13 @@ def test_expired_token_and_available_aad_should_return_new_token(self):
393397
# a.k.a. Attempt refresh expired token when AAD available
394398
self.populate_cache(access_token="expired at", expires_in=-1, refresh_in=-900)
395399
new_access_token = "new AT"
396-
self.app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family = (
397-
lambda *args, **kwargs: {"access_token": new_access_token})
398-
self.assertEqual(
399-
new_access_token,
400-
self.app.acquire_token_silent(['s1'], self.account).get("access_token"))
400+
def mock_post(*args, **kwargs):
401+
return MinimalResponse(status_code=200, text=json.dumps({
402+
"access_token": new_access_token,
403+
"refresh_in": 123,
404+
}))
405+
self.app.http_client.post = mock_post
406+
result = self.app.acquire_token_silent(['s1'], self.account)
407+
self.assertEqual(new_access_token, result.get("access_token"))
408+
self.assertNotIn("refresh_in", result, "Customers need not know refresh_in")
401409

0 commit comments

Comments
 (0)