@@ -45,25 +45,42 @@ def _extract_data(kwargs, key, default=None):
45
45
return data .get (key ) if isinstance (data , dict ) else default
46
46
47
47
48
- class ThrottledHttpClient (object ):
49
- def __init__ (self , http_client , http_cache ):
50
- """Throttle the given http_client by storing and retrieving data from cache.
48
+ class ThrottledHttpClientBase (object ):
49
+ """Throttle the given http_client by storing and retrieving data from cache.
51
50
52
- This wrapper exists so that our patching post() and get() would prevent
53
- re-patching side effect when/if same http_client being reused.
54
- """
55
- expiring_mapping = ExpiringMapping ( # It will automatically clean up
51
+ This wrapper exists so that our patching post() and get() would prevent
52
+ re-patching side effect when/if same http_client being reused.
53
+
54
+ The subclass should implement post() and/or get()
55
+ """
56
+ def __init__ (self , http_client , http_cache ):
57
+ self .http_client = http_client
58
+ self ._expiring_mapping = ExpiringMapping ( # It will automatically clean up
56
59
mapping = http_cache if http_cache is not None else {},
57
60
capacity = 1024 , # To prevent cache blowing up especially for CCA
58
61
lock = Lock (), # TODO: This should ideally also allow customization
59
62
)
60
63
64
+ def post (self , * args , ** kwargs ):
65
+ return self .http_client .post (* args , ** kwargs )
66
+
67
+ def get (self , * args , ** kwargs ):
68
+ return self .http_client .get (* args , ** kwargs )
69
+
70
+ def close (self ):
71
+ return self .http_client .close ()
72
+
73
+
74
+ class ThrottledHttpClient (ThrottledHttpClientBase ):
75
+ def __init__ (self , http_client , http_cache ):
76
+ super (ThrottledHttpClient , self ).__init__ (http_client , http_cache )
77
+
61
78
_post = http_client .post # We'll patch _post, and keep original post() intact
62
79
63
80
_post = IndividualCache (
64
81
# Internal specs requires throttling on at least token endpoint,
65
82
# here we have a generic patch for POST on all endpoints.
66
- mapping = expiring_mapping ,
83
+ mapping = self . _expiring_mapping ,
67
84
key_maker = lambda func , args , kwargs :
68
85
"POST {} client_id={} scope={} hash={} 429/5xx/Retry-After" .format (
69
86
args [0 ], # It is the url, typically containing authority and tenant
@@ -81,7 +98,7 @@ def __init__(self, http_client, http_cache):
81
98
)(_post )
82
99
83
100
_post = IndividualCache ( # It covers the "UI required cache"
84
- mapping = expiring_mapping ,
101
+ mapping = self . _expiring_mapping ,
85
102
key_maker = lambda func , args , kwargs : "POST {} hash={} 400" .format (
86
103
args [0 ], # It is the url, typically containing authority and tenant
87
104
_hash (
@@ -120,7 +137,7 @@ def __init__(self, http_client, http_cache):
120
137
self .post = _post
121
138
122
139
self .get = IndividualCache ( # Typically those discovery GETs
123
- mapping = expiring_mapping ,
140
+ mapping = self . _expiring_mapping ,
124
141
key_maker = lambda func , args , kwargs : "GET {} hash={} 2xx" .format (
125
142
args [0 ], # It is the url, sometimes containing inline params
126
143
_hash (kwargs .get ("params" , "" )),
@@ -129,13 +146,7 @@ def __init__(self, http_client, http_cache):
129
146
3600 * 24 if 200 <= result .status_code < 300 else 0 ,
130
147
)(http_client .get )
131
148
132
- self ._http_client = http_client
133
-
134
149
# The following 2 methods have been defined dynamically by __init__()
135
150
#def post(self, *args, **kwargs): pass
136
151
#def get(self, *args, **kwargs): pass
137
152
138
- def close (self ):
139
- """MSAL won't need this. But we allow throttled_http_client.close() anyway"""
140
- return self ._http_client .close ()
141
-
0 commit comments