|
1 | 1 | # Note: Since Aug 2019 we move all e2e tests into test_e2e.py,
|
2 | 2 | # so this test_application file contains only unit tests without dependency.
|
| 3 | +import hashlib |
3 | 4 | import json
|
4 | 5 | import logging
|
5 | 6 | import sys
|
@@ -56,6 +57,35 @@ def test_bytes_to_bytes(self):
|
56 | 57 | self.assertEqual(type(_str2bytes(b"some bytes")), type(b"bytes"))
|
57 | 58 |
|
58 | 59 |
|
| 60 | +def fake_token_getter( |
| 61 | + *, |
| 62 | + access_token: str = "an access token", |
| 63 | + status_code: int = 200, |
| 64 | + expires_in: int = 3600, |
| 65 | + token_type: str = "Bearer", |
| 66 | + payload: dict = None, |
| 67 | + headers: dict = None, |
| 68 | +): |
| 69 | + """A helper to create a fake token getter, |
| 70 | + which will be consumed by ClientApplication's acquire methods' post parameter. |
| 71 | +
|
| 72 | + Generic mock.patch() is inconvenient because: |
| 73 | + 1. If you patch it at or above oauth2.py _obtain_token(), token cache is not populated. |
| 74 | + 2. If you patch it at request.post(), your test cases become fragile because |
| 75 | + more http round-trips may be added for future flows, |
| 76 | + then your existing test case would break until you mock new round-trips. |
| 77 | + """ |
| 78 | + return lambda url, *args, **kwargs: MinimalResponse( |
| 79 | + status_code=status_code, |
| 80 | + text=json.dumps(payload or { |
| 81 | + "access_token": access_token, |
| 82 | + "expires_in": expires_in, |
| 83 | + "token_type": token_type, |
| 84 | + }), |
| 85 | + headers=headers, |
| 86 | + ) |
| 87 | + |
| 88 | + |
59 | 89 | class TestClientApplicationAcquireTokenSilentErrorBehaviors(unittest.TestCase):
|
60 | 90 |
|
61 | 91 | def setUp(self):
|
@@ -856,3 +886,45 @@ def test_app_did_not_register_redirect_uri_should_error_out(self):
|
856 | 886 | )
|
857 | 887 | self.assertEqual(result.get("error"), "broker_error")
|
858 | 888 |
|
| 889 | + |
| 890 | +@patch("msal.authority.tenant_discovery", new=Mock(return_value={ |
| 891 | + "authorization_endpoint": "https://contoso.com/placeholder", |
| 892 | + "token_endpoint": "https://contoso.com/placeholder", |
| 893 | + })) |
| 894 | +class AccessTokenToRefreshTestCase(unittest.TestCase): |
| 895 | + def test_mismatching_hash_should_not_trigger_refresh(self): |
| 896 | + scopes = ["scope"] |
| 897 | + token1 = "AT one" |
| 898 | + token1_hash = hashlib.sha256(token1.encode()).hexdigest() |
| 899 | + token2 = "AT two" |
| 900 | + app = msal.ConfidentialClientApplication("foo", client_credential="bar") |
| 901 | + |
| 902 | + # Prepopulate cache |
| 903 | + app.acquire_token_for_client(scopes, post=fake_token_getter(access_token=token1)) |
| 904 | + self.assertNotEqual(app.token_cache._cache, {}, "Cache should have been populated") |
| 905 | + |
| 906 | + # Test mismatching hash should not trigger refresh |
| 907 | + result = app.acquire_token_for_client( |
| 908 | + scopes, |
| 909 | + access_token_sha256_to_refresh="mismatching hash", |
| 910 | + post=fake_token_getter(access_token=token2)) |
| 911 | + self.assertEqual(result.get("access_token"), token1, "Should hit old token") |
| 912 | + self.assertEqual(result.get("token_source"), app._TOKEN_SOURCE_CACHE) |
| 913 | + |
| 914 | + # Test matching hash should trigger refresh |
| 915 | + result = app.acquire_token_for_client( |
| 916 | + scopes, |
| 917 | + access_token_sha256_to_refresh=token1_hash, |
| 918 | + post=fake_token_getter(access_token=token2)) |
| 919 | + self.assertEqual(result.get("access_token"), token2, "Should obtain new token") |
| 920 | + self.assertEqual(result.get("token_source"), app._TOKEN_SOURCE_IDP) |
| 921 | + |
| 922 | + # A client using old token1, even with claims challenge, |
| 923 | + # should not trigger refresh, because we have token2 in cache. |
| 924 | + result = app.acquire_token_for_client( |
| 925 | + scopes, |
| 926 | + access_token_sha256_to_refresh=token1_hash, |
| 927 | + claims_challenge='{"access_token": {"xms_cc": {"values": ["challenge for token1"]}}}', |
| 928 | + post=fake_token_getter(access_token="AT three")) |
| 929 | + self.assertEqual(result.get("access_token"), token2, "Token 2 should be returned") |
| 930 | + self.assertEqual(result.get("token_source"), app._TOKEN_SOURCE_CACHE) |
0 commit comments