Skip to content

Commit e969e64

Browse files
committed
Merge branch 'refactor-token-cache-test-cases' into dev
2 parents deb7900 + 8eb5c18 commit e969e64

File tree

2 files changed

+52
-49
lines changed

2 files changed

+52
-49
lines changed

tests/test_application.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import msal
66
from msal.application import _merge_claims_challenge_and_capabilities
77
from tests import unittest
8-
from tests.test_token_cache import TokenCacheTestCase
8+
from tests.test_token_cache import build_id_token, build_response
99
from tests.http_client import MinimalHttpClient, MinimalResponse
1010
from msal.telemetry import CLIENT_CURRENT_TELEMETRY, CLIENT_LAST_TELEMETRY
1111

@@ -66,7 +66,7 @@ def setUp(self):
6666
"client_id": self.client_id,
6767
"scope": self.scopes,
6868
"token_endpoint": "{}/oauth2/v2.0/token".format(self.authority_url),
69-
"response": TokenCacheTestCase.build_response(
69+
"response": build_response(
7070
access_token="an expired AT to trigger refresh", expires_in=-99,
7171
uid=self.uid, utid=self.utid, refresh_token=self.rt),
7272
}) # The add(...) helper populates correct home_account_id for future searching
@@ -125,9 +125,9 @@ def setUp(self):
125125
"client_id": self.preexisting_family_app_id,
126126
"scope": self.scopes,
127127
"token_endpoint": "{}/oauth2/v2.0/token".format(self.authority_url),
128-
"response": TokenCacheTestCase.build_response(
128+
"response": build_response(
129129
access_token="Siblings won't share AT. test_remove_account() will.",
130-
id_token=TokenCacheTestCase.build_id_token(aud=self.preexisting_family_app_id),
130+
id_token=build_id_token(aud=self.preexisting_family_app_id),
131131
uid=self.uid, utid=self.utid, refresh_token=self.frt, foci="1"),
132132
}) # The add(...) helper populates correct home_account_id for future searching
133133

@@ -153,8 +153,7 @@ def test_known_orphan_app_will_skip_frt_and_only_use_its_own_rt(self):
153153
"client_id": app.client_id,
154154
"scope": self.scopes,
155155
"token_endpoint": "{}/oauth2/v2.0/token".format(self.authority_url),
156-
"response": TokenCacheTestCase.build_response(
157-
uid=self.uid, utid=self.utid, refresh_token=rt),
156+
"response": build_response(uid=self.uid, utid=self.utid, refresh_token=rt),
158157
})
159158
logger.debug("%s.cache = %s", self.id(), self.cache.serialize())
160159
def tester(url, data=None, **kwargs):
@@ -168,7 +167,7 @@ def tester(url, data=None, **kwargs):
168167
self.assertEqual(
169168
self.frt, data.get("refresh_token"), "Should attempt the FRT")
170169
return MinimalResponse(
171-
status_code=200, text=json.dumps(TokenCacheTestCase.build_response(
170+
status_code=200, text=json.dumps(build_response(
172171
uid=self.uid, utid=self.utid, foci="1", access_token="at")))
173172
app = ClientApplication(
174173
"unknown_family_app", authority=self.authority_url, token_cache=self.cache)
@@ -246,7 +245,7 @@ def setUp(self):
246245
"scope": self.scopes,
247246
"token_endpoint": "https://{}/common/oauth2/v2.0/token".format(
248247
self.environment_in_cache),
249-
"response": TokenCacheTestCase.build_response(
248+
"response": build_response(
250249
uid=uid, utid=utid,
251250
access_token=self.access_token, refresh_token="some refresh token"),
252251
}) # The add(...) helper populates correct home_account_id for future searching
@@ -342,7 +341,7 @@ def populate_cache(self, access_token="at", expires_in=86400, refresh_in=43200):
342341
"client_id": self.client_id,
343342
"scope": self.scopes,
344343
"token_endpoint": "{}/oauth2/v2.0/token".format(self.authority_url),
345-
"response": TokenCacheTestCase.build_response(
344+
"response": build_response(
346345
access_token=access_token,
347346
expires_in=expires_in, refresh_in=refresh_in,
348347
uid=self.uid, utid=self.utid, refresh_token=self.rt),
@@ -424,7 +423,7 @@ def populate_cache(self, cache, access_token="at"):
424423
"client_id": self.client_id,
425424
"scope": self.scopes,
426425
"token_endpoint": "{}/oauth2/v2.0/token".format(self.authority_url),
427-
"response": TokenCacheTestCase.build_response(
426+
"response": build_response(
428427
access_token=access_token,
429428
uid=self.uid, utid=self.utid, refresh_token=self.rt),
430429
})
@@ -571,9 +570,9 @@ def test_get_accounts(self):
571570
"scope": scopes,
572571
"token_endpoint":
573572
"https://{}/{}/oauth2/v2.0/token".format(environment, tenant),
574-
"response": TokenCacheTestCase.build_response(
573+
"response": build_response(
575574
uid=uid, utid=utid, access_token="at", refresh_token="rt",
576-
id_token=TokenCacheTestCase.build_id_token(
575+
id_token=build_id_token(
577576
aud=client_id,
578577
sub="oid_in_" + tenant,
579578
preferred_username=username,

tests/test_token_cache.py

Lines changed: 41 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -11,52 +11,56 @@
1111
logging.basicConfig(level=logging.DEBUG)
1212

1313

14-
class TokenCacheTestCase(unittest.TestCase):
14+
# NOTE: These helpers were once implemented as static methods in TokenCacheTestCase.
15+
# That would cause other test files' "from ... import TokenCacheTestCase"
16+
# to re-run all test cases in this file.
17+
# Now we avoid that, by defining these helpers in module level.
18+
def build_id_token(
19+
iss="issuer", sub="subject", aud="my_client_id", exp=None, iat=None,
20+
**claims): # AAD issues "preferred_username", ADFS issues "upn"
21+
return "header.%s.signature" % base64.b64encode(json.dumps(dict({
22+
"iss": iss,
23+
"sub": sub,
24+
"aud": aud,
25+
"exp": exp or (time.time() + 100),
26+
"iat": iat or time.time(),
27+
}, **claims)).encode()).decode('utf-8')
28+
1529

16-
@staticmethod
17-
def build_id_token(
18-
iss="issuer", sub="subject", aud="my_client_id", exp=None, iat=None,
19-
**claims): # AAD issues "preferred_username", ADFS issues "upn"
20-
return "header.%s.signature" % base64.b64encode(json.dumps(dict({
21-
"iss": iss,
22-
"sub": sub,
23-
"aud": aud,
24-
"exp": exp or (time.time() + 100),
25-
"iat": iat or time.time(),
26-
}, **claims)).encode()).decode('utf-8')
30+
def build_response( # simulate a response from AAD
31+
uid=None, utid=None, # If present, they will form client_info
32+
access_token=None, expires_in=3600, token_type="some type",
33+
**kwargs # Pass-through: refresh_token, foci, id_token, error, refresh_in, ...
34+
):
35+
response = {}
36+
if uid and utid: # Mimic the AAD behavior for "client_info=1" request
37+
response["client_info"] = base64.b64encode(json.dumps({
38+
"uid": uid, "utid": utid,
39+
}).encode()).decode('utf-8')
40+
if access_token:
41+
response.update({
42+
"access_token": access_token,
43+
"expires_in": expires_in,
44+
"token_type": token_type,
45+
})
46+
response.update(kwargs) # Pass-through key-value pairs as top-level fields
47+
return response
2748

28-
@staticmethod
29-
def build_response( # simulate a response from AAD
30-
uid=None, utid=None, # If present, they will form client_info
31-
access_token=None, expires_in=3600, token_type="some type",
32-
**kwargs # Pass-through: refresh_token, foci, id_token, error, refresh_in, ...
33-
):
34-
response = {}
35-
if uid and utid: # Mimic the AAD behavior for "client_info=1" request
36-
response["client_info"] = base64.b64encode(json.dumps({
37-
"uid": uid, "utid": utid,
38-
}).encode()).decode('utf-8')
39-
if access_token:
40-
response.update({
41-
"access_token": access_token,
42-
"expires_in": expires_in,
43-
"token_type": token_type,
44-
})
45-
response.update(kwargs) # Pass-through key-value pairs as top-level fields
46-
return response
49+
50+
class TokenCacheTestCase(unittest.TestCase):
4751

4852
def setUp(self):
4953
self.cache = TokenCache()
5054

5155
def testAddByAad(self):
5256
client_id = "my_client_id"
53-
id_token = self.build_id_token(
57+
id_token = build_id_token(
5458
oid="object1234", preferred_username="John Doe", aud=client_id)
5559
self.cache.add({
5660
"client_id": client_id,
5761
"scope": ["s2", "s1", "s3"], # Not in particular order
5862
"token_endpoint": "https://login.example.com/contoso/v2/token",
59-
"response": self.build_response(
63+
"response": build_response(
6064
uid="uid", utid="utid", # client_info
6165
expires_in=3600, access_token="an access token",
6266
id_token=id_token, refresh_token="a refresh token"),
@@ -125,12 +129,12 @@ def testAddByAad(self):
125129

126130
def testAddByAdfs(self):
127131
client_id = "my_client_id"
128-
id_token = self.build_id_token(aud=client_id, upn="[email protected]")
132+
id_token = build_id_token(aud=client_id, upn="[email protected]")
129133
self.cache.add({
130134
"client_id": client_id,
131135
"scope": ["s2", "s1", "s3"], # Not in particular order
132136
"token_endpoint": "https://fs.msidlab8.com/adfs/oauth2/token",
133-
"response": self.build_response(
137+
"response": build_response(
134138
uid=None, utid=None, # ADFS will provide no client_info
135139
expires_in=3600, access_token="an access token",
136140
id_token=id_token, refresh_token="a refresh token"),
@@ -204,7 +208,7 @@ def test_key_id_is_also_recorded(self):
204208
"client_id": "my_client_id",
205209
"scope": ["s2", "s1", "s3"], # Not in particular order
206210
"token_endpoint": "https://login.example.com/contoso/v2/token",
207-
"response": self.build_response(
211+
"response": build_response(
208212
uid="uid", utid="utid", # client_info
209213
expires_in=3600, access_token="an access token",
210214
refresh_token="a refresh token"),
@@ -219,7 +223,7 @@ def test_refresh_in_should_be_recorded_as_refresh_on(self): # Sounds weird. Yep
219223
"client_id": "my_client_id",
220224
"scope": ["s2", "s1", "s3"], # Not in particular order
221225
"token_endpoint": "https://login.example.com/contoso/v2/token",
222-
"response": self.build_response(
226+
"response": build_response(
223227
uid="uid", utid="utid", # client_info
224228
expires_in=3600, refresh_in=1800, access_token="an access token",
225229
), #refresh_token="a refresh token"),

0 commit comments

Comments
 (0)