3
3
import json
4
4
import time
5
5
6
- from msal .token_cache import *
6
+ from msal .token_cache import TokenCache , SerializableTokenCache
7
7
from tests import unittest
8
8
9
9
@@ -51,6 +51,8 @@ class TokenCacheTestCase(unittest.TestCase):
51
51
52
52
def setUp (self ):
53
53
self .cache = TokenCache ()
54
+ self .at_key_maker = self .cache .key_makers [
55
+ TokenCache .CredentialType .ACCESS_TOKEN ]
54
56
55
57
def testAddByAad (self ):
56
58
client_id = "my_client_id"
@@ -78,11 +80,8 @@ def testAddByAad(self):
78
80
'target' : 's1 s2 s3' , # Sorted
79
81
'token_type' : 'some type' ,
80
82
}
81
- self .assertEqual (
82
- access_token_entry ,
83
- self .cache ._cache ["AccessToken" ].get (
84
- 'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s1 s2 s3' )
85
- )
83
+ self .assertEqual (access_token_entry , self .cache ._cache ["AccessToken" ].get (
84
+ self .at_key_maker (** access_token_entry )))
86
85
self .assertIn (
87
86
access_token_entry ,
88
87
self .cache .find (self .cache .CredentialType .ACCESS_TOKEN ),
@@ -144,8 +143,7 @@ def testAddByAdfs(self):
144
143
expires_in = 3600 , access_token = "an access token" ,
145
144
id_token = id_token , refresh_token = "a refresh token" ),
146
145
}, now = 1000 )
147
- self .assertEqual (
148
- {
146
+ access_token_entry = {
149
147
'cached_at' : "1000" ,
150
148
'client_id' : 'my_client_id' ,
151
149
'credential_type' : 'AccessToken' ,
@@ -157,10 +155,9 @@ def testAddByAdfs(self):
157
155
'secret' : 'an access token' ,
158
156
'target' : 's1 s2 s3' , # Sorted
159
157
'token_type' : 'some type' ,
160
- },
161
- self .cache ._cache ["AccessToken" ].get (
162
- 'subject-fs.msidlab8.com-accesstoken-my_client_id-adfs-s1 s2 s3' )
163
- )
158
+ }
159
+ self .assertEqual (access_token_entry , self .cache ._cache ["AccessToken" ].get (
160
+ self .at_key_maker (** access_token_entry )))
164
161
self .assertEqual (
165
162
{
166
163
'client_id' : 'my_client_id' ,
@@ -238,37 +235,32 @@ def _test_data_should_be_saved_and_searchable_in_access_token(self, data):
238
235
def test_extra_data_should_also_be_recorded_and_searchable_in_access_token (self ):
239
236
self ._test_data_should_be_saved_and_searchable_in_access_token ({"key_id" : "1" })
240
237
241
- def test_key_id_is_also_recorded (self ):
242
- my_key_id = "some_key_id_123"
243
- self .cache .add ({
244
- "data" : {"key_id" : my_key_id },
245
- "client_id" : "my_client_id" ,
246
- "scope" : ["s2" , "s1" , "s3" ], # Not in particular order
247
- "token_endpoint" : "https://login.example.com/contoso/v2/token" ,
248
- "response" : build_response (
249
- uid = "uid" , utid = "utid" , # client_info
250
- expires_in = 3600 , access_token = "an access token" ,
251
- refresh_token = "a refresh token" ),
252
- }, now = 1000 )
253
- cached_key_id = self .cache ._cache ["AccessToken" ].get (
254
- 'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s1 s2 s3' ,
255
- {}).get ("key_id" )
256
- self .assertEqual (my_key_id , cached_key_id , "AT should be bound to the key" )
238
+ def test_access_tokens_with_different_key_id (self ):
239
+ self ._test_data_should_be_saved_and_searchable_in_access_token ({"key_id" : "1" })
240
+ self ._test_data_should_be_saved_and_searchable_in_access_token ({"key_id" : "2" })
241
+ self .assertEqual (
242
+ len (self .cache ._cache ["AccessToken" ]),
243
+ 1 , """Historically, tokens are not keyed by key_id,
244
+ so a new token overwrites the old one, and we would end up with 1 token in cache""" )
257
245
258
246
def test_refresh_in_should_be_recorded_as_refresh_on (self ): # Sounds weird. Yep.
247
+ scopes = ["s2" , "s1" , "s3" ] # Not in particular order
259
248
self .cache .add ({
260
249
"client_id" : "my_client_id" ,
261
- "scope" : [ "s2" , "s1" , "s3" ], # Not in particular order
250
+ "scope" : scopes ,
262
251
"token_endpoint" : "https://login.example.com/contoso/v2/token" ,
263
252
"response" : build_response (
264
253
uid = "uid" , utid = "utid" , # client_info
265
254
expires_in = 3600 , refresh_in = 1800 , access_token = "an access token" ,
266
255
), #refresh_token="a refresh token"),
267
256
}, now = 1000 )
268
- refresh_on = self .cache ._cache ["AccessToken" ].get (
269
- 'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s1 s2 s3' ,
270
- {}).get ("refresh_on" )
271
- self .assertEqual ("2800" , refresh_on , "Should save refresh_on" )
257
+ at = self .assertFoundAccessToken (scopes = scopes , query = dict (
258
+ client_id = "my_client_id" ,
259
+ environment = "login.example.com" ,
260
+ realm = "contoso" ,
261
+ home_account_id = "uid.utid" ,
262
+ ))
263
+ self .assertEqual ("2800" , at .get ("refresh_on" ), "Should save refresh_on" )
272
264
273
265
def test_old_rt_data_with_wrong_key_should_still_be_salvaged_into_new_rt (self ):
274
266
sample = {
0 commit comments