Skip to content

Commit 0247953

Browse files
committed
Change the API based on recent team discussion
1 parent 75601db commit 0247953

File tree

5 files changed

+210
-92
lines changed

5 files changed

+210
-92
lines changed

docs/index.rst

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,3 +121,24 @@ See `SerializableTokenCache` for example.
121121

122122
.. autoclass:: msal.SerializableTokenCache
123123
:members:
124+
125+
126+
Managed Identity
127+
----------------
128+
MSAL supports
129+
`Managed Identity <https://learn.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/overview>`_.
130+
131+
You can create one of these two kinds of managed identity configuration objects:
132+
133+
.. autoclass:: msal.SystemAssignedManagedIdentity
134+
:members:
135+
136+
.. autoclass:: msal.UserAssignedManagedIdentity
137+
:members:
138+
139+
And then feed the configuration object into a :class:`ManagedIdentityClient` object.
140+
141+
.. autoclass:: msal.ManagedIdentityClient
142+
:members:
143+
144+
.. automethod:: __init__

msal/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,5 +33,9 @@
3333
)
3434
from .oauth2cli.oidc import Prompt
3535
from .token_cache import TokenCache, SerializableTokenCache
36-
from .imds import ManagedIdentity
36+
from .imds import (
37+
SystemAssignedManagedIdentity,
38+
UserAssignedManagedIdentity,
39+
ManagedIdentityClient,
40+
)
3741

msal/imds.py

Lines changed: 119 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# TODO: Change the module name from imds to managed_identity
12
# Copyright (c) Microsoft Corporation.
23
# All rights reserved.
34
#
@@ -11,53 +12,135 @@
1112
from urlparse import urlparse
1213
except: # Python 3
1314
from urllib.parse import urlparse
15+
try: # Python 3
16+
from collections import UserDict
17+
except:
18+
UserDict = dict # The real UserDict is an old-style class which fails super()
19+
1420

1521
logger = logging.getLogger(__name__)
1622

23+
class ManagedIdentity(UserDict):
24+
# The key names used in config dict
25+
ID_TYPE = "ManagedIdentityIdType"
26+
ID = "Id"
27+
def __init__(self, identifier=None, id_type=None):
28+
super(ManagedIdentity, self).__init__({
29+
self.ID_TYPE: id_type,
30+
self.ID: identifier,
31+
})
32+
33+
34+
class UserAssignedManagedIdentity(ManagedIdentity):
35+
"""Feed an instance of this class to :class:`msal.ManagedIdentityClient`
36+
to acquire token for user-assigned managed identity.
37+
38+
By design, an instance of this class is equivalent to a dict in
39+
one of these shapes::
40+
41+
{"ManagedIdentityIdType": "ClientId", "Id": "foo"}
42+
43+
{"ManagedIdentityIdType": "ResourceId", "Id": "foo"}
44+
45+
{"ManagedIdentityIdType": "ObjectId", "Id": "foo"}
46+
47+
so that you may load it from a json configuration file or an env var,
48+
and feed it to :class:`Client`.
49+
"""
50+
CLIENT_ID = "ClientId"
51+
RESOURCE_ID = "ResourceId"
52+
OBJECT_ID = "ObjectId"
53+
_types_mapping = { # Maps type name in configuration to type name on wire
54+
CLIENT_ID: "client_id",
55+
RESOURCE_ID: "mi_res_id",
56+
OBJECT_ID: "object_id",
57+
}
58+
def __init__(self, identifier, id_type):
59+
"""Construct a UserAssignedManagedIdentity instance.
60+
61+
:param string identifier: The id.
62+
:param string id_type: It shall be one of these three::
63+
64+
UserAssignedManagedIdentity.CLIENT_ID
65+
UserAssignedManagedIdentity.RESOURCE_ID
66+
UserAssignedManagedIdentity.OBJECT_ID
67+
"""
68+
if id_type not in self._types_mapping:
69+
raise ValueError("id_type only accepts one of: {}".format(
70+
list(self._types_mapping)))
71+
super(UserAssignedManagedIdentity, self).__init__(
72+
identifier=identifier,
73+
id_type=id_type,
74+
)
75+
76+
77+
class SystemAssignedManagedIdentity(ManagedIdentity):
78+
"""Feed an instance of this class to :class:`msal.ManagedIdentityClient`
79+
to acquire token for system-assigned managed identity.
80+
81+
By design, an instance of this class is equivalent to::
82+
83+
{"ManagedIdentityIdType": "SystemAssignedManagedIdentity", "Id": None}
84+
85+
so that you may load it from a json configuration file or an env var,
86+
and feed it to :class:`Client`.
87+
"""
88+
def __init__(self):
89+
super(SystemAssignedManagedIdentity, self).__init__(
90+
id_type="SystemAssignedManagedIdentity", # As of this writing,
91+
# It can be any value other than
92+
# UserAssignedManagedIdentity._types_mapping's key names
93+
)
94+
95+
1796
def _scope_to_resource(scope): # This is an experimental reasonable-effort approach
1897
u = urlparse(scope)
1998
if u.scheme:
2099
return "{}://{}".format(u.scheme, u.netloc)
21100
return scope # There is no much else we can do here
22101

23102

24-
def _obtain_token(http_client, resource, client_id=None, object_id=None, mi_res_id=None):
103+
def _obtain_token(http_client, managed_identity, resource):
25104
if ("IDENTITY_ENDPOINT" in os.environ and "IDENTITY_HEADER" in os.environ
26105
and "IDENTITY_SERVER_THUMBPRINT" in os.environ
27106
):
28-
if client_id or object_id or mi_res_id:
29-
logger.debug(
30-
"Ignoring client_id/object_id/mi_res_id. "
31-
"Managed Identity in Service Fabric is configured in the cluster, "
32-
"not during runtime. See also "
33-
"https://learn.microsoft.com/en-us/azure/service-fabric/configure-existing-cluster-enable-managed-identity-token-service")
107+
logger.debug(
108+
"Ignoring client_id/object_id/mi_res_id. "
109+
"Managed Identity in Service Fabric is configured in the cluster, "
110+
"not during runtime. See also "
111+
"https://learn.microsoft.com/en-us/azure/service-fabric/configure-existing-cluster-enable-managed-identity-token-service")
34112
return _obtain_token_on_service_fabric(
35-
http_client, os.environ["IDENTITY_ENDPOINT"], os.environ["IDENTITY_HEADER"],
36-
os.environ["IDENTITY_SERVER_THUMBPRINT"], resource)
113+
http_client,
114+
os.environ["IDENTITY_ENDPOINT"],
115+
os.environ["IDENTITY_HEADER"],
116+
os.environ["IDENTITY_SERVER_THUMBPRINT"],
117+
resource,
118+
)
37119
if "IDENTITY_ENDPOINT" in os.environ and "IDENTITY_HEADER" in os.environ:
38120
return _obtain_token_on_app_service(
39-
http_client, os.environ["IDENTITY_ENDPOINT"], os.environ["IDENTITY_HEADER"],
40-
resource, client_id=client_id, object_id=object_id, mi_res_id=mi_res_id)
41-
return _obtain_token_on_azure_vm(
42-
http_client,
43-
resource, client_id=client_id, object_id=object_id, mi_res_id=mi_res_id)
121+
http_client,
122+
os.environ["IDENTITY_ENDPOINT"],
123+
os.environ["IDENTITY_HEADER"],
124+
managed_identity,
125+
resource,
126+
)
127+
return _obtain_token_on_azure_vm(http_client, managed_identity, resource)
44128

45129

46-
def _obtain_token_on_azure_vm(http_client, resource,
47-
client_id=None, object_id=None, mi_res_id=None,
48-
):
130+
def _adjust_param(params, managed_identity):
131+
id_name = UserAssignedManagedIdentity._types_mapping.get(
132+
managed_identity.get(ManagedIdentity.ID_TYPE))
133+
if id_name:
134+
params[id_name] = managed_identity[ManagedIdentity.ID]
135+
136+
def _obtain_token_on_azure_vm(http_client, managed_identity, resource):
49137
# Based on https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http
50138
logger.debug("Obtaining token via managed identity on Azure VM")
51139
params = {
52140
"api-version": "2018-02-01",
53141
"resource": resource,
54142
}
55-
if client_id:
56-
params["client_id"] = client_id
57-
if object_id:
58-
params["object_id"] = object_id
59-
if mi_res_id:
60-
params["mi_res_id"] = mi_res_id
143+
_adjust_param(params, managed_identity)
61144
resp = http_client.get(
62145
"http://169.254.169.254/metadata/identity/oauth2/token",
63146
params=params,
@@ -77,8 +160,8 @@ def _obtain_token_on_azure_vm(http_client, resource,
77160
logger.debug("IMDS emits unexpected payload: %s", resp.text)
78161
raise
79162

80-
def _obtain_token_on_app_service(http_client, endpoint, identity_header, resource,
81-
client_id=None, object_id=None, mi_res_id=None,
163+
def _obtain_token_on_app_service(
164+
http_client, endpoint, identity_header, managed_identity, resource,
82165
):
83166
"""Obtains token for
84167
`App Service <https://learn.microsoft.com/en-us/azure/app-service/overview-managed-identity?tabs=portal%2Chttp#rest-endpoint-reference>`_,
@@ -92,12 +175,7 @@ def _obtain_token_on_app_service(http_client, endpoint, identity_header, resourc
92175
"api-version": "2019-08-01",
93176
"resource": resource,
94177
}
95-
if client_id:
96-
params["client_id"] = client_id
97-
if object_id:
98-
params["object_id"] = object_id
99-
if mi_res_id:
100-
params["mi_res_id"] = mi_res_id
178+
_adjust_param(params, managed_identity)
101179
resp = http_client.get(
102180
endpoint,
103181
params=params,
@@ -167,32 +245,24 @@ def _obtain_token_on_service_fabric(
167245

168246

169247

170-
class ManagedIdentity(object):
248+
class ManagedIdentityClient(object):
171249
_instance, _tenant = socket.getfqdn(), "managed_identity" # Placeholders
172250

173-
def __init__(self, http_client,
174-
client_id=None, object_id=None, mi_res_id=None,
175-
token_cache=None,
176-
):
177-
"""Create a managed identity object.
251+
def __init__(self, http_client, managed_identity, token_cache=None):
252+
"""Create a managed identity client.
178253
179254
:param http_client:
180255
An http client object. For example, you can use `requests.Session()`.
181256
182-
:param str client_id:
183-
Optional.
184-
It accepts the Client ID (NOT the Object ID) of your user-assigned managed identity.
185-
If it is None, it means to use a system-assigned managed identity.
257+
:param dict managed_identity:
258+
It accepts an instance of :class:`SystemAssignedManagedIdentity`
259+
or :class:`UserAssignedManagedIdentity`, or their equivalent dict.
186260
187261
:param token_cache:
188262
Optional. It accepts a :class:`msal.TokenCache` instance to store tokens.
189263
"""
190-
if len(list(filter(bool, [client_id, object_id, mi_res_id]))) > 1:
191-
raise ValueError("You can use up to one of these: client_id, object_id, mi_res_id")
192264
self._http_client = http_client
193-
self._client_id = client_id
194-
self._object_id = object_id
195-
self._mi_res_id = mi_res_id
265+
self._managed_identity = managed_identity
196266
self._token_cache = token_cache
197267

198268
def acquire_token(self, resource=None):
@@ -202,9 +272,8 @@ def acquire_token(self, resource=None):
202272
"It is only declared as optional in method signature, "
203273
"in case we want to support scope parameter in the future.")
204274
access_token_from_cache = None
205-
client_id_in_cache = (
206-
self._client_id or self._object_id or self._mi_res_id
207-
or "SYSTEM_ASSIGNED_MANAGED_IDENTITY")
275+
client_id_in_cache = self._managed_identity.get(
276+
ManagedIdentity.ID, "SYSTEM_ASSIGNED_MANAGED_IDENTITY")
208277
if self._token_cache:
209278
matches = self._token_cache.find(
210279
self._token_cache.CredentialType.ACCESS_TOKEN,
@@ -230,13 +299,7 @@ def acquire_token(self, resource=None):
230299
if "refresh_on" in entry and int(entry["refresh_on"]) < now: # aging
231300
break # With a fallback in hand, we break here to go refresh
232301
return access_token_from_cache # It is still good as new
233-
result = _obtain_token(
234-
self._http_client,
235-
resource,
236-
client_id=self._client_id,
237-
object_id=self._object_id,
238-
mi_res_id=self._mi_res_id,
239-
)
302+
result = _obtain_token(self._http_client, self._managed_identity, resource)
240303
if self._token_cache and "access_token" in result:
241304
self._token_cache.add(dict(
242305
client_id=client_id_in_cache,

tests/msaltest.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -161,17 +161,19 @@ def exit(app):
161161
sys.exit()
162162

163163
def _managed_identity():
164-
client_id = _select_options([
165-
{"client_id": None, "name": "System-assigned managed identity"},
166-
],
164+
mi = _select_options([
165+
{
166+
'ManagedIdentityIdType': 'SystemAssignedManagedIdentity',
167+
"name": "System-assigned managed identity",
168+
}],
167169
option_renderer=lambda a: a["name"],
168170
header="Choose the system-assigned managed identity "
169-
"(or type in your user-assigned managed identity)",
171+
"(or type in your user-assigned managed identity's client id)",
170172
accept_nonempty_string=True)
171-
return msal.ManagedIdentity(
173+
return msal.ManagedIdentityClient(
172174
requests.Session(),
173-
client_id=client_id["client_id"]
174-
if isinstance(client_id, dict) else client_id,
175+
mi if isinstance(mi, dict) else msal.UserAssignedManagedIdentity(
176+
identifier=mi, id_type=msal.UserAssignedManagedIdentity.CLIENT_ID),
175177
token_cache=msal.TokenCache(),
176178
)
177179

@@ -218,7 +220,7 @@ def main():
218220
acquire_token_by_username_password,
219221
remove_account,
220222
],
221-
msal.ManagedIdentity: [acquire_token_for_managed_identity],
223+
msal.ManagedIdentityClient: [acquire_token_for_managed_identity],
222224
}.items() if isinstance(app, app_type)])
223225
while True:
224226
func = _select_options(

0 commit comments

Comments
 (0)