|
| 1 | +from threading import Lock |
| 2 | +from hashlib import sha256 |
| 3 | + |
| 4 | +from .individual_cache import _IndividualCache as IndividualCache |
| 5 | +from .individual_cache import _ExpiringMapping as ExpiringMapping |
| 6 | + |
| 7 | + |
| 8 | +# https://datatracker.ietf.org/doc/html/rfc8628#section-3.4 |
| 9 | +DEVICE_AUTH_GRANT = "urn:ietf:params:oauth:grant-type:device_code" |
| 10 | + |
| 11 | + |
| 12 | +def _hash(raw): |
| 13 | + return sha256(repr(raw).encode("utf-8")).hexdigest() |
| 14 | + |
| 15 | + |
| 16 | +def _handle_http_429_5xx_retry_after(result=None, **ignored): |
| 17 | + assert result is not None, """ |
| 18 | + The signature defines it with a default value None, |
| 19 | + only because the its shape is already decided by the |
| 20 | + IndividualCache's.__call__(). |
| 21 | + In actual code path, the result parameter here won't be None. |
| 22 | + """ |
| 23 | + response = result |
| 24 | + lowercase_headers = { # MSAL's HttpClient may not have headers |
| 25 | + k.lower(): v for k, v in getattr(response, "headers", {}).items()} |
| 26 | + default = 600 |
| 27 | + try: |
| 28 | + # AAD's retry_after uses integer format only |
| 29 | + # https://stackoverflow.microsoft.com/questions/264931/264932 |
| 30 | + retry_after = int(lowercase_headers.get("retry-after", default)) |
| 31 | + except ValueError: |
| 32 | + retry_after = default |
| 33 | + return min(3600, retry_after) if ( |
| 34 | + response.status_code == 429 or response.status_code >= 500 |
| 35 | + or "retry-after" in lowercase_headers |
| 36 | + ) else 0 |
| 37 | + |
| 38 | + |
| 39 | +def _extract_data(kwargs, key, default=None): |
| 40 | + data = kwargs.get("data", {}) # data is usually a dict, but occasionally a string |
| 41 | + return data.get(key) if isinstance(data, dict) else default |
| 42 | + |
| 43 | + |
| 44 | +class ThrottledHttpClient(object): |
| 45 | + def __init__(self, http_client, http_cache): |
| 46 | + """Throttle the given http_client by storing and retrieving data from cache. |
| 47 | +
|
| 48 | + This wrapper exists so that our patching post() and get() would prevent |
| 49 | + re-patching side effect when/if same http_client being reused. |
| 50 | + """ |
| 51 | + expiring_mapping = ExpiringMapping( # It will automatically clean up |
| 52 | + mapping=http_cache if http_cache is not None else {}, |
| 53 | + capacity=1024, # To prevent cache blowing up especially for CCA |
| 54 | + lock=Lock(), # TODO: This should ideally also allow customization |
| 55 | + ) |
| 56 | + |
| 57 | + _post = http_client.post # We'll patch _post, and keep original post() intact |
| 58 | + |
| 59 | + _post = IndividualCache( |
| 60 | + # Internal specs requires throttling on at least token endpoint, |
| 61 | + # here we have a generic patch for POST on all endpoints. |
| 62 | + mapping=expiring_mapping, |
| 63 | + key_maker=lambda func, args, kwargs: |
| 64 | + "POST {} client_id={} scope={} hash={} 429/5xx/Retry-After".format( |
| 65 | + args[0], # It is the url, typically containing authority and tenant |
| 66 | + _extract_data(kwargs, "client_id"), # Per internal specs |
| 67 | + _extract_data(kwargs, "scope"), # Per internal specs |
| 68 | + _hash( |
| 69 | + # The followings are all approximations of the "account" concept |
| 70 | + # to support per-account throttling. |
| 71 | + # TODO: We may want to disable it for confidential client, though |
| 72 | + _extract_data(kwargs, "refresh_token", # "account" during refresh |
| 73 | + _extract_data(kwargs, "code", # "account" of auth code grant |
| 74 | + _extract_data(kwargs, "username")))), # "account" of ROPC |
| 75 | + ), |
| 76 | + expires_in=_handle_http_429_5xx_retry_after, |
| 77 | + )(_post) |
| 78 | + |
| 79 | + _post = IndividualCache( # It covers the "UI required cache" |
| 80 | + mapping=expiring_mapping, |
| 81 | + key_maker=lambda func, args, kwargs: "POST {} hash={} 400".format( |
| 82 | + args[0], # It is the url, typically containing authority and tenant |
| 83 | + _hash( |
| 84 | + # Here we use literally all parameters, even those short-lived |
| 85 | + # parameters containing timestamps (WS-Trust or POP assertion), |
| 86 | + # because they will automatically be cleaned up by ExpiringMapping. |
| 87 | + # |
| 88 | + # Furthermore, there is no need to implement |
| 89 | + # "interactive requests would reset the cache", |
| 90 | + # because acquire_token_silent()'s would be automatically unblocked |
| 91 | + # due to token cache layer operates on top of http cache layer. |
| 92 | + # |
| 93 | + # And, acquire_token_silent(..., force_refresh=True) will NOT |
| 94 | + # bypass http cache, because there is no real gain from that. |
| 95 | + # We won't bother implement it, nor do we want to encourage |
| 96 | + # acquire_token_silent(..., force_refresh=True) pattern. |
| 97 | + str(kwargs.get("params")) + str(kwargs.get("data"))), |
| 98 | + ), |
| 99 | + expires_in=lambda result=None, data=None, **ignored: |
| 100 | + 60 |
| 101 | + if result.status_code == 400 |
| 102 | + # Here we choose to cache exact HTTP 400 errors only (rather than 4xx) |
| 103 | + # because they are the ones defined in OAuth2 |
| 104 | + # (https://datatracker.ietf.org/doc/html/rfc6749#section-5.2) |
| 105 | + # Other 4xx errors might have different requirements e.g. |
| 106 | + # "407 Proxy auth required" would need a key including http headers. |
| 107 | + and not( # Exclude Device Flow cause its retry is expected and regulated |
| 108 | + isinstance(data, dict) and data.get("grant_type") == DEVICE_AUTH_GRANT |
| 109 | + ) |
| 110 | + and "retry-after" not in set( # Leave it to the Retry-After decorator |
| 111 | + h.lower() for h in getattr(result, "headers", {}).keys()) |
| 112 | + else 0, |
| 113 | + )(_post) |
| 114 | + |
| 115 | + self.post = _post |
| 116 | + |
| 117 | + self.get = IndividualCache( # Typically those discovery GETs |
| 118 | + mapping=expiring_mapping, |
| 119 | + key_maker=lambda func, args, kwargs: "GET {} hash={} 2xx".format( |
| 120 | + args[0], # It is the url, sometimes containing inline params |
| 121 | + _hash(kwargs.get("params", "")), |
| 122 | + ), |
| 123 | + expires_in=lambda result=None, **ignored: |
| 124 | + 3600*24 if 200 <= result.status_code < 300 else 0, |
| 125 | + )(http_client.get) |
| 126 | + |
| 127 | + # The following 2 methods have been defined dynamically by __init__() |
| 128 | + #def post(self, *args, **kwargs): pass |
| 129 | + #def get(self, *args, **kwargs): pass |
| 130 | + |
0 commit comments