Skip to content

Expose refresh_on (if any) to fresh or cached response #723

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions msal/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,14 @@ def _clean_up(result):
"msalruntime_telemetry": result.get("_msalruntime_telemetry"),
"msal_python_telemetry": result.get("_msal_python_telemetry"),
}, separators=(",", ":"))
return {
return_value = {
k: result[k] for k in result
if k != "refresh_in" # MSAL handled refresh_in, customers need not
and not k.startswith('_') # Skim internal properties
}
if "refresh_in" in result: # To encourage proactive refresh
return_value["refresh_on"] = int(time.time() + result["refresh_in"])
return return_value
return result # It could be None


Expand Down Expand Up @@ -1507,9 +1510,11 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
"expires_in": int(expires_in), # OAuth2 specs defines it as int
self._TOKEN_SOURCE: self._TOKEN_SOURCE_CACHE,
}
if "refresh_on" in entry and int(entry["refresh_on"]) < now: # aging
refresh_reason = msal.telemetry.AT_AGING
break # With a fallback in hand, we break here to go refresh
if "refresh_on" in entry:
access_token_from_cache["refresh_on"] = int(entry["refresh_on"])
if int(entry["refresh_on"]) < now: # aging
refresh_reason = msal.telemetry.AT_AGING
break # With a fallback in hand, we break here to go refresh
self._build_telemetry_context(-1).hit_an_access_token()
return access_token_from_cache # It is still good as new
else:
Expand Down
8 changes: 6 additions & 2 deletions msal/managed_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,10 @@ def acquire_token_for_client(self, *, resource): # We may support scope in the
"token_type": entry.get("token_type", "Bearer"),
"expires_in": int(expires_in), # OAuth2 specs defines it as int
}
if "refresh_on" in entry and int(entry["refresh_on"]) < now: # aging
break # With a fallback in hand, we break here to go refresh
if "refresh_on" in entry:
access_token_from_cache["refresh_on"] = int(entry["refresh_on"])
if int(entry["refresh_on"]) < now: # aging
break # With a fallback in hand, we break here to go refresh
return access_token_from_cache # It is still good as new
try:
result = _obtain_token(self._http_client, self._managed_identity, resource)
Expand All @@ -290,6 +292,8 @@ def acquire_token_for_client(self, *, resource): # We may support scope in the
params={},
data={},
))
if "refresh_in" in result:
result["refresh_on"] = int(now + result["refresh_in"])
if (result and "error" not in result) or (not access_token_from_cache):
return result
except: # The exact HTTP exception is transportation-layer dependent
Expand Down
25 changes: 21 additions & 4 deletions tests/test_application.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Note: Since Aug 2019 we move all e2e tests into test_e2e.py,
# so this test_application file contains only unit tests without dependency.
import sys
import time
from msal.application import *
from msal.application import _str2bytes
import msal
Expand Down Expand Up @@ -353,10 +354,18 @@ def populate_cache(self, access_token="at", expires_in=86400, refresh_in=43200):
uid=self.uid, utid=self.utid, refresh_token=self.rt),
})

def assertRefreshOn(self, result, refresh_in):
refresh_on = int(time.time() + refresh_in)
self.assertTrue(
refresh_on - 1 < result.get("refresh_on", 0) < refresh_on + 1,
"refresh_on should be set properly")

def test_fresh_token_should_be_returned_from_cache(self):
# a.k.a. Return unexpired token that is not above token refresh expiration threshold
refresh_in = 450
access_token = "An access token prepopulated into cache"
self.populate_cache(access_token=access_token, expires_in=900, refresh_in=450)
self.populate_cache(
access_token=access_token, expires_in=900, refresh_in=refresh_in)
result = self.app.acquire_token_silent(
['s1'], self.account,
post=lambda url, *args, **kwargs: # Utilize the undocumented test feature
Expand All @@ -365,32 +374,38 @@ def test_fresh_token_should_be_returned_from_cache(self):
self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_CACHE)
self.assertEqual(access_token, result.get("access_token"))
self.assertNotIn("refresh_in", result, "Customers need not know refresh_in")
self.assertRefreshOn(result, refresh_in)

def test_aging_token_and_available_aad_should_return_new_token(self):
# a.k.a. Attempt to refresh unexpired token when AAD available
self.populate_cache(access_token="old AT", expires_in=3599, refresh_in=-1)
new_access_token = "new AT"
new_refresh_in = 123
def mock_post(url, headers=None, *args, **kwargs):
self.assertEqual("4|84,4|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY))
return MinimalResponse(status_code=200, text=json.dumps({
"access_token": new_access_token,
"refresh_in": 123,
"refresh_in": new_refresh_in,
}))
result = self.app.acquire_token_silent(['s1'], self.account, post=mock_post)
self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_IDP)
self.assertEqual(new_access_token, result.get("access_token"))
self.assertNotIn("refresh_in", result, "Customers need not know refresh_in")
self.assertRefreshOn(result, new_refresh_in)

def test_aging_token_and_unavailable_aad_should_return_old_token(self):
# a.k.a. Attempt refresh unexpired token when AAD unavailable
refresh_in = -1
old_at = "old AT"
self.populate_cache(access_token=old_at, expires_in=3599, refresh_in=-1)
self.populate_cache(
access_token=old_at, expires_in=3599, refresh_in=refresh_in)
def mock_post(url, headers=None, *args, **kwargs):
self.assertEqual("4|84,4|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY))
return MinimalResponse(status_code=400, text=json.dumps({"error": "foo"}))
result = self.app.acquire_token_silent(['s1'], self.account, post=mock_post)
self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_CACHE)
self.assertEqual(old_at, result.get("access_token"))
self.assertRefreshOn(result, refresh_in)

def test_expired_token_and_unavailable_aad_should_return_error(self):
# a.k.a. Attempt refresh expired token when AAD unavailable
Expand All @@ -407,16 +422,18 @@ def test_expired_token_and_available_aad_should_return_new_token(self):
# a.k.a. Attempt refresh expired token when AAD available
self.populate_cache(access_token="expired at", expires_in=-1, refresh_in=-900)
new_access_token = "new AT"
new_refresh_in = 123
def mock_post(url, headers=None, *args, **kwargs):
self.assertEqual("4|84,3|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY))
return MinimalResponse(status_code=200, text=json.dumps({
"access_token": new_access_token,
"refresh_in": 123,
"refresh_in": new_refresh_in,
}))
result = self.app.acquire_token_silent(['s1'], self.account, post=mock_post)
self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_IDP)
self.assertEqual(new_access_token, result.get("access_token"))
self.assertNotIn("refresh_in", result, "Customers need not know refresh_in")
self.assertRefreshOn(result, new_refresh_in)


class TestTelemetryMaintainingOfflineState(unittest.TestCase):
Expand Down
69 changes: 50 additions & 19 deletions tests/test_mi.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
SERVICE_FABRIC,
DEFAULT_TO_VM,
)
from msal.token_cache import is_subdict_of


class ManagedIdentityTestCase(unittest.TestCase):
Expand Down Expand Up @@ -60,7 +61,7 @@ def setUp(self):
http_client=requests.Session(),
)

def _test_token_cache(self, app):
def assertCacheStatus(self, app):
cache = app._token_cache._cache
self.assertEqual(1, len(cache.get("AccessToken", [])), "Should have 1 AT")
at = list(cache["AccessToken"].values())[0]
Expand All @@ -70,30 +71,55 @@ def _test_token_cache(self, app):
"Should have expected client_id")
self.assertEqual("managed_identity", at["realm"], "Should have expected realm")

def _test_happy_path(self, app, mocked_http):
result = app.acquire_token_for_client(resource="R")
def _test_happy_path(self, app, mocked_http, expires_in, resource="R"):
result = app.acquire_token_for_client(resource=resource)
mocked_http.assert_called()
self.assertEqual({
call_count = mocked_http.call_count
expected_result = {
"access_token": "AT",
"expires_in": 1234,
"resource": "R",
"token_type": "Bearer",
}, result, "Should obtain a token response")
}
self.assertTrue(
is_subdict_of(expected_result, result), # We will test refresh_on later
"Should obtain a token response")
self.assertEqual(expires_in, result["expires_in"], "Should have expected expires_in")
if expires_in >= 7200:
expected_refresh_on = int(time.time() + expires_in / 2)
self.assertTrue(
expected_refresh_on - 1 <= result["refresh_on"] <= expected_refresh_on + 1,
"Should have a refresh_on time around the middle of the token's life")
self.assertEqual(
result["access_token"],
app.acquire_token_for_client(resource="R").get("access_token"),
app.acquire_token_for_client(resource=resource).get("access_token"),
"Should hit the same token from cache")
self._test_token_cache(app)

self.assertCacheStatus(app)

result = app.acquire_token_for_client(resource=resource)
self.assertEqual(
call_count, mocked_http.call_count,
"No new call to the mocked http should be made for a cache hit")
self.assertTrue(
is_subdict_of(expected_result, result), # We will test refresh_on later
"Should obtain a token response")
self.assertTrue(
expires_in - 5 < result["expires_in"] <= expires_in,
"Should have similar expires_in")
if expires_in >= 7200:
self.assertTrue(
expected_refresh_on - 5 < result["refresh_on"] <= expected_refresh_on,
"Should have a refresh_on time around the middle of the token's life")


class VmTestCase(ClientTestCase):

def test_happy_path(self):
expires_in = 7890 # We test a bigger than 7200 value here
with patch.object(self.app._http_client, "get", return_value=MinimalResponse(
status_code=200,
text='{"access_token": "AT", "expires_in": "1234", "resource": "R"}',
text='{"access_token": "AT", "expires_in": "%s", "resource": "R"}' % expires_in,
)) as mocked_method:
self._test_happy_path(self.app, mocked_method)
self._test_happy_path(self.app, mocked_method, expires_in)

def test_vm_error_should_be_returned_as_is(self):
raw_error = '{"raw": "error format is undefined"}'
Expand All @@ -110,12 +136,13 @@ def test_vm_error_should_be_returned_as_is(self):
class AppServiceTestCase(ClientTestCase):

def test_happy_path(self):
expires_in = 1234
with patch.object(self.app._http_client, "get", return_value=MinimalResponse(
status_code=200,
text='{"access_token": "AT", "expires_on": "%s", "resource": "R"}' % (
int(time.time()) + 1234),
int(time.time()) + expires_in),
)) as mocked_method:
self._test_happy_path(self.app, mocked_method)
self._test_happy_path(self.app, mocked_method, expires_in)

def test_app_service_error_should_be_normalized(self):
raw_error = '{"statusCode": 500, "message": "error content is undefined"}'
Expand All @@ -134,12 +161,13 @@ def test_app_service_error_should_be_normalized(self):
class MachineLearningTestCase(ClientTestCase):

def test_happy_path(self):
expires_in = 1234
with patch.object(self.app._http_client, "get", return_value=MinimalResponse(
status_code=200,
text='{"access_token": "AT", "expires_on": "%s", "resource": "R"}' % (
int(time.time()) + 1234),
int(time.time()) + expires_in),
)) as mocked_method:
self._test_happy_path(self.app, mocked_method)
self._test_happy_path(self.app, mocked_method, expires_in)

def test_machine_learning_error_should_be_normalized(self):
raw_error = '{"error": "placeholder", "message": "placeholder"}'
Expand All @@ -162,12 +190,14 @@ def test_machine_learning_error_should_be_normalized(self):
class ServiceFabricTestCase(ClientTestCase):

def _test_happy_path(self, app):
expires_in = 1234
with patch.object(app._http_client, "get", return_value=MinimalResponse(
status_code=200,
text='{"access_token": "AT", "expires_on": %s, "resource": "R", "token_type": "Bearer"}' % (
int(time.time()) + 1234),
int(time.time()) + expires_in),
)) as mocked_method:
super(ServiceFabricTestCase, self)._test_happy_path(app, mocked_method)
super(ServiceFabricTestCase, self)._test_happy_path(
app, mocked_method, expires_in)

def test_happy_path(self):
self._test_happy_path(self.app)
Expand Down Expand Up @@ -212,15 +242,16 @@ class ArcTestCase(ClientTestCase):
})

def test_happy_path(self, mocked_stat):
expires_in = 1234
with patch.object(self.app._http_client, "get", side_effect=[
self.challenge,
MinimalResponse(
status_code=200,
text='{"access_token": "AT", "expires_in": "1234", "resource": "R"}',
text='{"access_token": "AT", "expires_in": "%s", "resource": "R"}' % expires_in,
),
]) as mocked_method:
try:
super(ArcTestCase, self)._test_happy_path(self.app, mocked_method)
self._test_happy_path(self.app, mocked_method, expires_in)
mocked_stat.assert_called_with(os.path.join(
_supported_arc_platforms_and_their_prefixes[sys.platform],
"foo.key"))
Expand Down