1
+ # TODO: Change the module name from imds to managed_identity
1
2
# Copyright (c) Microsoft Corporation.
2
3
# All rights reserved.
3
4
#
11
12
from urlparse import urlparse
12
13
except : # Python 3
13
14
from urllib .parse import urlparse
15
+ try : # Python 3
16
+ from collections import UserDict
17
+ except :
18
+ UserDict = dict # The real UserDict is an old-style class which fails super()
19
+
14
20
15
21
logger = logging .getLogger (__name__ )
16
22
23
+ class ManagedIdentity (UserDict ):
24
+ # The key names used in config dict
25
+ ID_TYPE = "ManagedIdentityIdType"
26
+ ID = "Id"
27
+ def __init__ (self , identifier = None , id_type = None ):
28
+ super (ManagedIdentity , self ).__init__ ({
29
+ self .ID_TYPE : id_type ,
30
+ self .ID : identifier ,
31
+ })
32
+
33
+
34
+ class UserAssignedManagedIdentity (ManagedIdentity ):
35
+ """Feed an instance of this class to :class:`msal.ManagedIdentityClient`
36
+ to acquire token for user-assigned managed identity.
37
+
38
+ By design, an instance of this class is equivalent to a dict in
39
+ one of these shapes::
40
+
41
+ {"ManagedIdentityIdType": "ClientId", "Id": "foo"}
42
+
43
+ {"ManagedIdentityIdType": "ResourceId", "Id": "foo"}
44
+
45
+ {"ManagedIdentityIdType": "ObjectId", "Id": "foo"}
46
+
47
+ so that you may load it from a json configuration file or an env var,
48
+ and feed it to :class:`Client`.
49
+ """
50
+ CLIENT_ID = "ClientId"
51
+ RESOURCE_ID = "ResourceId"
52
+ OBJECT_ID = "ObjectId"
53
+ _types_mapping = { # Maps type name in configuration to type name on wire
54
+ CLIENT_ID : "client_id" ,
55
+ RESOURCE_ID : "mi_res_id" ,
56
+ OBJECT_ID : "object_id" ,
57
+ }
58
+ def __init__ (self , identifier , id_type ):
59
+ """Construct a UserAssignedManagedIdentity instance.
60
+
61
+ :param string identifier: The id.
62
+ :param string id_type: It shall be one of these three::
63
+
64
+ UserAssignedManagedIdentity.CLIENT_ID
65
+ UserAssignedManagedIdentity.RESOURCE_ID
66
+ UserAssignedManagedIdentity.OBJECT_ID
67
+ """
68
+ if id_type not in self ._types_mapping :
69
+ raise ValueError ("id_type only accepts one of: {}" .format (
70
+ list (self ._types_mapping )))
71
+ super (UserAssignedManagedIdentity , self ).__init__ (
72
+ identifier = identifier ,
73
+ id_type = id_type ,
74
+ )
75
+
76
+
77
+ class SystemAssignedManagedIdentity (ManagedIdentity ):
78
+ """Feed an instance of this class to :class:`msal.ManagedIdentityClient`
79
+ to acquire token for system-assigned managed identity.
80
+
81
+ By design, an instance of this class is equivalent to::
82
+
83
+ {"ManagedIdentityIdType": "SystemAssignedManagedIdentity", "Id": None}
84
+
85
+ so that you may load it from a json configuration file or an env var,
86
+ and feed it to :class:`Client`.
87
+ """
88
+ def __init__ (self ):
89
+ super (SystemAssignedManagedIdentity , self ).__init__ (
90
+ id_type = "SystemAssignedManagedIdentity" , # As of this writing,
91
+ # It can be any value other than
92
+ # UserAssignedManagedIdentity._types_mapping's key names
93
+ )
94
+
95
+
17
96
def _scope_to_resource (scope ): # This is an experimental reasonable-effort approach
18
97
u = urlparse (scope )
19
98
if u .scheme :
20
99
return "{}://{}" .format (u .scheme , u .netloc )
21
100
return scope # There is no much else we can do here
22
101
23
102
24
- def _obtain_token (http_client , resource , client_id = None , object_id = None , mi_res_id = None ):
103
+ def _obtain_token (http_client , managed_identity , resource ):
25
104
if ("IDENTITY_ENDPOINT" in os .environ and "IDENTITY_HEADER" in os .environ
26
105
and "IDENTITY_SERVER_THUMBPRINT" in os .environ
27
106
):
28
- if client_id or object_id or mi_res_id :
29
- logger .debug (
30
- "Ignoring client_id/object_id/mi_res_id. "
31
- "Managed Identity in Service Fabric is configured in the cluster, "
32
- "not during runtime. See also "
33
- "https://learn.microsoft.com/en-us/azure/service-fabric/configure-existing-cluster-enable-managed-identity-token-service" )
107
+ logger .debug (
108
+ "Ignoring client_id/object_id/mi_res_id. "
109
+ "Managed Identity in Service Fabric is configured in the cluster, "
110
+ "not during runtime. See also "
111
+ "https://learn.microsoft.com/en-us/azure/service-fabric/configure-existing-cluster-enable-managed-identity-token-service" )
34
112
return _obtain_token_on_service_fabric (
35
- http_client , os .environ ["IDENTITY_ENDPOINT" ], os .environ ["IDENTITY_HEADER" ],
36
- os .environ ["IDENTITY_SERVER_THUMBPRINT" ], resource )
113
+ http_client ,
114
+ os .environ ["IDENTITY_ENDPOINT" ],
115
+ os .environ ["IDENTITY_HEADER" ],
116
+ os .environ ["IDENTITY_SERVER_THUMBPRINT" ],
117
+ resource ,
118
+ )
37
119
if "IDENTITY_ENDPOINT" in os .environ and "IDENTITY_HEADER" in os .environ :
38
120
return _obtain_token_on_app_service (
39
- http_client , os .environ ["IDENTITY_ENDPOINT" ], os .environ ["IDENTITY_HEADER" ],
40
- resource , client_id = client_id , object_id = object_id , mi_res_id = mi_res_id )
41
- return _obtain_token_on_azure_vm (
42
- http_client ,
43
- resource , client_id = client_id , object_id = object_id , mi_res_id = mi_res_id )
121
+ http_client ,
122
+ os .environ ["IDENTITY_ENDPOINT" ],
123
+ os .environ ["IDENTITY_HEADER" ],
124
+ managed_identity ,
125
+ resource ,
126
+ )
127
+ return _obtain_token_on_azure_vm (http_client , managed_identity , resource )
44
128
45
129
46
- def _obtain_token_on_azure_vm (http_client , resource ,
47
- client_id = None , object_id = None , mi_res_id = None ,
48
- ):
130
+ def _adjust_param (params , managed_identity ):
131
+ id_name = UserAssignedManagedIdentity ._types_mapping .get (
132
+ managed_identity .get (ManagedIdentity .ID_TYPE ))
133
+ if id_name :
134
+ params [id_name ] = managed_identity [ManagedIdentity .ID ]
135
+
136
+ def _obtain_token_on_azure_vm (http_client , managed_identity , resource ):
49
137
# Based on https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http
50
138
logger .debug ("Obtaining token via managed identity on Azure VM" )
51
139
params = {
52
140
"api-version" : "2018-02-01" ,
53
141
"resource" : resource ,
54
142
}
55
- if client_id :
56
- params ["client_id" ] = client_id
57
- if object_id :
58
- params ["object_id" ] = object_id
59
- if mi_res_id :
60
- params ["mi_res_id" ] = mi_res_id
143
+ _adjust_param (params , managed_identity )
61
144
resp = http_client .get (
62
145
"http://169.254.169.254/metadata/identity/oauth2/token" ,
63
146
params = params ,
@@ -77,8 +160,8 @@ def _obtain_token_on_azure_vm(http_client, resource,
77
160
logger .debug ("IMDS emits unexpected payload: %s" , resp .text )
78
161
raise
79
162
80
- def _obtain_token_on_app_service (http_client , endpoint , identity_header , resource ,
81
- client_id = None , object_id = None , mi_res_id = None ,
163
+ def _obtain_token_on_app_service (
164
+ http_client , endpoint , identity_header , managed_identity , resource ,
82
165
):
83
166
"""Obtains token for
84
167
`App Service <https://learn.microsoft.com/en-us/azure/app-service/overview-managed-identity?tabs=portal%2Chttp#rest-endpoint-reference>`_,
@@ -92,12 +175,7 @@ def _obtain_token_on_app_service(http_client, endpoint, identity_header, resourc
92
175
"api-version" : "2019-08-01" ,
93
176
"resource" : resource ,
94
177
}
95
- if client_id :
96
- params ["client_id" ] = client_id
97
- if object_id :
98
- params ["object_id" ] = object_id
99
- if mi_res_id :
100
- params ["mi_res_id" ] = mi_res_id
178
+ _adjust_param (params , managed_identity )
101
179
resp = http_client .get (
102
180
endpoint ,
103
181
params = params ,
@@ -167,32 +245,24 @@ def _obtain_token_on_service_fabric(
167
245
168
246
169
247
170
- class ManagedIdentity (object ):
248
+ class ManagedIdentityClient (object ):
171
249
_instance , _tenant = socket .getfqdn (), "managed_identity" # Placeholders
172
250
173
- def __init__ (self , http_client ,
174
- client_id = None , object_id = None , mi_res_id = None ,
175
- token_cache = None ,
176
- ):
177
- """Create a managed identity object.
251
+ def __init__ (self , http_client , managed_identity , token_cache = None ):
252
+ """Create a managed identity client.
178
253
179
254
:param http_client:
180
255
An http client object. For example, you can use `requests.Session()`.
181
256
182
- :param str client_id:
183
- Optional.
184
- It accepts the Client ID (NOT the Object ID) of your user-assigned managed identity.
185
- If it is None, it means to use a system-assigned managed identity.
257
+ :param dict managed_identity:
258
+ It accepts an instance of :class:`SystemAssignedManagedIdentity`
259
+ or :class:`UserAssignedManagedIdentity`, or their equivalent dict.
186
260
187
261
:param token_cache:
188
262
Optional. It accepts a :class:`msal.TokenCache` instance to store tokens.
189
263
"""
190
- if len (list (filter (bool , [client_id , object_id , mi_res_id ]))) > 1 :
191
- raise ValueError ("You can use up to one of these: client_id, object_id, mi_res_id" )
192
264
self ._http_client = http_client
193
- self ._client_id = client_id
194
- self ._object_id = object_id
195
- self ._mi_res_id = mi_res_id
265
+ self ._managed_identity = managed_identity
196
266
self ._token_cache = token_cache
197
267
198
268
def acquire_token (self , resource = None ):
@@ -202,9 +272,8 @@ def acquire_token(self, resource=None):
202
272
"It is only declared as optional in method signature, "
203
273
"in case we want to support scope parameter in the future." )
204
274
access_token_from_cache = None
205
- client_id_in_cache = (
206
- self ._client_id or self ._object_id or self ._mi_res_id
207
- or "SYSTEM_ASSIGNED_MANAGED_IDENTITY" )
275
+ client_id_in_cache = self ._managed_identity .get (
276
+ ManagedIdentity .ID , "SYSTEM_ASSIGNED_MANAGED_IDENTITY" )
208
277
if self ._token_cache :
209
278
matches = self ._token_cache .find (
210
279
self ._token_cache .CredentialType .ACCESS_TOKEN ,
@@ -230,13 +299,7 @@ def acquire_token(self, resource=None):
230
299
if "refresh_on" in entry and int (entry ["refresh_on" ]) < now : # aging
231
300
break # With a fallback in hand, we break here to go refresh
232
301
return access_token_from_cache # It is still good as new
233
- result = _obtain_token (
234
- self ._http_client ,
235
- resource ,
236
- client_id = self ._client_id ,
237
- object_id = self ._object_id ,
238
- mi_res_id = self ._mi_res_id ,
239
- )
302
+ result = _obtain_token (self ._http_client , self ._managed_identity , resource )
240
303
if self ._token_cache and "access_token" in result :
241
304
self ._token_cache .add (dict (
242
305
client_id = client_id_in_cache ,
0 commit comments