9
9
DEVICE_AUTH_GRANT = "urn:ietf:params:oauth:grant-type:device_code"
10
10
11
11
12
- def _hash (raw ):
13
- return sha256 (repr (raw ).encode ("utf-8" )).hexdigest ()
14
-
15
-
16
- def _parse_http_429_5xx_retry_after (result = None , ** ignored ):
17
- """Return seconds to throttle"""
18
- assert result is not None , """
19
- The signature defines it with a default value None,
20
- only because the its shape is already decided by the
21
- IndividualCache's.__call__().
22
- In actual code path, the result parameter here won't be None.
23
- """
24
- response = result
25
- lowercase_headers = {k .lower (): v for k , v in getattr (
26
- # Historically, MSAL's HttpResponse does not always have headers
27
- response , "headers" , {}).items ()}
28
- if not (response .status_code == 429 or response .status_code >= 500
29
- or "retry-after" in lowercase_headers ):
30
- return 0 # Quick exit
31
- default = 60 # Recommended at the end of
32
- # https://identitydivision.visualstudio.com/devex/_git/AuthLibrariesApiReview?version=GBdev&path=%2FService%20protection%2FIntial%20set%20of%20protection%20measures.md&_a=preview
33
- retry_after = lowercase_headers .get ("retry-after" , default )
34
- try :
35
- # AAD's retry_after uses integer format only
36
- # https://stackoverflow.microsoft.com/questions/264931/264932
37
- delay_seconds = int (retry_after )
38
- except ValueError :
39
- delay_seconds = default
40
- return min (3600 , delay_seconds )
12
+ class RetryAfterParser (object ):
13
+ def __init__ (self , default_value = None ):
14
+ self ._default_value = 5 if default_value is None else default_value
15
+
16
+ def parse (self , * , result , ** ignored ):
17
+ """Return seconds to throttle"""
18
+ response = result
19
+ lowercase_headers = {k .lower (): v for k , v in getattr (
20
+ # Historically, MSAL's HttpResponse does not always have headers
21
+ response , "headers" , {}).items ()}
22
+ if not (response .status_code == 429 or response .status_code >= 500
23
+ or "retry-after" in lowercase_headers ):
24
+ return 0 # Quick exit
25
+ retry_after = lowercase_headers .get ("retry-after" , self ._default_value )
26
+ try :
27
+ # AAD's retry_after uses integer format only
28
+ # https://stackoverflow.microsoft.com/questions/264931/264932
29
+ delay_seconds = int (retry_after )
30
+ except ValueError :
31
+ delay_seconds = self ._default_value
32
+ return min (3600 , delay_seconds )
41
33
42
34
43
35
def _extract_data (kwargs , key , default = None ):
@@ -53,7 +45,7 @@ class ThrottledHttpClientBase(object):
53
45
54
46
The subclass should implement post() and/or get()
55
47
"""
56
- def __init__ (self , http_client , http_cache ):
48
+ def __init__ (self , http_client , * , http_cache = None ):
57
49
self .http_client = http_client
58
50
self ._expiring_mapping = ExpiringMapping ( # It will automatically clean up
59
51
mapping = http_cache if http_cache is not None else {},
@@ -70,10 +62,14 @@ def get(self, *args, **kwargs):
70
62
def close (self ):
71
63
return self .http_client .close ()
72
64
65
+ @staticmethod
66
+ def _hash (raw ):
67
+ return sha256 (repr (raw ).encode ("utf-8" )).hexdigest ()
68
+
73
69
74
70
class ThrottledHttpClient (ThrottledHttpClientBase ):
75
- def __init__ (self , http_client , http_cache ):
76
- super (ThrottledHttpClient , self ).__init__ (http_client , http_cache )
71
+ def __init__ (self , http_client , * , default_throttle_time = None , ** kwargs ):
72
+ super (ThrottledHttpClient , self ).__init__ (http_client , ** kwargs )
77
73
78
74
_post = http_client .post # We'll patch _post, and keep original post() intact
79
75
@@ -86,22 +82,22 @@ def __init__(self, http_client, http_cache):
86
82
args [0 ], # It is the url, typically containing authority and tenant
87
83
_extract_data (kwargs , "client_id" ), # Per internal specs
88
84
_extract_data (kwargs , "scope" ), # Per internal specs
89
- _hash (
85
+ self . _hash (
90
86
# The followings are all approximations of the "account" concept
91
87
# to support per-account throttling.
92
88
# TODO: We may want to disable it for confidential client, though
93
89
_extract_data (kwargs , "refresh_token" , # "account" during refresh
94
90
_extract_data (kwargs , "code" , # "account" of auth code grant
95
91
_extract_data (kwargs , "username" )))), # "account" of ROPC
96
92
),
97
- expires_in = _parse_http_429_5xx_retry_after ,
93
+ expires_in = RetryAfterParser ( default_throttle_time or 5 ). parse ,
98
94
)(_post )
99
95
100
96
_post = IndividualCache ( # It covers the "UI required cache"
101
97
mapping = self ._expiring_mapping ,
102
98
key_maker = lambda func , args , kwargs : "POST {} hash={} 400" .format (
103
99
args [0 ], # It is the url, typically containing authority and tenant
104
- _hash (
100
+ self . _hash (
105
101
# Here we use literally all parameters, even those short-lived
106
102
# parameters containing timestamps (WS-Trust or POP assertion),
107
103
# because they will automatically be cleaned up by ExpiringMapping.
@@ -140,7 +136,7 @@ def __init__(self, http_client, http_cache):
140
136
mapping = self ._expiring_mapping ,
141
137
key_maker = lambda func , args , kwargs : "GET {} hash={} 2xx" .format (
142
138
args [0 ], # It is the url, sometimes containing inline params
143
- _hash (kwargs .get ("params" , "" )),
139
+ self . _hash (kwargs .get ("params" , "" )),
144
140
),
145
141
expires_in = lambda result = None , ** ignored :
146
142
3600 * 24 if 200 <= result .status_code < 300 else 0 ,
0 commit comments