Skip to content

Commit 00c3d42

Browse files
committed
Refactor throttling and add it to Managed Identity
1 parent 20975f2 commit 00c3d42

File tree

2 files changed

+66
-24
lines changed

2 files changed

+66
-24
lines changed

msal/managed_identity.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
from urllib.parse import urlparse # Python 3+
1111
from collections import UserDict # Python 3+
1212
from .token_cache import TokenCache
13-
from .throttled_http_client import ThrottledHttpClient
13+
from .individual_cache import _IndividualCache as IndividualCache
14+
from .throttled_http_client import ThrottledHttpClientBase, _parse_http_429_5xx_retry_after
1415

1516

1617
logger = logging.getLogger(__name__)
@@ -106,6 +107,22 @@ def __init__(self, *, client_id=None, resource_id=None, object_id=None):
106107
"client_id, resource_id, object_id")
107108

108109

110+
class _ThrottledHttpClient(ThrottledHttpClientBase):
111+
def __init__(self, http_client, http_cache):
112+
super(_ThrottledHttpClient, self).__init__(http_client, http_cache)
113+
self.get = IndividualCache( # All MIs (except Cloud Shell) use GETs
114+
mapping=self._expiring_mapping,
115+
key_maker=lambda func, args, kwargs: "POST {} hash={} 429/5xx/Retry-After".format(
116+
args[0], # It is the endpoint, typically a constant per MI type
117+
_hash(
118+
# Managed Identity flavors have inconsistent parameters.
119+
# We simply choose to hash them all.
120+
str(kwargs.get("params")) + str(kwargs.get("data"))),
121+
),
122+
expires_in=_parse_http_429_5xx_retry_after,
123+
)(http_client.get)
124+
125+
109126
class ManagedIdentityClient(object):
110127
"""This API encapulates multiple managed identity backends:
111128
VM, App Service, Azure Automation (Runbooks), Azure Function, Service Fabric,
@@ -115,7 +132,8 @@ class ManagedIdentityClient(object):
115132
"""
116133
_instance, _tenant = socket.getfqdn(), "managed_identity" # Placeholders
117134

118-
def __init__(self, managed_identity, *, http_client, token_cache=None):
135+
def __init__(
136+
self, managed_identity, *, http_client, token_cache=None, http_cache=None):
119137
"""Create a managed identity client.
120138
121139
:param dict managed_identity:
@@ -141,6 +159,10 @@ def __init__(self, managed_identity, *, http_client, token_cache=None):
141159
Optional. It accepts a :class:`msal.TokenCache` instance to store tokens.
142160
It will use an in-memory token cache by default.
143161
162+
:param http_cache:
163+
Optional. It has the same characteristics as the
164+
:paramref:`msal.ClientApplication.http_cache`.
165+
144166
Recipe 1: Hard code a managed identity for your app::
145167
146168
import msal, requests
@@ -168,12 +190,21 @@ def __init__(self, managed_identity, *, http_client, token_cache=None):
168190
token = client.acquire_token_for_client("resource")
169191
"""
170192
self._managed_identity = managed_identity
171-
if isinstance(http_client, ThrottledHttpClient):
172-
raise ValueError(
173-
# It is a precaution to reject application.py's throttled http_client,
174-
# whose cache life on HTTP GET 200 is too long for Managed Identity.
175-
"This class does not currently accept a ThrottledHttpClient.")
176-
self._http_client = http_client
193+
self._http_client = _ThrottledHttpClient(
194+
# This class only throttles excess token acquisition requests.
195+
# It does not provide retry.
196+
# Retry is the http_client or caller's responsibility, not MSAL's.
197+
#
198+
# FWIW, here is the inconsistent retry recommendation.
199+
# 1. Only MI on VM defines exotic 404 and 410 retry recommendations
200+
# ( https://learn.microsoft.com/en-us/entra/identity/managed-identities-azure-resources/how-to-use-vm-token#error-handling )
201+
# (especially for 410 which was supposed to be a permanent failure).
202+
# 2. MI on Service Fabric specifically suggests to not retry on 404.
203+
# ( https://learn.microsoft.com/en-us/azure/service-fabric/how-to-managed-cluster-managed-identity-service-fabric-app-code#error-handling )
204+
http_client.http_client # Patch the raw (unpatched) http client
205+
if isinstance(http_client, ThrottledHttpClientBase) else http_client,
206+
{} if http_cache is None else http_cache, # Default to an in-memory dict
207+
)
177208
self._token_cache = token_cache or TokenCache()
178209

179210
def acquire_token_for_client(self, *, resource): # We may support scope in the future

msal/throttled_http_client.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -45,25 +45,42 @@ def _extract_data(kwargs, key, default=None):
4545
return data.get(key) if isinstance(data, dict) else default
4646

4747

48-
class ThrottledHttpClient(object):
49-
def __init__(self, http_client, http_cache):
50-
"""Throttle the given http_client by storing and retrieving data from cache.
48+
class ThrottledHttpClientBase(object):
49+
"""Throttle the given http_client by storing and retrieving data from cache.
5150
52-
This wrapper exists so that our patching post() and get() would prevent
53-
re-patching side effect when/if same http_client being reused.
54-
"""
55-
expiring_mapping = ExpiringMapping( # It will automatically clean up
51+
This wrapper exists so that our patching post() and get() would prevent
52+
re-patching side effect when/if same http_client being reused.
53+
54+
The subclass should implement post() and/or get()
55+
"""
56+
def __init__(self, http_client, http_cache):
57+
self.http_client = http_client
58+
self._expiring_mapping = ExpiringMapping( # It will automatically clean up
5659
mapping=http_cache if http_cache is not None else {},
5760
capacity=1024, # To prevent cache blowing up especially for CCA
5861
lock=Lock(), # TODO: This should ideally also allow customization
5962
)
6063

64+
def post(self, *args, **kwargs):
65+
return self.http_client.post(*args, **kwargs)
66+
67+
def get(self, *args, **kwargs):
68+
return self.http_client.get(*args, **kwargs)
69+
70+
def close(self):
71+
return self.http_client.close()
72+
73+
74+
class ThrottledHttpClient(ThrottledHttpClientBase):
75+
def __init__(self, http_client, http_cache):
76+
super(ThrottledHttpClient, self).__init__(http_client, http_cache)
77+
6178
_post = http_client.post # We'll patch _post, and keep original post() intact
6279

6380
_post = IndividualCache(
6481
# Internal specs requires throttling on at least token endpoint,
6582
# here we have a generic patch for POST on all endpoints.
66-
mapping=expiring_mapping,
83+
mapping=self._expiring_mapping,
6784
key_maker=lambda func, args, kwargs:
6885
"POST {} client_id={} scope={} hash={} 429/5xx/Retry-After".format(
6986
args[0], # It is the url, typically containing authority and tenant
@@ -81,7 +98,7 @@ def __init__(self, http_client, http_cache):
8198
)(_post)
8299

83100
_post = IndividualCache( # It covers the "UI required cache"
84-
mapping=expiring_mapping,
101+
mapping=self._expiring_mapping,
85102
key_maker=lambda func, args, kwargs: "POST {} hash={} 400".format(
86103
args[0], # It is the url, typically containing authority and tenant
87104
_hash(
@@ -120,7 +137,7 @@ def __init__(self, http_client, http_cache):
120137
self.post = _post
121138

122139
self.get = IndividualCache( # Typically those discovery GETs
123-
mapping=expiring_mapping,
140+
mapping=self._expiring_mapping,
124141
key_maker=lambda func, args, kwargs: "GET {} hash={} 2xx".format(
125142
args[0], # It is the url, sometimes containing inline params
126143
_hash(kwargs.get("params", "")),
@@ -129,13 +146,7 @@ def __init__(self, http_client, http_cache):
129146
3600*24 if 200 <= result.status_code < 300 else 0,
130147
)(http_client.get)
131148

132-
self._http_client = http_client
133-
134149
# The following 2 methods have been defined dynamically by __init__()
135150
#def post(self, *args, **kwargs): pass
136151
#def get(self, *args, **kwargs): pass
137152

138-
def close(self):
139-
"""MSAL won't need this. But we allow throttled_http_client.close() anyway"""
140-
return self._http_client.close()
141-

0 commit comments

Comments
 (0)