3
3
4
4
from .individual_cache import _IndividualCache as IndividualCache
5
5
from .individual_cache import _ExpiringMapping as ExpiringMapping
6
+ from .oauth2cli .http import Response
7
+ from .exceptions import MsalServiceError
6
8
7
9
8
10
# https://datatracker.ietf.org/doc/html/rfc8628#section-3.4
9
11
DEVICE_AUTH_GRANT = "urn:ietf:params:oauth:grant-type:device_code"
10
12
11
13
12
14
class RetryAfterParser (object ):
15
+ FIELD_NAME_LOWER = "Retry-After" .lower ()
13
16
def __init__ (self , default_value = None ):
14
17
self ._default_value = 5 if default_value is None else default_value
15
18
@@ -20,9 +23,9 @@ def parse(self, *, result, **ignored):
20
23
# Historically, MSAL's HttpResponse does not always have headers
21
24
response , "headers" , {}).items ()}
22
25
if not (response .status_code == 429 or response .status_code >= 500
23
- or "retry-after" in lowercase_headers ):
26
+ or self . FIELD_NAME_LOWER in lowercase_headers ):
24
27
return 0 # Quick exit
25
- retry_after = lowercase_headers .get ("retry-after" , self ._default_value )
28
+ retry_after = lowercase_headers .get (self . FIELD_NAME_LOWER , self ._default_value )
26
29
try :
27
30
# AAD's retry_after uses integer format only
28
31
# https://stackoverflow.microsoft.com/questions/264931/264932
@@ -37,27 +40,52 @@ def _extract_data(kwargs, key, default=None):
37
40
return data .get (key ) if isinstance (data , dict ) else default
38
41
39
42
43
+ class NormalizedResponse (Response ):
44
+ """A http response with the shape defined in Response,
45
+ but contains only the data we will store in cache.
46
+ """
47
+ def __init__ (self , raw_response ):
48
+ super ().__init__ ()
49
+ self .status_code = raw_response .status_code
50
+ self .text = raw_response .text
51
+ self .headers = { # Only keep the headers which ThrottledHttpClient cares about
52
+ k : v for k , v in raw_response .headers .items ()
53
+ if k .lower () == RetryAfterParser .FIELD_NAME_LOWER
54
+ }
55
+
56
+ ## Note: Don't use the following line,
57
+ ## because when being pickled, it will indirectly pickle the whole raw_response
58
+ # self.raise_for_status = raw_response.raise_for_status
59
+ def raise_for_status (self ):
60
+ if self .status_code >= 400 :
61
+ raise MsalServiceError ("HTTP Error: {}" .format (self .status_code ))
62
+
63
+
40
64
class ThrottledHttpClientBase (object ):
41
65
"""Throttle the given http_client by storing and retrieving data from cache.
42
66
43
- This wrapper exists so that our patching post() and get() would prevent
44
- re-patching side effect when/if same http_client being reused.
67
+ This base exists so that:
68
+ 1. These base post() and get() will return a NormalizedResponse
69
+ 2. The base __init__() will NOT re-throttle even if caller accidentally nested ThrottledHttpClient.
45
70
46
- The subclass should implement post() and/or get()
71
+ Subclasses shall only need to dynamically decorate their post() and get() methods
72
+ in their __init__() method.
47
73
"""
48
74
def __init__ (self , http_client , * , http_cache = None ):
49
- self .http_client = http_client
75
+ self .http_client = http_client .http_client if isinstance (
76
+ # If it is already a ThrottledHttpClientBase, we use its raw (unthrottled) http client
77
+ http_client , ThrottledHttpClientBase ) else http_client
50
78
self ._expiring_mapping = ExpiringMapping ( # It will automatically clean up
51
79
mapping = http_cache if http_cache is not None else {},
52
80
capacity = 1024 , # To prevent cache blowing up especially for CCA
53
81
lock = Lock (), # TODO: This should ideally also allow customization
54
82
)
55
83
56
84
def post (self , * args , ** kwargs ):
57
- return self .http_client .post (* args , ** kwargs )
85
+ return NormalizedResponse ( self .http_client .post (* args , ** kwargs ) )
58
86
59
87
def get (self , * args , ** kwargs ):
60
- return self .http_client .get (* args , ** kwargs )
88
+ return NormalizedResponse ( self .http_client .get (* args , ** kwargs ) )
61
89
62
90
def close (self ):
63
91
return self .http_client .close ()
@@ -68,12 +96,11 @@ def _hash(raw):
68
96
69
97
70
98
class ThrottledHttpClient (ThrottledHttpClientBase ):
71
- def __init__ (self , http_client , * , default_throttle_time = None , ** kwargs ):
72
- super (ThrottledHttpClient , self ).__init__ (http_client , ** kwargs )
73
-
74
- _post = http_client .post # We'll patch _post, and keep original post() intact
75
-
76
- _post = IndividualCache (
99
+ """A throttled http client that is used by MSAL's non-managed identity clients."""
100
+ def __init__ (self , * args , default_throttle_time = None , ** kwargs ):
101
+ """Decorate self.post() and self.get() dynamically"""
102
+ super (ThrottledHttpClient , self ).__init__ (* args , ** kwargs )
103
+ self .post = IndividualCache (
77
104
# Internal specs requires throttling on at least token endpoint,
78
105
# here we have a generic patch for POST on all endpoints.
79
106
mapping = self ._expiring_mapping ,
@@ -91,9 +118,9 @@ def __init__(self, http_client, *, default_throttle_time=None, **kwargs):
91
118
_extract_data (kwargs , "username" )))), # "account" of ROPC
92
119
),
93
120
expires_in = RetryAfterParser (default_throttle_time or 5 ).parse ,
94
- )(_post )
121
+ )(self . post )
95
122
96
- _post = IndividualCache ( # It covers the "UI required cache"
123
+ self . post = IndividualCache ( # It covers the "UI required cache"
97
124
mapping = self ._expiring_mapping ,
98
125
key_maker = lambda func , args , kwargs : "POST {} hash={} 400" .format (
99
126
args [0 ], # It is the url, typically containing authority and tenant
@@ -125,12 +152,10 @@ def __init__(self, http_client, *, default_throttle_time=None, **kwargs):
125
152
isinstance (kwargs .get ("data" ), dict )
126
153
and kwargs ["data" ].get ("grant_type" ) == DEVICE_AUTH_GRANT
127
154
)
128
- and "retry-after" not in set ( # Leave it to the Retry-After decorator
155
+ and RetryAfterParser . FIELD_NAME_LOWER not in set ( # Otherwise leave it to the Retry-After decorator
129
156
h .lower () for h in getattr (result , "headers" , {}).keys ())
130
157
else 0 ,
131
- )(_post )
132
-
133
- self .post = _post
158
+ )(self .post )
134
159
135
160
self .get = IndividualCache ( # Typically those discovery GETs
136
161
mapping = self ._expiring_mapping ,
@@ -140,9 +165,4 @@ def __init__(self, http_client, *, default_throttle_time=None, **kwargs):
140
165
),
141
166
expires_in = lambda result = None , ** ignored :
142
167
3600 * 24 if 200 <= result .status_code < 300 else 0 ,
143
- )(http_client .get )
144
-
145
- # The following 2 methods have been defined dynamically by __init__()
146
- #def post(self, *args, **kwargs): pass
147
- #def get(self, *args, **kwargs): pass
148
-
168
+ )(self .get )
0 commit comments