Skip to content

Commit 23e2e15

Browse files
committed
Also support object_id and msi_res_id
1 parent 8b8310b commit 23e2e15

File tree

2 files changed

+51
-15
lines changed

2 files changed

+51
-15
lines changed

msal/imds.py

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,19 @@ def _scope_to_resource(scope): # This is an experimental reasonable-effort appr
2121
return scope # There is no much else we can do here
2222

2323

24-
def _obtain_token(http_client, resource, client_id=None):
24+
def _obtain_token(http_client, resource, client_id=None, object_id=None, mi_res_id=None):
2525
if "IDENTITY_ENDPOINT" in os.environ and "IDENTITY_HEADER" in os.environ:
2626
return _obtain_token_on_app_service(
2727
http_client, os.environ["IDENTITY_ENDPOINT"], os.environ["IDENTITY_HEADER"],
28-
resource, client_id=client_id)
29-
return _obtain_token_on_azure_vm(http_client, resource, client_id=client_id)
28+
resource, client_id=client_id, object_id=object_id, mi_res_id=mi_res_id)
29+
return _obtain_token_on_azure_vm(
30+
http_client,
31+
resource, client_id=client_id, object_id=object_id, mi_res_id=mi_res_id)
3032

3133

32-
def _obtain_token_on_azure_vm(http_client, resource, client_id=None):
34+
def _obtain_token_on_azure_vm(http_client, resource,
35+
client_id=None, object_id=None, mi_res_id=None,
36+
):
3337
# Based on https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http
3438
logger.debug("Obtaining token via managed identity on Azure VM")
3539
params = {
@@ -38,6 +42,10 @@ def _obtain_token_on_azure_vm(http_client, resource, client_id=None):
3842
}
3943
if client_id:
4044
params["client_id"] = client_id
45+
if object_id:
46+
params["object_id"] = object_id
47+
if mi_res_id:
48+
params["mi_res_id"] = mi_res_id
4149
resp = http_client.get(
4250
"http://169.254.169.254/metadata/identity/oauth2/token",
4351
params=params,
@@ -57,7 +65,9 @@ def _obtain_token_on_azure_vm(http_client, resource, client_id=None):
5765
logger.debug("IMDS emits unexpected payload: %s", resp.text)
5866
raise
5967

60-
def _obtain_token_on_app_service(http_client, endpoint, identity_header, resource, client_id=None):
68+
def _obtain_token_on_app_service(http_client, endpoint, identity_header, resource,
69+
client_id=None, object_id=None, mi_res_id=None,
70+
):
6171
"""Obtains token for
6272
`App Service <https://learn.microsoft.com/en-us/azure/app-service/overview-managed-identity?tabs=portal%2Chttp#rest-endpoint-reference>`_
6373
"""
@@ -71,6 +81,10 @@ def _obtain_token_on_app_service(http_client, endpoint, identity_header, resourc
7181
}
7282
if client_id:
7383
params["client_id"] = client_id
84+
if object_id:
85+
params["object_id"] = object_id
86+
if mi_res_id:
87+
params["mi_res_id"] = mi_res_id
7488
resp = http_client.get(
7589
endpoint,
7690
params=params,
@@ -103,7 +117,10 @@ def _obtain_token_on_app_service(http_client, endpoint, identity_header, resourc
103117
class ManagedIdentity(object):
104118
_instance, _tenant = socket.getfqdn(), "managed_identity" # Placeholders
105119

106-
def __init__(self, http_client, client_id=None, token_cache=None):
120+
def __init__(self, http_client,
121+
client_id=None, object_id=None, mi_res_id=None,
122+
token_cache=None,
123+
):
107124
"""Create a managed identity object.
108125
109126
:param http_client:
@@ -117,13 +134,24 @@ def __init__(self, http_client, client_id=None, token_cache=None):
117134
:param token_cache:
118135
Optional. It accepts a :class:`msal.TokenCache` instance to store tokens.
119136
"""
137+
if len(list(filter(bool, [client_id, object_id, mi_res_id]))) > 1:
138+
raise ValueError("You can use up to one of these: client_id, object_id, mi_res_id")
120139
self._http_client = http_client
121140
self._client_id = client_id
141+
self._object_id = object_id
142+
self._mi_res_id = mi_res_id
122143
self._token_cache = token_cache
123144

124-
def acquire_token(self, resource):
145+
def acquire_token(self, resource=None):
146+
if not resource:
147+
raise ValueError(
148+
"The resource parameter is currently required. "
149+
"It is only declared as optional in method signature, "
150+
"in case we want to support scope parameter in the future.")
125151
access_token_from_cache = None
126-
client_id_in_cache = self._client_id or "SYSTEM_ASSIGNED_MANAGED_IDENTITY"
152+
client_id_in_cache = (
153+
self._client_id or self._object_id or self._mi_res_id
154+
or "SYSTEM_ASSIGNED_MANAGED_IDENTITY")
127155
if self._token_cache:
128156
matches = self._token_cache.find(
129157
self._token_cache.CredentialType.ACCESS_TOKEN,
@@ -149,7 +177,13 @@ def acquire_token(self, resource):
149177
if "refresh_on" in entry and int(entry["refresh_on"]) < now: # aging
150178
break # With a fallback in hand, we break here to go refresh
151179
return access_token_from_cache # It is still good as new
152-
result = _obtain_token(self._http_client, resource, client_id=self._client_id)
180+
result = _obtain_token(
181+
self._http_client,
182+
resource,
183+
client_id=self._client_id,
184+
object_id=self._object_id,
185+
mi_res_id=self._mi_res_id,
186+
)
153187
if self._token_cache and "access_token" in result:
154188
self._token_cache.add(dict(
155189
client_id=client_id_in_cache,

tests/test_mi.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def test_happy_path(self):
3434
status_code=200,
3535
text='{"access_token": "AT", "expires_in": "1234", "resource": "R"}',
3636
)) as mocked_method:
37-
result = app.acquire_token("R")
37+
result = app.acquire_token(resource="R")
3838
mocked_method.assert_called_once()
3939
self.assertEqual({
4040
"access_token": "AT",
@@ -43,7 +43,8 @@ def test_happy_path(self):
4343
"token_type": "Bearer",
4444
}, result, "Should obtain a token response")
4545
self.assertEqual(
46-
result["access_token"], app.acquire_token("R").get("access_token"),
46+
result["access_token"],
47+
app.acquire_token(resource="R").get("access_token"),
4748
"Should hit the same token from cache")
4849
self._test_token_cache(app)
4950

@@ -54,7 +55,7 @@ def test_vm_error_should_be_returned_as_is(self):
5455
status_code=400,
5556
text=raw_error,
5657
)) as mocked_method:
57-
self.assertEqual(json.loads(raw_error), app.acquire_token("R"))
58+
self.assertEqual(json.loads(raw_error), app.acquire_token(resource="R"))
5859
self.assertEqual({}, app._token_cache._cache)
5960

6061

@@ -69,7 +70,7 @@ def test_happy_path(self):
6970
status_code=200,
7071
text='{"access_token": "AT", "expires_on": "%s", "resource": "R"}' % (now + 100),
7172
)) as mocked_method:
72-
result = app.acquire_token("R")
73+
result = app.acquire_token(resource="R")
7374
mocked_method.assert_called_once()
7475
self.assertEqual({
7576
"access_token": "AT",
@@ -78,7 +79,8 @@ def test_happy_path(self):
7879
"token_type": "Bearer",
7980
}, result, "Should obtain a token response")
8081
self.assertEqual(
81-
result["access_token"], app.acquire_token("R").get("access_token"),
82+
result["access_token"],
83+
app.acquire_token(resource="R").get("access_token"),
8284
"Should hit the same token from cache")
8385
self._test_token_cache(app)
8486

@@ -92,6 +94,6 @@ def test_app_service_error_should_be_normalized(self):
9294
self.assertEqual({
9395
"error": "invalid_scope",
9496
"error_description": "500, error content is undefined",
95-
}, app.acquire_token("R"))
97+
}, app.acquire_token(resource="R"))
9698
self.assertEqual({}, app._token_cache._cache)
9799

0 commit comments

Comments
 (0)