Skip to content

Commit 29d1ac1

Browse files
committed
Only cache desirable data in http cache
1 parent b92b4f1 commit 29d1ac1

File tree

6 files changed

+180
-40
lines changed

6 files changed

+180
-40
lines changed

msal/application.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,7 @@ def __init__(
499499
except (
500500
FileNotFoundError, # Or IOError in Python 2
501501
pickle.UnpicklingError, # A corrupted http cache file
502+
AttributeError, # Cache created by a different version of MSAL
502503
):
503504
persisted_http_cache = {} # Recover by starting afresh
504505
atexit.register(lambda: pickle.dump(

msal/managed_identity.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,8 @@ def __init__(self, *, client_id=None, resource_id=None, object_id=None):
112112

113113

114114
class _ThrottledHttpClient(ThrottledHttpClientBase):
115-
def __init__(self, http_client, **kwargs):
116-
super(_ThrottledHttpClient, self).__init__(http_client, **kwargs)
115+
def __init__(self, *args, **kwargs):
116+
super(_ThrottledHttpClient, self).__init__(*args, **kwargs)
117117
self.get = IndividualCache( # All MIs (except Cloud Shell) use GETs
118118
mapping=self._expiring_mapping,
119119
key_maker=lambda func, args, kwargs: "REQ {} hash={} 429/5xx/Retry-After".format(
@@ -124,7 +124,7 @@ def __init__(self, http_client, **kwargs):
124124
str(kwargs.get("params")) + str(kwargs.get("data"))),
125125
),
126126
expires_in=RetryAfterParser(5).parse, # 5 seconds default for non-PCA
127-
)(http_client.get)
127+
)(self.get) # Note: Decorate the parent get(), not the http_client.get()
128128

129129

130130
class ManagedIdentityClient(object):
@@ -233,8 +233,7 @@ def __init__(
233233
# (especially for 410 which was supposed to be a permanent failure).
234234
# 2. MI on Service Fabric specifically suggests to not retry on 404.
235235
# ( https://learn.microsoft.com/en-us/azure/service-fabric/how-to-managed-cluster-managed-identity-service-fabric-app-code#error-handling )
236-
http_client.http_client # Patch the raw (unpatched) http client
237-
if isinstance(http_client, ThrottledHttpClientBase) else http_client,
236+
http_client,
238237
http_cache=http_cache,
239238
)
240239
self._token_cache = token_cache or TokenCache()

msal/throttled_http_client.py

Lines changed: 46 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,16 @@
33

44
from .individual_cache import _IndividualCache as IndividualCache
55
from .individual_cache import _ExpiringMapping as ExpiringMapping
6+
from .oauth2cli.http import Response
7+
from .exceptions import MsalServiceError
68

79

810
# https://datatracker.ietf.org/doc/html/rfc8628#section-3.4
911
DEVICE_AUTH_GRANT = "urn:ietf:params:oauth:grant-type:device_code"
1012

1113

1214
class RetryAfterParser(object):
15+
FIELD_NAME_LOWER = "Retry-After".lower()
1316
def __init__(self, default_value=None):
1417
self._default_value = 5 if default_value is None else default_value
1518

@@ -20,9 +23,9 @@ def parse(self, *, result, **ignored):
2023
# Historically, MSAL's HttpResponse does not always have headers
2124
response, "headers", {}).items()}
2225
if not (response.status_code == 429 or response.status_code >= 500
23-
or "retry-after" in lowercase_headers):
26+
or self.FIELD_NAME_LOWER in lowercase_headers):
2427
return 0 # Quick exit
25-
retry_after = lowercase_headers.get("retry-after", self._default_value)
28+
retry_after = lowercase_headers.get(self.FIELD_NAME_LOWER, self._default_value)
2629
try:
2730
# AAD's retry_after uses integer format only
2831
# https://stackoverflow.microsoft.com/questions/264931/264932
@@ -37,27 +40,52 @@ def _extract_data(kwargs, key, default=None):
3740
return data.get(key) if isinstance(data, dict) else default
3841

3942

43+
class NormalizedResponse(Response):
44+
"""A http response with the shape defined in Response,
45+
but contains only the data we will store in cache.
46+
"""
47+
def __init__(self, raw_response):
48+
super().__init__()
49+
self.status_code = raw_response.status_code
50+
self.text = raw_response.text
51+
self.headers = { # Only keep the headers which ThrottledHttpClient cares about
52+
k: v for k, v in raw_response.headers.items()
53+
if k.lower() == RetryAfterParser.FIELD_NAME_LOWER
54+
}
55+
56+
## Note: Don't use the following line,
57+
## because when being pickled, it will indirectly pickle the whole raw_response
58+
# self.raise_for_status = raw_response.raise_for_status
59+
def raise_for_status(self):
60+
if self.status_code >= 400:
61+
raise MsalServiceError("HTTP Error: {}".format(self.status_code))
62+
63+
4064
class ThrottledHttpClientBase(object):
4165
"""Throttle the given http_client by storing and retrieving data from cache.
4266
43-
This wrapper exists so that our patching post() and get() would prevent
44-
re-patching side effect when/if same http_client being reused.
67+
This base exists so that:
68+
1. These base post() and get() will return a NormalizedResponse
69+
2. The base __init__() will NOT re-throttle even if caller accidentally nested ThrottledHttpClient.
4570
46-
The subclass should implement post() and/or get()
71+
Subclasses shall only need to dynamically decorate their post() and get() methods
72+
in their __init__() method.
4773
"""
4874
def __init__(self, http_client, *, http_cache=None):
49-
self.http_client = http_client
75+
self.http_client = http_client.http_client if isinstance(
76+
# If it is already a ThrottledHttpClientBase, we use its raw (unthrottled) http client
77+
http_client, ThrottledHttpClientBase) else http_client
5078
self._expiring_mapping = ExpiringMapping( # It will automatically clean up
5179
mapping=http_cache if http_cache is not None else {},
5280
capacity=1024, # To prevent cache blowing up especially for CCA
5381
lock=Lock(), # TODO: This should ideally also allow customization
5482
)
5583

5684
def post(self, *args, **kwargs):
57-
return self.http_client.post(*args, **kwargs)
85+
return NormalizedResponse(self.http_client.post(*args, **kwargs))
5886

5987
def get(self, *args, **kwargs):
60-
return self.http_client.get(*args, **kwargs)
88+
return NormalizedResponse(self.http_client.get(*args, **kwargs))
6189

6290
def close(self):
6391
return self.http_client.close()
@@ -68,12 +96,11 @@ def _hash(raw):
6896

6997

7098
class ThrottledHttpClient(ThrottledHttpClientBase):
71-
def __init__(self, http_client, *, default_throttle_time=None, **kwargs):
72-
super(ThrottledHttpClient, self).__init__(http_client, **kwargs)
73-
74-
_post = http_client.post # We'll patch _post, and keep original post() intact
75-
76-
_post = IndividualCache(
99+
"""A throttled http client that is used by MSAL's non-managed identity clients."""
100+
def __init__(self, *args, default_throttle_time=None, **kwargs):
101+
"""Decorate self.post() and self.get() dynamically"""
102+
super(ThrottledHttpClient, self).__init__(*args, **kwargs)
103+
self.post = IndividualCache(
77104
# Internal specs requires throttling on at least token endpoint,
78105
# here we have a generic patch for POST on all endpoints.
79106
mapping=self._expiring_mapping,
@@ -91,9 +118,9 @@ def __init__(self, http_client, *, default_throttle_time=None, **kwargs):
91118
_extract_data(kwargs, "username")))), # "account" of ROPC
92119
),
93120
expires_in=RetryAfterParser(default_throttle_time or 5).parse,
94-
)(_post)
121+
)(self.post)
95122

96-
_post = IndividualCache( # It covers the "UI required cache"
123+
self.post = IndividualCache( # It covers the "UI required cache"
97124
mapping=self._expiring_mapping,
98125
key_maker=lambda func, args, kwargs: "POST {} hash={} 400".format(
99126
args[0], # It is the url, typically containing authority and tenant
@@ -125,12 +152,10 @@ def __init__(self, http_client, *, default_throttle_time=None, **kwargs):
125152
isinstance(kwargs.get("data"), dict)
126153
and kwargs["data"].get("grant_type") == DEVICE_AUTH_GRANT
127154
)
128-
and "retry-after" not in set( # Leave it to the Retry-After decorator
155+
and RetryAfterParser.FIELD_NAME_LOWER not in set( # Otherwise leave it to the Retry-After decorator
129156
h.lower() for h in getattr(result, "headers", {}).keys())
130157
else 0,
131-
)(_post)
132-
133-
self.post = _post
158+
)(self.post)
134159

135160
self.get = IndividualCache( # Typically those discovery GETs
136161
mapping=self._expiring_mapping,
@@ -140,9 +165,4 @@ def __init__(self, http_client, *, default_throttle_time=None, **kwargs):
140165
),
141166
expires_in=lambda result=None, **ignored:
142167
3600*24 if 200 <= result.status_code < 300 else 0,
143-
)(http_client.get)
144-
145-
# The following 2 methods have been defined dynamically by __init__()
146-
#def post(self, *args, **kwargs): pass
147-
#def get(self, *args, **kwargs): pass
148-
168+
)(self.get)

tests/test_mi.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,16 @@
99
from mock import patch, ANY, mock_open, Mock
1010
import requests
1111

12-
from tests.http_client import MinimalResponse
12+
from tests.test_throttled_http_client import (
13+
MinimalResponse, ThrottledHttpClientBaseTestCase, DummyHttpClient)
1314
from msal import (
1415
SystemAssignedManagedIdentity, UserAssignedManagedIdentity,
1516
ManagedIdentityClient,
1617
ManagedIdentityError,
1718
ArcPlatformNotSupportedError,
1819
)
1920
from msal.managed_identity import (
21+
_ThrottledHttpClient,
2022
_supported_arc_platforms_and_their_prefixes,
2123
get_managed_identity_source,
2224
APP_SERVICE,
@@ -49,6 +51,37 @@ def test_helper_class_should_be_interchangable_with_dict_which_could_be_loaded_f
4951
{"ManagedIdentityIdType": "SystemAssigned", "Id": None})
5052

5153

54+
class ThrottledHttpClientTestCase(ThrottledHttpClientBaseTestCase):
55+
def test_throttled_http_client_should_not_alter_original_http_client(self):
56+
self.assertNotAlteringOriginalHttpClient(_ThrottledHttpClient)
57+
58+
def test_throttled_http_client_should_not_cache_successful_http_response(self):
59+
http_cache = {}
60+
http_client=DummyHttpClient(
61+
status_code=200,
62+
response_text='{"access_token": "AT", "expires_in": "1234", "resource": "R"}',
63+
)
64+
app = ManagedIdentityClient(
65+
SystemAssignedManagedIdentity(), http_client=http_client, http_cache=http_cache)
66+
result = app.acquire_token_for_client(resource="R")
67+
self.assertEqual("AT", result["access_token"])
68+
self.assertEqual({}, http_cache, "Should not cache successful http response")
69+
70+
def test_throttled_http_client_should_cache_unsuccessful_http_response(self):
71+
http_cache = {}
72+
http_client=DummyHttpClient(
73+
status_code=400,
74+
response_headers={"Retry-After": "1"},
75+
response_text='{"error": "invalid_request"}',
76+
)
77+
app = ManagedIdentityClient(
78+
SystemAssignedManagedIdentity(), http_client=http_client, http_cache=http_cache)
79+
result = app.acquire_token_for_client(resource="R")
80+
self.assertEqual("invalid_request", result["error"])
81+
self.assertNotEqual({}, http_cache, "Should cache unsuccessful http response")
82+
self.assertCleanPickle(http_cache)
83+
84+
5285
class ClientTestCase(unittest.TestCase):
5386
maxDiff = None
5487

tests/test_throttled_http_client.py

Lines changed: 67 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,43 @@
11
# Test cases for https://identitydivision.visualstudio.com/devex/_git/AuthLibrariesApiReview?version=GBdev&path=%2FService%20protection%2FIntial%20set%20of%20protection%20measures.md&_a=preview&anchor=common-test-cases
2+
import pickle
23
from time import sleep
34
from random import random
45
import logging
5-
from msal.throttled_http_client import ThrottledHttpClient
6+
7+
from msal.throttled_http_client import (
8+
ThrottledHttpClientBase, ThrottledHttpClient, NormalizedResponse)
9+
610
from tests import unittest
7-
from tests.http_client import MinimalResponse
11+
from tests.http_client import MinimalResponse as _MinimalResponse
812

913

1014
logger = logging.getLogger(__name__)
1115
logging.basicConfig(level=logging.DEBUG)
1216

1317

18+
class MinimalResponse(_MinimalResponse):
19+
SIGNATURE = str(random()).encode("utf-8")
20+
21+
def __init__(self, *args, **kwargs):
22+
super().__init__(*args, **kwargs)
23+
self._ = ( # Only an instance attribute will be stored in pickled instance
24+
self.__class__.SIGNATURE) # Useful for testing its presence in pickled instance
25+
26+
1427
class DummyHttpClient(object):
15-
def __init__(self, status_code=None, response_headers=None):
28+
def __init__(self, status_code=None, response_headers=None, response_text=None):
1629
self._status_code = status_code
1730
self._response_headers = response_headers
31+
self._response_text = response_text
1832

1933
def _build_dummy_response(self):
2034
return MinimalResponse(
2135
status_code=self._status_code,
2236
headers=self._response_headers,
23-
text=random(), # So that we'd know whether a new response is received
24-
)
37+
text=self._response_text if self._response_text is not None else str(
38+
random() # So that we'd know whether a new response is received
39+
),
40+
)
2541

2642
def post(self, url, params=None, data=None, headers=None, **kwargs):
2743
return self._build_dummy_response()
@@ -37,19 +53,54 @@ class CloseMethodCalled(Exception):
3753
pass
3854

3955

40-
class TestHttpDecoration(unittest.TestCase):
56+
class ThrottledHttpClientBaseTestCase(unittest.TestCase):
4157

42-
def test_throttled_http_client_should_not_alter_original_http_client(self):
58+
def assertCleanPickle(self, obj):
59+
self.assertTrue(bool(obj), "The object should not be empty")
60+
self.assertNotIn(
61+
MinimalResponse.SIGNATURE, pickle.dumps(obj),
62+
"A pickled object should not contain undesirable data")
63+
64+
def assertValidResponse(self, response):
65+
self.assertIsInstance(response, NormalizedResponse)
66+
self.assertCleanPickle(response)
67+
68+
def test_pickled_minimal_response_should_contain_signature(self):
69+
self.assertIn(MinimalResponse.SIGNATURE, pickle.dumps(MinimalResponse(
70+
status_code=200, headers={}, text="foo")))
71+
72+
def test_throttled_http_client_base_response_should_not_contain_signature(self):
73+
http_client = ThrottledHttpClientBase(DummyHttpClient(status_code=200))
74+
response = http_client.post("https://example.com")
75+
self.assertValidResponse(response)
76+
77+
def assertNotAlteringOriginalHttpClient(self, ThrottledHttpClientClass):
4378
original_http_client = DummyHttpClient()
4479
original_get = original_http_client.get
4580
original_post = original_http_client.post
46-
throttled_http_client = ThrottledHttpClient(original_http_client)
81+
throttled_http_client = ThrottledHttpClientClass(original_http_client)
4782
goal = """The implementation should wrap original http_client
4883
and keep it intact, instead of monkey-patching it"""
4984
self.assertNotEqual(throttled_http_client, original_http_client, goal)
5085
self.assertEqual(original_post, original_http_client.post)
5186
self.assertEqual(original_get, original_http_client.get)
5287

88+
def test_throttled_http_client_base_should_not_alter_original_http_client(self):
89+
self.assertNotAlteringOriginalHttpClient(ThrottledHttpClientBase)
90+
91+
def test_throttled_http_client_base_should_not_nest_http_client(self):
92+
original_http_client = DummyHttpClient()
93+
throttled_http_client = ThrottledHttpClientBase(original_http_client)
94+
self.assertIs(original_http_client, throttled_http_client.http_client)
95+
nested_throttled_http_client = ThrottledHttpClientBase(throttled_http_client)
96+
self.assertIs(original_http_client, nested_throttled_http_client.http_client)
97+
98+
99+
class ThrottledHttpClientTestCase(ThrottledHttpClientBaseTestCase):
100+
101+
def test_throttled_http_client_should_not_alter_original_http_client(self):
102+
self.assertNotAlteringOriginalHttpClient(ThrottledHttpClient)
103+
53104
def _test_RetryAfter_N_seconds_should_keep_entry_for_N_seconds(
54105
self, http_client, retry_after):
55106
http_cache = {}
@@ -112,15 +163,23 @@ def test_one_invalid_grant_should_block_a_similar_request(self):
112163
http_client = DummyHttpClient(
113164
status_code=400) # It covers invalid_grant and interaction_required
114165
http_client = ThrottledHttpClient(http_client, http_cache=http_cache)
166+
115167
resp1 = http_client.post("https://example.com", data={"claims": "foo"})
116168
logger.debug(http_cache)
169+
self.assertValidResponse(resp1)
117170
resp1_again = http_client.post("https://example.com", data={"claims": "foo"})
171+
self.assertValidResponse(resp1_again)
118172
self.assertEqual(resp1.text, resp1_again.text, "Should return a cached response")
173+
119174
resp2 = http_client.post("https://example.com", data={"claims": "bar"})
175+
self.assertValidResponse(resp2)
120176
self.assertNotEqual(resp1.text, resp2.text, "Should return a new response")
121177
resp2_again = http_client.post("https://example.com", data={"claims": "bar"})
178+
self.assertValidResponse(resp2_again)
122179
self.assertEqual(resp2.text, resp2_again.text, "Should return a cached response")
123180

181+
self.assertCleanPickle(http_cache)
182+
124183
def test_one_foci_app_recovering_from_invalid_grant_should_also_unblock_another(self):
125184
"""
126185
Need not test multiple FOCI app's acquire_token_silent() here. By design,

tox.ini

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
[tox]
2+
env_list =
3+
py3
4+
minversion = 4.21.2
5+
6+
[testenv]
7+
description = run the tests with pytest
8+
package = wheel
9+
wheel_build_env = .pkg
10+
passenv =
11+
# This allows tox environment on a DevBox to trigger host browser
12+
DISPLAY
13+
deps =
14+
pytest>=6
15+
-r requirements.txt
16+
commands =
17+
pip list
18+
{posargs:pytest --color=yes}
19+
20+
[testenv:azcli]
21+
deps =
22+
azure-cli
23+
commands_pre =
24+
# It will unfortunately be run every time but luckily subsequent runs are fast.
25+
pip install -e .
26+
commands =
27+
pip list
28+
{posargs:az --version}

0 commit comments

Comments
 (0)