Skip to content

Commit 87e8644

Browse files
committed
Decorate the http_client for http_cache behavior
1 parent 8f406c8 commit 87e8644

File tree

2 files changed

+269
-0
lines changed

2 files changed

+269
-0
lines changed

msal/http_decorate.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
from threading import Lock
2+
from hashlib import sha256
3+
4+
from .individual_cache import _IndividualCache as IndividualCache
5+
from .individual_cache import _ExpiringMapping as ExpiringMapping
6+
7+
8+
# https://datatracker.ietf.org/doc/html/rfc8628#section-3.4
9+
DEVICE_AUTH_GRANT = "urn:ietf:params:oauth:grant-type:device_code"
10+
11+
12+
def _hash(raw):
13+
return sha256(repr(raw).encode("utf-8")).hexdigest()
14+
15+
16+
def _handle_http_429_5xx_retry_after(result=None, **ignored):
17+
assert result is not None, """
18+
The signature defines it with a default value None,
19+
only because the its shape is already decided by the
20+
IndividualCache's.__call__().
21+
In actual code path, the result parameter here won't be None.
22+
"""
23+
response = result
24+
lowercase_headers = { # MSAL's HttpClient may not have headers
25+
k.lower(): v for k, v in getattr(response, "headers", {}).items()}
26+
default = 600
27+
try:
28+
# AAD's retry_after uses integer format only
29+
# https://stackoverflow.microsoft.com/questions/264931/264932
30+
retry_after = int(lowercase_headers.get("retry-after", default))
31+
except ValueError:
32+
retry_after = default
33+
return min(3600, retry_after) if (
34+
response.status_code == 429 or response.status_code >= 500
35+
or "retry-after" in lowercase_headers
36+
) else 0
37+
38+
39+
def _extract_data(kwargs, key, default=None):
40+
data = kwargs.get("data", {}) # data is usually a dict, but occasionally a string
41+
return data.get(key) if isinstance(data, dict) else default
42+
43+
44+
def _decorate(http_client, http_cache):
45+
"""Throttle the given http_client by storing and retrieving data from cache"""
46+
expiring_mapping = ExpiringMapping( # It will automatically clean up
47+
mapping=http_cache if http_cache is not None else {},
48+
capacity=1024, # To prevent cache blowing up especially for CCA
49+
lock=Lock(), # TODO: This should ideally also allow customization
50+
)
51+
52+
http_client.post = IndividualCache(
53+
# Internal specs requires throttling on at least token endpoint,
54+
# here we have a generic patch for POST on all endpoints.
55+
mapping=expiring_mapping,
56+
key_maker=lambda func, args, kwargs:
57+
"POST {} client_id={} scope={} hash={} 429/5xx/Retry-After".format(
58+
args[0], # It is the url, typically containing authority and tenant
59+
_extract_data(kwargs, "client_id"), # Per internal specs
60+
_extract_data(kwargs, "scope"), # Per internal specs
61+
_hash(
62+
# The followings are all approximations of the "account" concept
63+
# to support per-account throttling.
64+
# TODO: We may want to disable it for confidential client, though
65+
_extract_data(kwargs, "refresh_token", # "account" during refresh
66+
_extract_data(kwargs, "code", # "account" of auth code grant
67+
_extract_data(kwargs, "username")))), # "account" of ROPC
68+
),
69+
expires_in=_handle_http_429_5xx_retry_after,
70+
)(http_client.post)
71+
72+
http_client.post = IndividualCache( # It covers the "UI required cache"
73+
mapping=expiring_mapping,
74+
key_maker=lambda func, args, kwargs: "POST {} hash={} 400".format(
75+
args[0], # It is the url, typically containing authority and tenant
76+
_hash(
77+
# Here we use literally all parameters, even those short-lived
78+
# parameters containing timestamps (WS-Trust or POP assertion),
79+
# because they will automatically be cleaned up by ExpiringMapping.
80+
#
81+
# Furthermore, there is no need to implement
82+
# "interactive requests would reset the cache",
83+
# because acquire_token_silent()'s would be automatically unblocked
84+
# due to token cache layer operates on top of http cache layer.
85+
#
86+
# And, acquire_token_silent(..., force_refresh=True) will NOT
87+
# bypass http cache, because there is no real gain from that.
88+
# We won't bother implement it, nor do we want to encourage
89+
# acquire_token_silent(..., force_refresh=True) pattern.
90+
str(kwargs.get("params")) + str(kwargs.get("data"))),
91+
),
92+
expires_in=lambda result=None, data=None, **ignored:
93+
60
94+
if result.status_code == 400
95+
# Here we choose to cache exact HTTP 400 errors only (rather than 4xx)
96+
# because they are the ones defined in OAuth2
97+
# (https://datatracker.ietf.org/doc/html/rfc6749#section-5.2)
98+
# Other 4xx errors might have different requirements e.g.
99+
# "407 Proxy auth required" would need a key including http headers.
100+
and not( # Exclude Device Flow cause its retry is expected and regulated
101+
isinstance(data, dict) and data.get("grant_type") == DEVICE_AUTH_GRANT
102+
)
103+
and "retry-after" not in set( # Leave it to the Retry-After decorator
104+
h.lower() for h in getattr(result, "headers", {}).keys())
105+
else 0,
106+
)(http_client.post)
107+
108+
http_client.get = IndividualCache( # Typically those discovery GETs
109+
mapping=expiring_mapping,
110+
key_maker=lambda func, args, kwargs: "GET {} hash={} 2xx".format(
111+
args[0], # It is the url, sometimes containing inline params
112+
_hash(kwargs.get("params", "")),
113+
),
114+
expires_in=lambda result=None, **ignored:
115+
3600*24 if 200 <= result.status_code < 300 else 0,
116+
)(http_client.get)
117+

tests/test_http_decorate.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# Test cases for https://identitydivision.visualstudio.com/devex/_git/AuthLibrariesApiReview?version=GBdev&path=%2FService%20protection%2FIntial%20set%20of%20protection%20measures.md&_a=preview&anchor=common-test-cases
2+
from time import sleep
3+
from random import random
4+
import logging
5+
from msal.http_decorate import _decorate
6+
from tests import unittest
7+
from tests.http_client import MinimalResponse
8+
9+
10+
logger = logging.getLogger(__name__)
11+
logging.basicConfig(level=logging.DEBUG)
12+
13+
14+
class DummyHttpResponse(MinimalResponse):
15+
def __init__(self, headers=None, **kwargs):
16+
self.headers = {} if headers is None else headers
17+
super(DummyHttpResponse, self).__init__(**kwargs)
18+
19+
20+
class DummyHttpClient(object):
21+
def __init__(self, status_code=None, response_headers=None):
22+
self._status_code = status_code
23+
self._response_headers = response_headers
24+
25+
def _build_dummy_response(self):
26+
return DummyHttpResponse(
27+
status_code=self._status_code,
28+
headers=self._response_headers,
29+
text=random(), # So that we'd know whether a new response is received
30+
)
31+
32+
def post(self, url, params=None, data=None, headers=None, **kwargs):
33+
return self._build_dummy_response()
34+
35+
def get(self, url, params=None, headers=None, **kwargs):
36+
return self._build_dummy_response()
37+
38+
39+
class TestHttpDecoration(unittest.TestCase):
40+
def _test_RetryAfter_N_seconds_should_keep_entry_for_N_seconds(
41+
self, http_client, retry_after):
42+
http_cache = {}
43+
_decorate(http_client, http_cache)
44+
resp1 = http_client.post("https://example.com") # We implemented POST only
45+
resp2 = http_client.post("https://example.com") # We implemented POST only
46+
logger.debug(http_cache)
47+
self.assertEqual(resp1.text, resp2.text, "Should return a cached response")
48+
sleep(retry_after + 1)
49+
resp3 = http_client.post("https://example.com") # We implemented POST only
50+
self.assertNotEqual(resp1.text, resp3.text, "Should return a new response")
51+
52+
def test_429_with_RetryAfter_N_seconds_should_keep_entry_for_N_seconds(self):
53+
retry_after = 1
54+
self._test_RetryAfter_N_seconds_should_keep_entry_for_N_seconds(
55+
DummyHttpClient(
56+
status_code=429, response_headers={"Retry-After": retry_after}),
57+
retry_after)
58+
59+
def test_5xx_with_RetryAfter_N_seconds_should_keep_entry_for_N_seconds(self):
60+
retry_after = 1
61+
self._test_RetryAfter_N_seconds_should_keep_entry_for_N_seconds(
62+
DummyHttpClient(
63+
status_code=503, response_headers={"Retry-After": retry_after}),
64+
retry_after)
65+
66+
def test_400_with_RetryAfter_N_seconds_should_keep_entry_for_N_seconds(self):
67+
"""Retry-After is supposed to only shown in http 429/5xx,
68+
but we choose to support Retry-After for arbitrary http response."""
69+
retry_after = 1
70+
self._test_RetryAfter_N_seconds_should_keep_entry_for_N_seconds(
71+
DummyHttpClient(
72+
status_code=400, response_headers={"Retry-After": retry_after}),
73+
retry_after)
74+
75+
def test_one_RetryAfter_request_should_block_a_similar_request(self):
76+
http_cache = {}
77+
http_client = DummyHttpClient(
78+
status_code=429, response_headers={"Retry-After": 2})
79+
_decorate(http_client, http_cache)
80+
resp1 = http_client.post("https://example.com", data={
81+
"scope": "one", "claims": "bar", "grant_type": "authorization_code"})
82+
resp2 = http_client.post("https://example.com", data={
83+
"scope": "one", "claims": "foo", "grant_type": "password"})
84+
logger.debug(http_cache)
85+
self.assertEqual(resp1.text, resp2.text, "Should return a cached response")
86+
87+
def test_one_RetryAfter_request_should_not_block_a_different_request(self):
88+
http_cache = {}
89+
http_client = DummyHttpClient(
90+
status_code=429, response_headers={"Retry-After": 2})
91+
_decorate(http_client, http_cache)
92+
resp1 = http_client.post("https://example.com", data={"scope": "one"})
93+
resp2 = http_client.post("https://example.com", data={"scope": "two"})
94+
logger.debug(http_cache)
95+
self.assertNotEqual(resp1.text, resp2.text, "Should return a new response")
96+
97+
def test_one_invalid_grant_should_block_a_similar_request(self):
98+
http_cache = {}
99+
http_client = DummyHttpClient(
100+
status_code=400) # It covers invalid_grant and interaction_required
101+
_decorate(http_client, http_cache)
102+
resp1 = http_client.post("https://example.com", data={"claims": "foo"})
103+
logger.debug(http_cache)
104+
resp1_again = http_client.post("https://example.com", data={"claims": "foo"})
105+
self.assertEqual(resp1.text, resp1_again.text, "Should return a cached response")
106+
resp2 = http_client.post("https://example.com", data={"claims": "bar"})
107+
self.assertNotEqual(resp1.text, resp2.text, "Should return a new response")
108+
resp2_again = http_client.post("https://example.com", data={"claims": "bar"})
109+
self.assertEqual(resp2.text, resp2_again.text, "Should return a cached response")
110+
111+
def test_one_foci_app_recovering_from_invalid_grant_should_also_unblock_another(self):
112+
"""
113+
Need not test multiple FOCI app's acquire_token_silent() here. By design,
114+
one FOCI app's successful populating token cache would result in another
115+
FOCI app's acquire_token_silent() to hit a token without invoking http request.
116+
"""
117+
118+
def test_forcefresh_behavior(self):
119+
"""
120+
The implementation let token cache and http cache operate in different
121+
layers. They do not couple with each other.
122+
Therefore, acquire_token_silent(..., force_refresh=True)
123+
would bypass the token cache yet technically still hit the http cache.
124+
125+
But that is OK, cause the customer need no force_refresh in the first place.
126+
After a successful AT/RT acquisition, AT/RT will be in the token cache,
127+
and a normal acquire_token_silent(...) without force_refresh would just work.
128+
This was discussed in https://identitydivision.visualstudio.com/DevEx/_git/AuthLibrariesApiReview/pullrequest/3618?_a=files
129+
"""
130+
131+
def test_http_get_200_should_be_cached(self):
132+
http_cache = {}
133+
http_client = DummyHttpClient(
134+
status_code=200) # It covers UserRealm discovery and OIDC discovery
135+
_decorate(http_client, http_cache)
136+
resp1 = http_client.get("https://example.com")
137+
resp2 = http_client.get("https://example.com")
138+
logger.debug(http_cache)
139+
self.assertEqual(resp1.text, resp2.text, "Should return a cached response")
140+
141+
def test_device_flow_retry_should_not_be_cached(self):
142+
DEVICE_AUTH_GRANT = "urn:ietf:params:oauth:grant-type:device_code"
143+
http_cache = {}
144+
http_client = DummyHttpClient(status_code=400)
145+
_decorate(http_client, http_cache)
146+
resp1 = http_client.get(
147+
"https://example.com", data={"grant_type": DEVICE_AUTH_GRANT})
148+
resp2 = http_client.get(
149+
"https://example.com", data={"grant_type": DEVICE_AUTH_GRANT})
150+
logger.debug(http_cache)
151+
self.assertNotEqual(resp1.text, resp2.text, "Should return a new response")
152+

0 commit comments

Comments
 (0)