Skip to content

Commit c9dac6c

Browse files
committed
Change AccessToken key_maker algorithm
1 parent ede849e commit c9dac6c

File tree

2 files changed

+34
-38
lines changed

2 files changed

+34
-38
lines changed

msal/token_cache.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ def __init__(self):
4343
self._lock = threading.RLock()
4444
self._cache = {}
4545
self.key_makers = {
46+
# Note: We have changed token key format before when ordering scopes;
47+
# changing token key won't result in cache miss.
4648
self.CredentialType.REFRESH_TOKEN:
4749
lambda home_account_id=None, environment=None, client_id=None,
4850
target=None, **ignored_payload_from_a_real_token:
@@ -56,14 +58,18 @@ def __init__(self):
5658
]).lower(),
5759
self.CredentialType.ACCESS_TOKEN:
5860
lambda home_account_id=None, environment=None, client_id=None,
59-
realm=None, target=None, **ignored_payload_from_a_real_token:
60-
"-".join([
61+
realm=None, target=None,
62+
# Note: New field(s) can be added here
63+
#key_id=None,
64+
**ignored_payload_from_a_real_token:
65+
"-".join([ # Note: Could use a hash here to shorten key length
6166
home_account_id or "",
6267
environment or "",
6368
self.CredentialType.ACCESS_TOKEN,
6469
client_id or "",
6570
realm or "",
6671
target or "",
72+
#key_id or "", # So ATs of different key_id can coexist
6773
]).lower(),
6874
self.CredentialType.ID_TOKEN:
6975
lambda home_account_id=None, environment=None, client_id=None,
@@ -150,9 +156,7 @@ def search(self, credential_type, target=None, query=None): # O(n) generator
150156

151157
target_set = set(target)
152158
with self._lock:
153-
# Since the target inside token cache key is (per schema) unsorted,
154-
# there is no point to attempt an O(1) key-value search here.
155-
# So we always do an O(n) in-memory search.
159+
# O(n) search. The key is NOT used in search.
156160
for entry in self._cache.get(credential_type, {}).values():
157161
if (entry != preferred_result # Avoid yielding the same entry twice
158162
and self._is_matching(entry, query, target_set=target_set)

tests/test_token_cache.py

Lines changed: 25 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import json
44
import time
55

6-
from msal.token_cache import *
6+
from msal.token_cache import TokenCache, SerializableTokenCache
77
from tests import unittest
88

99

@@ -51,6 +51,8 @@ class TokenCacheTestCase(unittest.TestCase):
5151

5252
def setUp(self):
5353
self.cache = TokenCache()
54+
self.at_key_maker = self.cache.key_makers[
55+
TokenCache.CredentialType.ACCESS_TOKEN]
5456

5557
def testAddByAad(self):
5658
client_id = "my_client_id"
@@ -78,11 +80,8 @@ def testAddByAad(self):
7880
'target': 's1 s2 s3', # Sorted
7981
'token_type': 'some type',
8082
}
81-
self.assertEqual(
82-
access_token_entry,
83-
self.cache._cache["AccessToken"].get(
84-
'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s1 s2 s3')
85-
)
83+
self.assertEqual(access_token_entry, self.cache._cache["AccessToken"].get(
84+
self.at_key_maker(**access_token_entry)))
8685
self.assertIn(
8786
access_token_entry,
8887
self.cache.find(self.cache.CredentialType.ACCESS_TOKEN),
@@ -144,8 +143,7 @@ def testAddByAdfs(self):
144143
expires_in=3600, access_token="an access token",
145144
id_token=id_token, refresh_token="a refresh token"),
146145
}, now=1000)
147-
self.assertEqual(
148-
{
146+
access_token_entry = {
149147
'cached_at': "1000",
150148
'client_id': 'my_client_id',
151149
'credential_type': 'AccessToken',
@@ -157,10 +155,9 @@ def testAddByAdfs(self):
157155
'secret': 'an access token',
158156
'target': 's1 s2 s3', # Sorted
159157
'token_type': 'some type',
160-
},
161-
self.cache._cache["AccessToken"].get(
162-
'subject-fs.msidlab8.com-accesstoken-my_client_id-adfs-s1 s2 s3')
163-
)
158+
}
159+
self.assertEqual(access_token_entry, self.cache._cache["AccessToken"].get(
160+
self.at_key_maker(**access_token_entry)))
164161
self.assertEqual(
165162
{
166163
'client_id': 'my_client_id',
@@ -238,37 +235,32 @@ def _test_data_should_be_saved_and_searchable_in_access_token(self, data):
238235
def test_extra_data_should_also_be_recorded_and_searchable_in_access_token(self):
239236
self._test_data_should_be_saved_and_searchable_in_access_token({"key_id": "1"})
240237

241-
def test_key_id_is_also_recorded(self):
242-
my_key_id = "some_key_id_123"
243-
self.cache.add({
244-
"data": {"key_id": my_key_id},
245-
"client_id": "my_client_id",
246-
"scope": ["s2", "s1", "s3"], # Not in particular order
247-
"token_endpoint": "https://login.example.com/contoso/v2/token",
248-
"response": build_response(
249-
uid="uid", utid="utid", # client_info
250-
expires_in=3600, access_token="an access token",
251-
refresh_token="a refresh token"),
252-
}, now=1000)
253-
cached_key_id = self.cache._cache["AccessToken"].get(
254-
'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s1 s2 s3',
255-
{}).get("key_id")
256-
self.assertEqual(my_key_id, cached_key_id, "AT should be bound to the key")
238+
def test_access_tokens_with_different_key_id(self):
239+
self._test_data_should_be_saved_and_searchable_in_access_token({"key_id": "1"})
240+
self._test_data_should_be_saved_and_searchable_in_access_token({"key_id": "2"})
241+
self.assertEqual(
242+
len(self.cache._cache["AccessToken"]),
243+
1, """Historically, tokens are not keyed by key_id,
244+
so a new token overwrites the old one, and we would end up with 1 token in cache""")
257245

258246
def test_refresh_in_should_be_recorded_as_refresh_on(self): # Sounds weird. Yep.
247+
scopes = ["s2", "s1", "s3"] # Not in particular order
259248
self.cache.add({
260249
"client_id": "my_client_id",
261-
"scope": ["s2", "s1", "s3"], # Not in particular order
250+
"scope": scopes,
262251
"token_endpoint": "https://login.example.com/contoso/v2/token",
263252
"response": build_response(
264253
uid="uid", utid="utid", # client_info
265254
expires_in=3600, refresh_in=1800, access_token="an access token",
266255
), #refresh_token="a refresh token"),
267256
}, now=1000)
268-
refresh_on = self.cache._cache["AccessToken"].get(
269-
'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s1 s2 s3',
270-
{}).get("refresh_on")
271-
self.assertEqual("2800", refresh_on, "Should save refresh_on")
257+
at = self.assertFoundAccessToken(scopes=scopes, query=dict(
258+
client_id="my_client_id",
259+
environment="login.example.com",
260+
realm="contoso",
261+
home_account_id="uid.utid",
262+
))
263+
self.assertEqual("2800", at.get("refresh_on"), "Should save refresh_on")
272264

273265
def test_old_rt_data_with_wrong_key_should_still_be_salvaged_into_new_rt(self):
274266
sample = {

0 commit comments

Comments
 (0)