Skip to content

Commit e9a601a

Browse files
committed
Wrap http_client instead of decorate it
Rename to throttled_http_client.py
1 parent 20b65d9 commit e9a601a

File tree

3 files changed

+150
-124
lines changed

3 files changed

+150
-124
lines changed

msal/http_decorate.py

Lines changed: 0 additions & 117 deletions
This file was deleted.

msal/throttled_http_client.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
from threading import Lock
2+
from hashlib import sha256
3+
4+
from .individual_cache import _IndividualCache as IndividualCache
5+
from .individual_cache import _ExpiringMapping as ExpiringMapping
6+
7+
8+
# https://datatracker.ietf.org/doc/html/rfc8628#section-3.4
9+
DEVICE_AUTH_GRANT = "urn:ietf:params:oauth:grant-type:device_code"
10+
11+
12+
def _hash(raw):
13+
return sha256(repr(raw).encode("utf-8")).hexdigest()
14+
15+
16+
def _handle_http_429_5xx_retry_after(result=None, **ignored):
17+
assert result is not None, """
18+
The signature defines it with a default value None,
19+
only because the its shape is already decided by the
20+
IndividualCache's.__call__().
21+
In actual code path, the result parameter here won't be None.
22+
"""
23+
response = result
24+
lowercase_headers = { # MSAL's HttpClient may not have headers
25+
k.lower(): v for k, v in getattr(response, "headers", {}).items()}
26+
default = 600
27+
try:
28+
# AAD's retry_after uses integer format only
29+
# https://stackoverflow.microsoft.com/questions/264931/264932
30+
retry_after = int(lowercase_headers.get("retry-after", default))
31+
except ValueError:
32+
retry_after = default
33+
return min(3600, retry_after) if (
34+
response.status_code == 429 or response.status_code >= 500
35+
or "retry-after" in lowercase_headers
36+
) else 0
37+
38+
39+
def _extract_data(kwargs, key, default=None):
40+
data = kwargs.get("data", {}) # data is usually a dict, but occasionally a string
41+
return data.get(key) if isinstance(data, dict) else default
42+
43+
44+
class ThrottledHttpClient(object):
45+
def __init__(self, http_client, http_cache):
46+
"""Throttle the given http_client by storing and retrieving data from cache.
47+
48+
This wrapper exists so that our patching post() and get() would prevent
49+
re-patching side effect when/if same http_client being reused.
50+
"""
51+
expiring_mapping = ExpiringMapping( # It will automatically clean up
52+
mapping=http_cache if http_cache is not None else {},
53+
capacity=1024, # To prevent cache blowing up especially for CCA
54+
lock=Lock(), # TODO: This should ideally also allow customization
55+
)
56+
57+
_post = http_client.post # We'll patch _post, and keep original post() intact
58+
59+
_post = IndividualCache(
60+
# Internal specs requires throttling on at least token endpoint,
61+
# here we have a generic patch for POST on all endpoints.
62+
mapping=expiring_mapping,
63+
key_maker=lambda func, args, kwargs:
64+
"POST {} client_id={} scope={} hash={} 429/5xx/Retry-After".format(
65+
args[0], # It is the url, typically containing authority and tenant
66+
_extract_data(kwargs, "client_id"), # Per internal specs
67+
_extract_data(kwargs, "scope"), # Per internal specs
68+
_hash(
69+
# The followings are all approximations of the "account" concept
70+
# to support per-account throttling.
71+
# TODO: We may want to disable it for confidential client, though
72+
_extract_data(kwargs, "refresh_token", # "account" during refresh
73+
_extract_data(kwargs, "code", # "account" of auth code grant
74+
_extract_data(kwargs, "username")))), # "account" of ROPC
75+
),
76+
expires_in=_handle_http_429_5xx_retry_after,
77+
)(_post)
78+
79+
_post = IndividualCache( # It covers the "UI required cache"
80+
mapping=expiring_mapping,
81+
key_maker=lambda func, args, kwargs: "POST {} hash={} 400".format(
82+
args[0], # It is the url, typically containing authority and tenant
83+
_hash(
84+
# Here we use literally all parameters, even those short-lived
85+
# parameters containing timestamps (WS-Trust or POP assertion),
86+
# because they will automatically be cleaned up by ExpiringMapping.
87+
#
88+
# Furthermore, there is no need to implement
89+
# "interactive requests would reset the cache",
90+
# because acquire_token_silent()'s would be automatically unblocked
91+
# due to token cache layer operates on top of http cache layer.
92+
#
93+
# And, acquire_token_silent(..., force_refresh=True) will NOT
94+
# bypass http cache, because there is no real gain from that.
95+
# We won't bother implement it, nor do we want to encourage
96+
# acquire_token_silent(..., force_refresh=True) pattern.
97+
str(kwargs.get("params")) + str(kwargs.get("data"))),
98+
),
99+
expires_in=lambda result=None, data=None, **ignored:
100+
60
101+
if result.status_code == 400
102+
# Here we choose to cache exact HTTP 400 errors only (rather than 4xx)
103+
# because they are the ones defined in OAuth2
104+
# (https://datatracker.ietf.org/doc/html/rfc6749#section-5.2)
105+
# Other 4xx errors might have different requirements e.g.
106+
# "407 Proxy auth required" would need a key including http headers.
107+
and not( # Exclude Device Flow cause its retry is expected and regulated
108+
isinstance(data, dict) and data.get("grant_type") == DEVICE_AUTH_GRANT
109+
)
110+
and "retry-after" not in set( # Leave it to the Retry-After decorator
111+
h.lower() for h in getattr(result, "headers", {}).keys())
112+
else 0,
113+
)(_post)
114+
115+
self.post = _post
116+
117+
self.get = IndividualCache( # Typically those discovery GETs
118+
mapping=expiring_mapping,
119+
key_maker=lambda func, args, kwargs: "GET {} hash={} 2xx".format(
120+
args[0], # It is the url, sometimes containing inline params
121+
_hash(kwargs.get("params", "")),
122+
),
123+
expires_in=lambda result=None, **ignored:
124+
3600*24 if 200 <= result.status_code < 300 else 0,
125+
)(http_client.get)
126+
127+
# The following 2 methods have been defined dynamically by __init__()
128+
#def post(self, *args, **kwargs): pass
129+
#def get(self, *args, **kwargs): pass
130+

tests/test_http_decorate.py renamed to tests/test_throttled_http_client.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from time import sleep
33
from random import random
44
import logging
5-
from msal.http_decorate import _decorate
5+
from msal.throttled_http_client import ThrottledHttpClient
66
from tests import unittest
77
from tests.http_client import MinimalResponse
88

@@ -37,10 +37,23 @@ def get(self, url, params=None, headers=None, **kwargs):
3737

3838

3939
class TestHttpDecoration(unittest.TestCase):
40+
41+
def test_throttled_http_client_should_not_alter_original_http_client(self):
42+
http_cache = {}
43+
original_http_client = DummyHttpClient()
44+
original_get = original_http_client.get
45+
original_post = original_http_client.post
46+
throttled_http_client = ThrottledHttpClient(original_http_client, http_cache)
47+
goal = """The implementation should wrap original http_client
48+
and keep it intact, instead of monkey-patching it"""
49+
self.assertNotEqual(throttled_http_client, original_http_client, goal)
50+
self.assertEqual(original_post, original_http_client.post)
51+
self.assertEqual(original_get, original_http_client.get)
52+
4053
def _test_RetryAfter_N_seconds_should_keep_entry_for_N_seconds(
4154
self, http_client, retry_after):
4255
http_cache = {}
43-
_decorate(http_client, http_cache)
56+
http_client = ThrottledHttpClient(http_client, http_cache)
4457
resp1 = http_client.post("https://example.com") # We implemented POST only
4558
resp2 = http_client.post("https://example.com") # We implemented POST only
4659
logger.debug(http_cache)
@@ -76,7 +89,7 @@ def test_one_RetryAfter_request_should_block_a_similar_request(self):
7689
http_cache = {}
7790
http_client = DummyHttpClient(
7891
status_code=429, response_headers={"Retry-After": 2})
79-
_decorate(http_client, http_cache)
92+
http_client = ThrottledHttpClient(http_client, http_cache)
8093
resp1 = http_client.post("https://example.com", data={
8194
"scope": "one", "claims": "bar", "grant_type": "authorization_code"})
8295
resp2 = http_client.post("https://example.com", data={
@@ -88,7 +101,7 @@ def test_one_RetryAfter_request_should_not_block_a_different_request(self):
88101
http_cache = {}
89102
http_client = DummyHttpClient(
90103
status_code=429, response_headers={"Retry-After": 2})
91-
_decorate(http_client, http_cache)
104+
http_client = ThrottledHttpClient(http_client, http_cache)
92105
resp1 = http_client.post("https://example.com", data={"scope": "one"})
93106
resp2 = http_client.post("https://example.com", data={"scope": "two"})
94107
logger.debug(http_cache)
@@ -98,7 +111,7 @@ def test_one_invalid_grant_should_block_a_similar_request(self):
98111
http_cache = {}
99112
http_client = DummyHttpClient(
100113
status_code=400) # It covers invalid_grant and interaction_required
101-
_decorate(http_client, http_cache)
114+
http_client = ThrottledHttpClient(http_client, http_cache)
102115
resp1 = http_client.post("https://example.com", data={"claims": "foo"})
103116
logger.debug(http_cache)
104117
resp1_again = http_client.post("https://example.com", data={"claims": "foo"})
@@ -132,7 +145,7 @@ def test_http_get_200_should_be_cached(self):
132145
http_cache = {}
133146
http_client = DummyHttpClient(
134147
status_code=200) # It covers UserRealm discovery and OIDC discovery
135-
_decorate(http_client, http_cache)
148+
http_client = ThrottledHttpClient(http_client, http_cache)
136149
resp1 = http_client.get("https://example.com")
137150
resp2 = http_client.get("https://example.com")
138151
logger.debug(http_cache)
@@ -142,7 +155,7 @@ def test_device_flow_retry_should_not_be_cached(self):
142155
DEVICE_AUTH_GRANT = "urn:ietf:params:oauth:grant-type:device_code"
143156
http_cache = {}
144157
http_client = DummyHttpClient(status_code=400)
145-
_decorate(http_client, http_cache)
158+
http_client = ThrottledHttpClient(http_client, http_cache)
146159
resp1 = http_client.get(
147160
"https://example.com", data={"grant_type": DEVICE_AUTH_GRANT})
148161
resp2 = http_client.get(

0 commit comments

Comments
 (0)