Skip to content

Commit ede849e

Browse files
committed
Store extra data into access token in cache
1 parent 331c16f commit ede849e

File tree

2 files changed

+38
-3
lines changed

2 files changed

+38
-3
lines changed

msal/token_cache.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,8 +249,11 @@ def __add(self, event, now=None):
249249
"expires_on": str(now + expires_in), # Same here
250250
"extended_expires_on": str(now + ext_expires_in) # Same here
251251
}
252-
if data.get("key_id"): # It happens in SSH-cert or POP scenario
253-
at["key_id"] = data.get("key_id")
252+
at.update({k: data[k] for k in data if k in {
253+
# Also store extra data which we explicitly allow
254+
# So that we won't accidentally store a user's password etc.
255+
"key_id", # It happens in SSH-cert or POP scenario
256+
}})
254257
if "refresh_in" in response:
255258
refresh_in = response["refresh_in"] # It is an integer
256259
at["refresh_on"] = str(now + refresh_in) # Schema wants a string

tests/test_token_cache.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,38 @@ def testAddByAdfs(self):
206206
"appmetadata-fs.msidlab8.com-my_client_id")
207207
)
208208

209+
def assertFoundAccessToken(self, *, scopes, query, data=None):
210+
cached_at = None
211+
for cached_at in self.cache.search(
212+
TokenCache.CredentialType.ACCESS_TOKEN, target=scopes, query=query):
213+
for k, v in (data or {}).items(): # The extra data, if any
214+
self.assertEqual(cached_at.get(k), v, f"AT should contain {k}={v}")
215+
self.assertTrue(cached_at, "AT should be cached and searchable")
216+
return cached_at
217+
218+
def _test_data_should_be_saved_and_searchable_in_access_token(self, data):
219+
scopes = ["s2", "s1", "s3"] # Not in particular order
220+
self.cache.add({
221+
"data": data,
222+
"client_id": "my_client_id",
223+
"scope": scopes,
224+
"token_endpoint": "https://login.example.com/contoso/v2/token",
225+
"response": build_response(
226+
uid="uid", utid="utid", # client_info
227+
expires_in=3600, access_token="an access token",
228+
refresh_token="a refresh token"),
229+
}, now=1000)
230+
self.assertFoundAccessToken(scopes=scopes, data=data, query=dict(
231+
data, # Also use the extra data as a query criteria
232+
client_id="my_client_id",
233+
environment="login.example.com",
234+
realm="contoso",
235+
home_account_id="uid.utid",
236+
))
237+
238+
def test_extra_data_should_also_be_recorded_and_searchable_in_access_token(self):
239+
self._test_data_should_be_saved_and_searchable_in_access_token({"key_id": "1"})
240+
209241
def test_key_id_is_also_recorded(self):
210242
my_key_id = "some_key_id_123"
211243
self.cache.add({
@@ -258,7 +290,7 @@ def test_old_rt_data_with_wrong_key_should_still_be_salvaged_into_new_rt(self):
258290
)
259291

260292

261-
class SerializableTokenCacheTestCase(TokenCacheTestCase):
293+
class SerializableTokenCacheTestCase(unittest.TestCase):
262294
# Run all inherited test methods, and have extra check in tearDown()
263295

264296
def setUp(self):

0 commit comments

Comments
 (0)