Skip to content

Commit e80b58f

Browse files
committed
Add the missing token query check
1 parent 84bdfab commit e80b58f

File tree

2 files changed

+51
-13
lines changed

2 files changed

+51
-13
lines changed

msal/token_cache.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,12 @@ def _get(self, credential_type, key, default=None): # O(1)
117117
with self._lock:
118118
return self._cache.get(credential_type, {}).get(key, default)
119119

120+
@staticmethod
121+
def _is_matching(entry: dict, query: dict, target_set: set):
122+
return is_subdict_of(query or {}, entry) and (
123+
target_set <= set(entry.get("target", "").split())
124+
if target_set else True)
125+
120126
def _find(self, credential_type, target=None, query=None): # O(n) generator
121127
"""Returns a generator of matching entries.
122128
@@ -125,6 +131,7 @@ def _find(self, credential_type, target=None, query=None): # O(n) generator
125131
"""
126132
target = sorted(target or []) # Match the order sorted by add()
127133
assert isinstance(target, list), "Invalid parameter type"
134+
target_set = set(target)
128135

129136
preferred_result = None
130137
if (credential_type == self.CredentialType.ACCESS_TOKEN
@@ -135,20 +142,20 @@ def _find(self, credential_type, target=None, query=None): # O(n) generator
135142
preferred_result = self._get_access_token(
136143
query["home_account_id"], query["environment"],
137144
query["client_id"], query["realm"], target)
138-
if preferred_result:
145+
if preferred_result and self._is_matching(
146+
preferred_result, query, target_set,
147+
):
139148
yield preferred_result
140149

141-
target_set = set(target)
142150
with self._lock:
143151
# Since the target inside token cache key is (per schema) unsorted,
144152
# there is no point to attempt an O(1) key-value search here.
145153
# So we always do an O(n) in-memory search.
146154
for entry in self._cache.get(credential_type, {}).values():
147-
if is_subdict_of(query or {}, entry) and (
148-
target_set <= set(entry.get("target", "").split())
149-
if target else True):
150-
if entry != preferred_result: # Avoid yielding the same entry twice
151-
yield entry
155+
if (entry != preferred_result # Avoid yielding the same entry twice
156+
and self._is_matching(entry, query, target_set)
157+
):
158+
yield entry
152159

153160
def find(self, credential_type, target=None, query=None): # Obsolete. Use _find() instead.
154161
return list(self._find(credential_type, target=target, query=query))

tests/test_e2e.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -679,11 +679,28 @@ def _test_acquire_token_by_client_secret(
679679

680680
class PopWithExternalKeyTestCase(LabBasedTestCase):
681681
def _test_service_principal(self):
682-
# Any SP can obtain an ssh-cert. Here we use the lab app.
683-
result = get_lab_app().acquire_token_for_client(self.SCOPE, data=self.DATA1)
682+
app = get_lab_app() # Any SP can obtain an ssh-cert. Here we use the lab app.
683+
result = app.acquire_token_for_client(self.SCOPE, data=self.DATA1)
684684
self.assertIsNotNone(result.get("access_token"), "Encountered {}: {}".format(
685685
result.get("error"), result.get("error_description")))
686686
self.assertEqual(self.EXPECTED_TOKEN_TYPE, result["token_type"])
687+
self.assertEqual(result["token_source"], "identity_provider")
688+
689+
# Test cache hit
690+
cached_result = app.acquire_token_for_client(self.SCOPE, data=self.DATA1)
691+
self.assertIsNotNone(
692+
cached_result.get("access_token"), "Encountered {}: {}".format(
693+
cached_result.get("error"), cached_result.get("error_description")))
694+
self.assertEqual(self.EXPECTED_TOKEN_TYPE, cached_result["token_type"])
695+
self.assertEqual(cached_result["token_source"], "cache")
696+
697+
# refresh_token grant can fetch an ssh-cert bound to a different key
698+
refreshed_result = app.acquire_token_for_client(self.SCOPE, data=self.DATA2)
699+
self.assertIsNotNone(
700+
refreshed_result.get("access_token"), "Encountered {}: {}".format(
701+
refreshed_result.get("error"), refreshed_result.get("error_description")))
702+
self.assertEqual(self.EXPECTED_TOKEN_TYPE, refreshed_result["token_type"])
703+
self.assertEqual(refreshed_result["token_source"], "identity_provider")
687704

688705
def _test_user_account(self):
689706
lab_user = self.get_lab_user(usertype="cloud")
@@ -701,16 +718,30 @@ def _test_user_account(self):
701718
self.assertIsNotNone(result.get("access_token"), "Encountered {}: {}".format(
702719
result.get("error"), result.get("error_description")))
703720
self.assertEqual(self.EXPECTED_TOKEN_TYPE, result["token_type"])
721+
self.assertEqual(result["token_source"], "identity_provider")
704722
logger.debug("%s.cache = %s",
705723
self.id(), json.dumps(self.app.token_cache._cache, indent=4))
706724

725+
# refresh_token grant can hit an ssh-cert bound to the same key
726+
account = self.app.get_accounts()[0]
727+
cached_result = self.app.acquire_token_silent(
728+
self.SCOPE, account=account, data=self.DATA1)
729+
self.assertIsNotNone(cached_result)
730+
self.assertEqual(self.EXPECTED_TOKEN_TYPE, cached_result["token_type"])
731+
## Actually, the self._test_acquire_token_interactive() already contained
732+
## a built-in refresh test, so the token in cache has been refreshed already.
733+
## Therefore, the following line won't pass, which is expected.
734+
#self.assertEqual(result["access_token"], cached_result['access_token'])
735+
self.assertEqual(cached_result["token_source"], "cache")
736+
707737
# refresh_token grant can fetch an ssh-cert bound to a different key
708738
account = self.app.get_accounts()[0]
709-
refreshed_ssh_cert = self.app.acquire_token_silent(
739+
refreshed_result = self.app.acquire_token_silent(
710740
self.SCOPE, account=account, data=self.DATA2)
711-
self.assertIsNotNone(refreshed_ssh_cert)
712-
self.assertEqual(self.EXPECTED_TOKEN_TYPE, refreshed_ssh_cert["token_type"])
713-
self.assertNotEqual(result["access_token"], refreshed_ssh_cert['access_token'])
741+
self.assertIsNotNone(refreshed_result)
742+
self.assertEqual(self.EXPECTED_TOKEN_TYPE, refreshed_result["token_type"])
743+
self.assertNotEqual(result["access_token"], refreshed_result['access_token'])
744+
self.assertEqual(refreshed_result["token_source"], "identity_provider")
714745

715746

716747
class SshCertTestCase(PopWithExternalKeyTestCase):

0 commit comments

Comments
 (0)