Skip to content

Commit 9350391

Browse files
committed
Managed Identity for Machine Learning
1 parent 2c8c5ba commit 9350391

File tree

2 files changed

+67
-0
lines changed

2 files changed

+67
-0
lines changed

msal/managed_identity.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,15 @@ def _obtain_token(http_client, managed_identity, resource):
313313
managed_identity,
314314
resource,
315315
)
316+
if "MSI_ENDPOINT" in os.environ and "MSI_SECRET" in os.environ:
317+
# Back ported from https://github.com/Azure/azure-sdk-for-python/blob/azure-identity_1.15.0/sdk/identity/azure-identity/azure/identity/_credentials/azure_ml.py
318+
return _obtain_token_on_machine_learning(
319+
http_client,
320+
os.environ["MSI_ENDPOINT"],
321+
os.environ["MSI_SECRET"],
322+
managed_identity,
323+
resource,
324+
)
316325
if "IDENTITY_ENDPOINT" in os.environ and "IMDS_ENDPOINT" in os.environ:
317326
if ManagedIdentity.is_user_assigned(managed_identity):
318327
raise ManagedIdentityError( # Note: Azure Identity for Python raised exception too
@@ -329,6 +338,7 @@ def _obtain_token(http_client, managed_identity, resource):
329338

330339

331340
def _adjust_param(params, managed_identity):
341+
# Modify the params dict in place
332342
id_name = ManagedIdentity._types_mapping.get(
333343
managed_identity.get(ManagedIdentity.ID_TYPE))
334344
if id_name:
@@ -405,6 +415,39 @@ def _obtain_token_on_app_service(
405415
logger.debug("IMDS emits unexpected payload: %s", resp.text)
406416
raise
407417

418+
def _obtain_token_on_machine_learning(
419+
http_client, endpoint, secret, managed_identity, resource,
420+
):
421+
# Could not find protocol docs from https://docs.microsoft.com/en-us/azure/machine-learning
422+
# The following implementation is back ported from Azure Identity 1.15.0
423+
logger.debug("Obtaining token via managed identity on Azure Machine Learning")
424+
params = {"api-version": "2017-09-01", "resource": resource}
425+
_adjust_param(params, managed_identity)
426+
if params["api-version"] == "2017-09-01" and "client_id" in params:
427+
# Workaround for a known bug in Azure ML 2017 API
428+
params["clientid"] = params.pop("client_id")
429+
resp = http_client.get(
430+
endpoint,
431+
params=params,
432+
headers={"secret": secret},
433+
)
434+
try:
435+
payload = json.loads(resp.text)
436+
if payload.get("access_token") and payload.get("expires_on"):
437+
return { # Normalizing the payload into OAuth2 format
438+
"access_token": payload["access_token"],
439+
"expires_in": int(payload["expires_on"]) - int(time.time()),
440+
"resource": payload.get("resource"),
441+
"token_type": payload.get("token_type", "Bearer"),
442+
}
443+
return {
444+
"error": "invalid_scope", # TODO: To be tested
445+
"error_description": "{}".format(payload),
446+
}
447+
except json.decoder.JSONDecodeError:
448+
logger.debug("IMDS emits unexpected payload: %s", resp.text)
449+
raise
450+
408451

409452
def _obtain_token_on_service_fabric(
410453
http_client, endpoint, identity_header, server_thumbprint, resource,

tests/test_mi.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,30 @@ def test_app_service_error_should_be_normalized(self):
119119
self.assertEqual({}, self.app._token_cache._cache)
120120

121121

122+
@patch.dict(os.environ, {"MSI_ENDPOINT": "http://localhost", "MSI_SECRET": "foo"})
123+
class MachineLearningTestCase(ClientTestCase):
124+
125+
def test_happy_path(self):
126+
with patch.object(self.app._http_client, "get", return_value=MinimalResponse(
127+
status_code=200,
128+
text='{"access_token": "AT", "expires_on": "%s", "resource": "R"}' % (
129+
int(time.time()) + 1234),
130+
)) as mocked_method:
131+
self._test_happy_path(self.app, mocked_method)
132+
133+
def test_machine_learning_error_should_be_normalized(self):
134+
raw_error = '{"error": "placeholder", "message": "placeholder"}'
135+
with patch.object(self.app._http_client, "get", return_value=MinimalResponse(
136+
status_code=500,
137+
text=raw_error,
138+
)) as mocked_method:
139+
self.assertEqual({
140+
"error": "invalid_scope",
141+
"error_description": "{'error': 'placeholder', 'message': 'placeholder'}",
142+
}, self.app.acquire_token_for_client(resource="R"))
143+
self.assertEqual({}, self.app._token_cache._cache)
144+
145+
122146
@patch.dict(os.environ, {
123147
"IDENTITY_ENDPOINT": "http://localhost",
124148
"IDENTITY_HEADER": "foo",

0 commit comments

Comments
 (0)