Skip to content

Commit 8e8113e

Browse files
committed
Use a short throttling threshold for MI (and CCA)
1 parent 5a9a262 commit 8e8113e

File tree

4 files changed

+52
-54
lines changed

4 files changed

+52
-54
lines changed

msal/application.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,11 @@ def __init__(
537537
self.http_client.mount("https://", a)
538538
self.http_client = ThrottledHttpClient(
539539
self.http_client,
540-
{} if http_cache is None else http_cache, # Default to an in-memory dict
540+
http_cache=http_cache,
541+
default_throttle_time=60
542+
# The default value 60 was recommended mainly for PCA at the end of
543+
# https://identitydivision.visualstudio.com/devex/_git/AuthLibrariesApiReview?version=GBdev&path=%2FService%20protection%2FIntial%20set%20of%20protection%20measures.md&_a=preview
544+
if isinstance(self, PublicClientApplication) else 5,
541545
)
542546

543547
self.app_name = app_name

msal/managed_identity.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from typing import Union # Needed in Python 3.7 & 3.8
1313
from .token_cache import TokenCache
1414
from .individual_cache import _IndividualCache as IndividualCache
15-
from .throttled_http_client import ThrottledHttpClientBase, _parse_http_429_5xx_retry_after
15+
from .throttled_http_client import ThrottledHttpClientBase, RetryAfterParser
1616

1717

1818
logger = logging.getLogger(__name__)
@@ -109,18 +109,18 @@ def __init__(self, *, client_id=None, resource_id=None, object_id=None):
109109

110110

111111
class _ThrottledHttpClient(ThrottledHttpClientBase):
112-
def __init__(self, http_client, http_cache):
113-
super(_ThrottledHttpClient, self).__init__(http_client, http_cache)
112+
def __init__(self, http_client, **kwargs):
113+
super(_ThrottledHttpClient, self).__init__(http_client, **kwargs)
114114
self.get = IndividualCache( # All MIs (except Cloud Shell) use GETs
115115
mapping=self._expiring_mapping,
116-
key_maker=lambda func, args, kwargs: "POST {} hash={} 429/5xx/Retry-After".format(
116+
key_maker=lambda func, args, kwargs: "REQ {} hash={} 429/5xx/Retry-After".format(
117117
args[0], # It is the endpoint, typically a constant per MI type
118-
_hash(
118+
self._hash(
119119
# Managed Identity flavors have inconsistent parameters.
120120
# We simply choose to hash them all.
121121
str(kwargs.get("params")) + str(kwargs.get("data"))),
122122
),
123-
expires_in=_parse_http_429_5xx_retry_after,
123+
expires_in=RetryAfterParser(5).parse, # 5 seconds default for non-PCA
124124
)(http_client.get)
125125

126126

@@ -226,7 +226,7 @@ def __init__(
226226
# ( https://learn.microsoft.com/en-us/azure/service-fabric/how-to-managed-cluster-managed-identity-service-fabric-app-code#error-handling )
227227
http_client.http_client # Patch the raw (unpatched) http client
228228
if isinstance(http_client, ThrottledHttpClientBase) else http_client,
229-
{} if http_cache is None else http_cache, # Default to an in-memory dict
229+
http_cache=http_cache,
230230
)
231231
self._token_cache = token_cache or TokenCache()
232232

msal/throttled_http_client.py

Lines changed: 32 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -9,35 +9,27 @@
99
DEVICE_AUTH_GRANT = "urn:ietf:params:oauth:grant-type:device_code"
1010

1111

12-
def _hash(raw):
13-
return sha256(repr(raw).encode("utf-8")).hexdigest()
14-
15-
16-
def _parse_http_429_5xx_retry_after(result=None, **ignored):
17-
"""Return seconds to throttle"""
18-
assert result is not None, """
19-
The signature defines it with a default value None,
20-
only because the its shape is already decided by the
21-
IndividualCache's.__call__().
22-
In actual code path, the result parameter here won't be None.
23-
"""
24-
response = result
25-
lowercase_headers = {k.lower(): v for k, v in getattr(
26-
# Historically, MSAL's HttpResponse does not always have headers
27-
response, "headers", {}).items()}
28-
if not (response.status_code == 429 or response.status_code >= 500
29-
or "retry-after" in lowercase_headers):
30-
return 0 # Quick exit
31-
default = 60 # Recommended at the end of
32-
# https://identitydivision.visualstudio.com/devex/_git/AuthLibrariesApiReview?version=GBdev&path=%2FService%20protection%2FIntial%20set%20of%20protection%20measures.md&_a=preview
33-
retry_after = lowercase_headers.get("retry-after", default)
34-
try:
35-
# AAD's retry_after uses integer format only
36-
# https://stackoverflow.microsoft.com/questions/264931/264932
37-
delay_seconds = int(retry_after)
38-
except ValueError:
39-
delay_seconds = default
40-
return min(3600, delay_seconds)
12+
class RetryAfterParser(object):
13+
def __init__(self, default_value=None):
14+
self._default_value = 5 if default_value is None else default_value
15+
16+
def parse(self, *, result, **ignored):
17+
"""Return seconds to throttle"""
18+
response = result
19+
lowercase_headers = {k.lower(): v for k, v in getattr(
20+
# Historically, MSAL's HttpResponse does not always have headers
21+
response, "headers", {}).items()}
22+
if not (response.status_code == 429 or response.status_code >= 500
23+
or "retry-after" in lowercase_headers):
24+
return 0 # Quick exit
25+
retry_after = lowercase_headers.get("retry-after", self._default_value)
26+
try:
27+
# AAD's retry_after uses integer format only
28+
# https://stackoverflow.microsoft.com/questions/264931/264932
29+
delay_seconds = int(retry_after)
30+
except ValueError:
31+
delay_seconds = self._default_value
32+
return min(3600, delay_seconds)
4133

4234

4335
def _extract_data(kwargs, key, default=None):
@@ -53,7 +45,7 @@ class ThrottledHttpClientBase(object):
5345
5446
The subclass should implement post() and/or get()
5547
"""
56-
def __init__(self, http_client, http_cache):
48+
def __init__(self, http_client, *, http_cache=None):
5749
self.http_client = http_client
5850
self._expiring_mapping = ExpiringMapping( # It will automatically clean up
5951
mapping=http_cache if http_cache is not None else {},
@@ -70,10 +62,14 @@ def get(self, *args, **kwargs):
7062
def close(self):
7163
return self.http_client.close()
7264

65+
@staticmethod
66+
def _hash(raw):
67+
return sha256(repr(raw).encode("utf-8")).hexdigest()
68+
7369

7470
class ThrottledHttpClient(ThrottledHttpClientBase):
75-
def __init__(self, http_client, http_cache):
76-
super(ThrottledHttpClient, self).__init__(http_client, http_cache)
71+
def __init__(self, http_client, *, default_throttle_time=None, **kwargs):
72+
super(ThrottledHttpClient, self).__init__(http_client, **kwargs)
7773

7874
_post = http_client.post # We'll patch _post, and keep original post() intact
7975

@@ -86,22 +82,22 @@ def __init__(self, http_client, http_cache):
8682
args[0], # It is the url, typically containing authority and tenant
8783
_extract_data(kwargs, "client_id"), # Per internal specs
8884
_extract_data(kwargs, "scope"), # Per internal specs
89-
_hash(
85+
self._hash(
9086
# The followings are all approximations of the "account" concept
9187
# to support per-account throttling.
9288
# TODO: We may want to disable it for confidential client, though
9389
_extract_data(kwargs, "refresh_token", # "account" during refresh
9490
_extract_data(kwargs, "code", # "account" of auth code grant
9591
_extract_data(kwargs, "username")))), # "account" of ROPC
9692
),
97-
expires_in=_parse_http_429_5xx_retry_after,
93+
expires_in=RetryAfterParser(default_throttle_time or 5).parse,
9894
)(_post)
9995

10096
_post = IndividualCache( # It covers the "UI required cache"
10197
mapping=self._expiring_mapping,
10298
key_maker=lambda func, args, kwargs: "POST {} hash={} 400".format(
10399
args[0], # It is the url, typically containing authority and tenant
104-
_hash(
100+
self._hash(
105101
# Here we use literally all parameters, even those short-lived
106102
# parameters containing timestamps (WS-Trust or POP assertion),
107103
# because they will automatically be cleaned up by ExpiringMapping.
@@ -140,7 +136,7 @@ def __init__(self, http_client, http_cache):
140136
mapping=self._expiring_mapping,
141137
key_maker=lambda func, args, kwargs: "GET {} hash={} 2xx".format(
142138
args[0], # It is the url, sometimes containing inline params
143-
_hash(kwargs.get("params", "")),
139+
self._hash(kwargs.get("params", "")),
144140
),
145141
expires_in=lambda result=None, **ignored:
146142
3600*24 if 200 <= result.status_code < 300 else 0,

tests/test_throttled_http_client.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,10 @@ class CloseMethodCalled(Exception):
4040
class TestHttpDecoration(unittest.TestCase):
4141

4242
def test_throttled_http_client_should_not_alter_original_http_client(self):
43-
http_cache = {}
4443
original_http_client = DummyHttpClient()
4544
original_get = original_http_client.get
4645
original_post = original_http_client.post
47-
throttled_http_client = ThrottledHttpClient(original_http_client, http_cache)
46+
throttled_http_client = ThrottledHttpClient(original_http_client)
4847
goal = """The implementation should wrap original http_client
4948
and keep it intact, instead of monkey-patching it"""
5049
self.assertNotEqual(throttled_http_client, original_http_client, goal)
@@ -54,7 +53,7 @@ def test_throttled_http_client_should_not_alter_original_http_client(self):
5453
def _test_RetryAfter_N_seconds_should_keep_entry_for_N_seconds(
5554
self, http_client, retry_after):
5655
http_cache = {}
57-
http_client = ThrottledHttpClient(http_client, http_cache)
56+
http_client = ThrottledHttpClient(http_client, http_cache=http_cache)
5857
resp1 = http_client.post("https://example.com") # We implemented POST only
5958
resp2 = http_client.post("https://example.com") # We implemented POST only
6059
logger.debug(http_cache)
@@ -90,7 +89,7 @@ def test_one_RetryAfter_request_should_block_a_similar_request(self):
9089
http_cache = {}
9190
http_client = DummyHttpClient(
9291
status_code=429, response_headers={"Retry-After": 2})
93-
http_client = ThrottledHttpClient(http_client, http_cache)
92+
http_client = ThrottledHttpClient(http_client, http_cache=http_cache)
9493
resp1 = http_client.post("https://example.com", data={
9594
"scope": "one", "claims": "bar", "grant_type": "authorization_code"})
9695
resp2 = http_client.post("https://example.com", data={
@@ -102,7 +101,7 @@ def test_one_RetryAfter_request_should_not_block_a_different_request(self):
102101
http_cache = {}
103102
http_client = DummyHttpClient(
104103
status_code=429, response_headers={"Retry-After": 2})
105-
http_client = ThrottledHttpClient(http_client, http_cache)
104+
http_client = ThrottledHttpClient(http_client, http_cache=http_cache)
106105
resp1 = http_client.post("https://example.com", data={"scope": "one"})
107106
resp2 = http_client.post("https://example.com", data={"scope": "two"})
108107
logger.debug(http_cache)
@@ -112,7 +111,7 @@ def test_one_invalid_grant_should_block_a_similar_request(self):
112111
http_cache = {}
113112
http_client = DummyHttpClient(
114113
status_code=400) # It covers invalid_grant and interaction_required
115-
http_client = ThrottledHttpClient(http_client, http_cache)
114+
http_client = ThrottledHttpClient(http_client, http_cache=http_cache)
116115
resp1 = http_client.post("https://example.com", data={"claims": "foo"})
117116
logger.debug(http_cache)
118117
resp1_again = http_client.post("https://example.com", data={"claims": "foo"})
@@ -146,7 +145,7 @@ def test_http_get_200_should_be_cached(self):
146145
http_cache = {}
147146
http_client = DummyHttpClient(
148147
status_code=200) # It covers UserRealm discovery and OIDC discovery
149-
http_client = ThrottledHttpClient(http_client, http_cache)
148+
http_client = ThrottledHttpClient(http_client, http_cache=http_cache)
150149
resp1 = http_client.get("https://example.com?foo=bar")
151150
resp2 = http_client.get("https://example.com?foo=bar")
152151
logger.debug(http_cache)
@@ -156,7 +155,7 @@ def test_device_flow_retry_should_not_be_cached(self):
156155
DEVICE_AUTH_GRANT = "urn:ietf:params:oauth:grant-type:device_code"
157156
http_cache = {}
158157
http_client = DummyHttpClient(status_code=400)
159-
http_client = ThrottledHttpClient(http_client, http_cache)
158+
http_client = ThrottledHttpClient(http_client, http_cache=http_cache)
160159
resp1 = http_client.post(
161160
"https://example.com", data={"grant_type": DEVICE_AUTH_GRANT})
162161
resp2 = http_client.post(
@@ -165,9 +164,8 @@ def test_device_flow_retry_should_not_be_cached(self):
165164
self.assertNotEqual(resp1.text, resp2.text, "Should return a new response")
166165

167166
def test_throttled_http_client_should_provide_close(self):
168-
http_cache = {}
169167
http_client = DummyHttpClient(status_code=200)
170-
http_client = ThrottledHttpClient(http_client, http_cache)
168+
http_client = ThrottledHttpClient(http_client)
171169
with self.assertRaises(CloseMethodCalled):
172170
http_client.close()
173171

0 commit comments

Comments
 (0)