Skip to content

Commit 60144d5

Browse files
committed
TokenCache.search() also wipes stale access tokens
1 parent c9dac6c commit 60144d5

File tree

3 files changed

+32
-11
lines changed

3 files changed

+32
-11
lines changed

msal/token_cache.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def _is_matching(entry: dict, query: dict, target_set: set = None) -> bool:
130130
target_set <= set(entry.get("target", "").split())
131131
if target_set else True)
132132

133-
def search(self, credential_type, target=None, query=None): # O(n) generator
133+
def search(self, credential_type, target=None, query=None, *, now=None): # O(n) generator
134134
"""Returns a generator of matching entries.
135135
136136
It is O(1) for AT hits, and O(n) for other types.
@@ -157,18 +157,32 @@ def search(self, credential_type, target=None, query=None): # O(n) generator
157157
target_set = set(target)
158158
with self._lock:
159159
# O(n) search. The key is NOT used in search.
160+
now = int(time.time() if now is None else now)
161+
expired_access_tokens = [
162+
# Especially when/if we key ATs by ephemeral fields such as key_id,
163+
# stale ATs keyed by an old key_id would stay forever.
164+
# Here we collect them for their removal.
165+
]
160166
for entry in self._cache.get(credential_type, {}).values():
167+
if ( # Automatically delete expired access tokens
168+
credential_type == self.CredentialType.ACCESS_TOKEN
169+
and int(entry["expires_on"]) < now
170+
):
171+
expired_access_tokens.append(entry) # Can't delete them within current for-loop
172+
continue
161173
if (entry != preferred_result # Avoid yielding the same entry twice
162174
and self._is_matching(entry, query, target_set=target_set)
163175
):
164176
yield entry
177+
for at in expired_access_tokens:
178+
self.remove_at(at)
165179

166-
def find(self, credential_type, target=None, query=None):
180+
def find(self, credential_type, target=None, query=None, *, now=None):
167181
"""Equivalent to list(search(...))."""
168182
warnings.warn(
169183
"Use list(search(...)) instead to explicitly get a list.",
170184
DeprecationWarning)
171-
return list(self.search(credential_type, target=target, query=query))
185+
return list(self.search(credential_type, target=target, query=query, now=now))
172186

173187
def add(self, event, now=None):
174188
"""Handle a token obtaining event, and add tokens into cache."""

tests/test_application.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,7 @@ class TestApplicationForRefreshInBehaviors(unittest.TestCase):
340340
account = {"home_account_id": "{}.{}".format(uid, utid)}
341341
rt = "this is a rt"
342342
client_id = "my_app"
343+
soon = 60 # application.py considers tokens within 5 minutes as expired
343344

344345
@classmethod
345346
def setUpClass(cls): # Initialization at runtime, not interpret-time
@@ -414,7 +415,8 @@ def mock_post(url, headers=None, *args, **kwargs):
414415

415416
def test_expired_token_and_unavailable_aad_should_return_error(self):
416417
# a.k.a. Attempt refresh expired token when AAD unavailable
417-
self.populate_cache(access_token="expired at", expires_in=-1, refresh_in=-900)
418+
self.populate_cache(
419+
access_token="expired at", expires_in=self.soon, refresh_in=-900)
418420
error = "something went wrong"
419421
def mock_post(url, headers=None, *args, **kwargs):
420422
self.assertEqual("4|84,3|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY))
@@ -425,7 +427,8 @@ def mock_post(url, headers=None, *args, **kwargs):
425427

426428
def test_expired_token_and_available_aad_should_return_new_token(self):
427429
# a.k.a. Attempt refresh expired token when AAD available
428-
self.populate_cache(access_token="expired at", expires_in=-1, refresh_in=-900)
430+
self.populate_cache(
431+
access_token="expired at", expires_in=self.soon, refresh_in=-900)
429432
new_access_token = "new AT"
430433
new_refresh_in = 123
431434
def mock_post(url, headers=None, *args, **kwargs):

tests/test_token_cache.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def testAddByAad(self):
5858
client_id = "my_client_id"
5959
id_token = build_id_token(
6060
oid="object1234", preferred_username="John Doe", aud=client_id)
61+
now = 1000
6162
self.cache.add({
6263
"client_id": client_id,
6364
"scope": ["s2", "s1", "s3"], # Not in particular order
@@ -66,7 +67,7 @@ def testAddByAad(self):
6667
uid="uid", utid="utid", # client_info
6768
expires_in=3600, access_token="an access token",
6869
id_token=id_token, refresh_token="a refresh token"),
69-
}, now=1000)
70+
}, now=now)
7071
access_token_entry = {
7172
'cached_at': "1000",
7273
'client_id': 'my_client_id',
@@ -84,7 +85,7 @@ def testAddByAad(self):
8485
self.at_key_maker(**access_token_entry)))
8586
self.assertIn(
8687
access_token_entry,
87-
self.cache.find(self.cache.CredentialType.ACCESS_TOKEN),
88+
self.cache.find(self.cache.CredentialType.ACCESS_TOKEN, now=now),
8889
"find(..., query=None) should not crash, even though MSAL does not use it")
8990
self.assertEqual(
9091
{
@@ -203,17 +204,20 @@ def testAddByAdfs(self):
203204
"appmetadata-fs.msidlab8.com-my_client_id")
204205
)
205206

206-
def assertFoundAccessToken(self, *, scopes, query, data=None):
207+
def assertFoundAccessToken(self, *, scopes, query, data=None, now=None):
207208
cached_at = None
208209
for cached_at in self.cache.search(
209-
TokenCache.CredentialType.ACCESS_TOKEN, target=scopes, query=query):
210+
TokenCache.CredentialType.ACCESS_TOKEN,
211+
target=scopes, query=query, now=now,
212+
):
210213
for k, v in (data or {}).items(): # The extra data, if any
211214
self.assertEqual(cached_at.get(k), v, f"AT should contain {k}={v}")
212215
self.assertTrue(cached_at, "AT should be cached and searchable")
213216
return cached_at
214217

215218
def _test_data_should_be_saved_and_searchable_in_access_token(self, data):
216219
scopes = ["s2", "s1", "s3"] # Not in particular order
220+
now = 1000
217221
self.cache.add({
218222
"data": data,
219223
"client_id": "my_client_id",
@@ -223,8 +227,8 @@ def _test_data_should_be_saved_and_searchable_in_access_token(self, data):
223227
uid="uid", utid="utid", # client_info
224228
expires_in=3600, access_token="an access token",
225229
refresh_token="a refresh token"),
226-
}, now=1000)
227-
self.assertFoundAccessToken(scopes=scopes, data=data, query=dict(
230+
}, now=now)
231+
self.assertFoundAccessToken(scopes=scopes, data=data, now=now, query=dict(
228232
data, # Also use the extra data as a query criteria
229233
client_id="my_client_id",
230234
environment="login.example.com",

0 commit comments

Comments
 (0)