Skip to content

Commit 804d529

Browse files
authored
Merge pull request #644 from AzureAD/order-scopes
Order scopes on save, and optimize the happy path for access token read
2 parents 866ba2b + 5272fbd commit 804d529

File tree

3 files changed

+76
-28
lines changed

3 files changed

+76
-28
lines changed

msal/application.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1357,13 +1357,14 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
13571357
key_id = kwargs.get("data", {}).get("key_id")
13581358
if key_id: # Some token types (SSH-certs, POP) are bound to a key
13591359
query["key_id"] = key_id
1360-
matches = self.token_cache.find(
1361-
self.token_cache.CredentialType.ACCESS_TOKEN,
1362-
target=scopes,
1363-
query=query)
13641360
now = time.time()
13651361
refresh_reason = msal.telemetry.AT_ABSENT
1366-
for entry in matches:
1362+
for entry in self.token_cache._find( # It returns a generator
1363+
self.token_cache.CredentialType.ACCESS_TOKEN,
1364+
target=scopes,
1365+
query=query,
1366+
): # Note that _find() holds a lock during this for loop;
1367+
# that is fine because this loop is fast
13671368
expires_in = int(entry["expires_on"]) - now
13681369
if expires_in < 5*60: # Then consider it expired
13691370
refresh_reason = msal.telemetry.AT_EXPIRED
@@ -1492,10 +1493,8 @@ def _acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
14921493
**kwargs) or last_resp
14931494

14941495
def _get_app_metadata(self, environment):
1495-
apps = self.token_cache.find( # Use find(), rather than token_cache.get(...)
1496-
TokenCache.CredentialType.APP_METADATA, query={
1497-
"environment": environment, "client_id": self.client_id})
1498-
return apps[0] if apps else {}
1496+
return self.token_cache._get_app_metadata(
1497+
environment=environment, client_id=self.client_id, default={})
14991498

15001499
def _acquire_token_silent_by_finding_specific_refresh_token(
15011500
self, authority, scopes, query,

msal/token_cache.py

Lines changed: 58 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -88,20 +88,69 @@ def __init__(self):
8888
"appmetadata-{}-{}".format(environment or "", client_id or ""),
8989
}
9090

91-
def find(self, credential_type, target=None, query=None):
92-
target = target or []
91+
def _get_access_token(
92+
self,
93+
home_account_id, environment, client_id, realm, target, # Together they form a compound key
94+
default=None,
95+
): # O(1)
96+
return self._get(
97+
self.CredentialType.ACCESS_TOKEN,
98+
self.key_makers[TokenCache.CredentialType.ACCESS_TOKEN](
99+
home_account_id=home_account_id,
100+
environment=environment,
101+
client_id=client_id,
102+
realm=realm,
103+
target=" ".join(target),
104+
),
105+
default=default)
106+
107+
def _get_app_metadata(self, environment, client_id, default=None): # O(1)
108+
return self._get(
109+
self.CredentialType.APP_METADATA,
110+
self.key_makers[TokenCache.CredentialType.APP_METADATA](
111+
environment=environment,
112+
client_id=client_id,
113+
),
114+
default=default)
115+
116+
def _get(self, credential_type, key, default=None): # O(1)
117+
with self._lock:
118+
return self._cache.get(credential_type, {}).get(key, default)
119+
120+
def _find(self, credential_type, target=None, query=None): # O(n) generator
121+
"""Returns a generator of matching entries.
122+
123+
It is O(1) for AT hits, and O(n) for other types.
124+
Note that it holds a lock during the entire search.
125+
"""
126+
target = sorted(target or []) # Match the order sorted by add()
93127
assert isinstance(target, list), "Invalid parameter type"
128+
129+
preferred_result = None
130+
if (credential_type == self.CredentialType.ACCESS_TOKEN
131+
and "home_account_id" in query and "environment" in query
132+
and "client_id" in query and "realm" in query and target
133+
): # Special case for O(1) AT lookup
134+
preferred_result = self._get_access_token(
135+
query["home_account_id"], query["environment"],
136+
query["client_id"], query["realm"], target)
137+
if preferred_result:
138+
yield preferred_result
139+
94140
target_set = set(target)
95141
with self._lock:
96142
# Since the target inside token cache key is (per schema) unsorted,
97143
# there is no point to attempt an O(1) key-value search here.
98144
# So we always do an O(n) in-memory search.
99-
return [entry
100-
for entry in self._cache.get(credential_type, {}).values()
101-
if is_subdict_of(query or {}, entry)
102-
and (target_set <= set(entry.get("target", "").split())
103-
if target else True)
104-
]
145+
for entry in self._cache.get(credential_type, {}).values():
146+
if is_subdict_of(query or {}, entry) and (
147+
target_set <= set(entry.get("target", "").split())
148+
if target else True):
149+
if entry != preferred_result: # Avoid yielding the same entry twice
150+
yield entry
151+
152+
def find(self, credential_type, target=None, query=None): # Obsolete. Use _find() instead.
153+
return list(self._find(credential_type, target=target, query=query))
105154

106155
def add(self, event, now=None):
107156
"""Handle a token obtaining event, and add tokens into cache."""
@@ -160,7 +209,7 @@ def __add(self, event, now=None):
160209
decode_id_token(id_token, client_id=event["client_id"]) if id_token else {})
161210
client_info, home_account_id = self.__parse_account(response, id_token_claims)
162211

163-
target = ' '.join(event.get("scope") or []) # Per schema, we don't sort it
212+
target = ' '.join(sorted(event.get("scope") or [])) # Schema should have required sorting
164213

165214
with self._lock:
166215
now = int(time.time() if now is None else now)

tests/test_token_cache.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,11 @@ def testAddByAad(self):
7676
'home_account_id': "uid.utid",
7777
'realm': 'contoso',
7878
'secret': 'an access token',
79-
'target': 's2 s1 s3',
79+
'target': 's1 s2 s3', # Sorted
8080
'token_type': 'some type',
8181
},
8282
self.cache._cache["AccessToken"].get(
83-
'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s2 s1 s3')
83+
'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s1 s2 s3')
8484
)
8585
self.assertEqual(
8686
{
@@ -90,10 +90,10 @@ def testAddByAad(self):
9090
'home_account_id': "uid.utid",
9191
'last_modification_time': '1000',
9292
'secret': 'a refresh token',
93-
'target': 's2 s1 s3',
93+
'target': 's1 s2 s3', # Sorted
9494
},
9595
self.cache._cache["RefreshToken"].get(
96-
'uid.utid-login.example.com-refreshtoken-my_client_id--s2 s1 s3')
96+
'uid.utid-login.example.com-refreshtoken-my_client_id--s1 s2 s3')
9797
)
9898
self.assertEqual(
9999
{
@@ -150,11 +150,11 @@ def testAddByAdfs(self):
150150
'home_account_id': "subject",
151151
'realm': 'adfs',
152152
'secret': 'an access token',
153-
'target': 's2 s1 s3',
153+
'target': 's1 s2 s3', # Sorted
154154
'token_type': 'some type',
155155
},
156156
self.cache._cache["AccessToken"].get(
157-
'subject-fs.msidlab8.com-accesstoken-my_client_id-adfs-s2 s1 s3')
157+
'subject-fs.msidlab8.com-accesstoken-my_client_id-adfs-s1 s2 s3')
158158
)
159159
self.assertEqual(
160160
{
@@ -164,10 +164,10 @@ def testAddByAdfs(self):
164164
'home_account_id': "subject",
165165
'last_modification_time': "1000",
166166
'secret': 'a refresh token',
167-
'target': 's2 s1 s3',
167+
'target': 's1 s2 s3', # Sorted
168168
},
169169
self.cache._cache["RefreshToken"].get(
170-
'subject-fs.msidlab8.com-refreshtoken-my_client_id--s2 s1 s3')
170+
'subject-fs.msidlab8.com-refreshtoken-my_client_id--s1 s2 s3')
171171
)
172172
self.assertEqual(
173173
{
@@ -214,7 +214,7 @@ def test_key_id_is_also_recorded(self):
214214
refresh_token="a refresh token"),
215215
}, now=1000)
216216
cached_key_id = self.cache._cache["AccessToken"].get(
217-
'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s2 s1 s3',
217+
'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s1 s2 s3',
218218
{}).get("key_id")
219219
self.assertEqual(my_key_id, cached_key_id, "AT should be bound to the key")
220220

@@ -229,7 +229,7 @@ def test_refresh_in_should_be_recorded_as_refresh_on(self): # Sounds weird. Yep
229229
), #refresh_token="a refresh token"),
230230
}, now=1000)
231231
refresh_on = self.cache._cache["AccessToken"].get(
232-
'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s2 s1 s3',
232+
'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s1 s2 s3',
233233
{}).get("refresh_on")
234234
self.assertEqual("2800", refresh_on, "Should save refresh_on")
235235

0 commit comments

Comments
 (0)