Skip to content

Commit 2f48e72

Browse files
committed
Managed Identity for Machine Learning
1 parent 8e8113e commit 2f48e72

File tree

2 files changed

+67
-0
lines changed

2 files changed

+67
-0
lines changed

msal/managed_identity.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,15 @@ def _obtain_token(http_client, managed_identity, resource):
330330
managed_identity,
331331
resource,
332332
)
333+
if "MSI_ENDPOINT" in os.environ and "MSI_SECRET" in os.environ:
334+
# Back ported from https://github.com/Azure/azure-sdk-for-python/blob/azure-identity_1.15.0/sdk/identity/azure-identity/azure/identity/_credentials/azure_ml.py
335+
return _obtain_token_on_machine_learning(
336+
http_client,
337+
os.environ["MSI_ENDPOINT"],
338+
os.environ["MSI_SECRET"],
339+
managed_identity,
340+
resource,
341+
)
333342
if "IDENTITY_ENDPOINT" in os.environ and "IMDS_ENDPOINT" in os.environ:
334343
if ManagedIdentity.is_user_assigned(managed_identity):
335344
raise ManagedIdentityError( # Note: Azure Identity for Python raised exception too
@@ -346,6 +355,7 @@ def _obtain_token(http_client, managed_identity, resource):
346355

347356

348357
def _adjust_param(params, managed_identity):
358+
# Modify the params dict in place
349359
id_name = ManagedIdentity._types_mapping.get(
350360
managed_identity.get(ManagedIdentity.ID_TYPE))
351361
if id_name:
@@ -422,6 +432,39 @@ def _obtain_token_on_app_service(
422432
logger.debug("IMDS emits unexpected payload: %s", resp.text)
423433
raise
424434

435+
def _obtain_token_on_machine_learning(
436+
http_client, endpoint, secret, managed_identity, resource,
437+
):
438+
# Could not find protocol docs from https://docs.microsoft.com/en-us/azure/machine-learning
439+
# The following implementation is back ported from Azure Identity 1.15.0
440+
logger.debug("Obtaining token via managed identity on Azure Machine Learning")
441+
params = {"api-version": "2017-09-01", "resource": resource}
442+
_adjust_param(params, managed_identity)
443+
if params["api-version"] == "2017-09-01" and "client_id" in params:
444+
# Workaround for a known bug in Azure ML 2017 API
445+
params["clientid"] = params.pop("client_id")
446+
resp = http_client.get(
447+
endpoint,
448+
params=params,
449+
headers={"secret": secret},
450+
)
451+
try:
452+
payload = json.loads(resp.text)
453+
if payload.get("access_token") and payload.get("expires_on"):
454+
return { # Normalizing the payload into OAuth2 format
455+
"access_token": payload["access_token"],
456+
"expires_in": int(payload["expires_on"]) - int(time.time()),
457+
"resource": payload.get("resource"),
458+
"token_type": payload.get("token_type", "Bearer"),
459+
}
460+
return {
461+
"error": "invalid_scope", # TODO: To be tested
462+
"error_description": "{}".format(payload),
463+
}
464+
except json.decoder.JSONDecodeError:
465+
logger.debug("IMDS emits unexpected payload: %s", resp.text)
466+
raise
467+
425468

426469
def _obtain_token_on_service_fabric(
427470
http_client, endpoint, identity_header, server_thumbprint, resource,

tests/test_mi.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,30 @@ def test_app_service_error_should_be_normalized(self):
119119
self.assertEqual({}, self.app._token_cache._cache)
120120

121121

122+
@patch.dict(os.environ, {"MSI_ENDPOINT": "http://localhost", "MSI_SECRET": "foo"})
123+
class MachineLearningTestCase(ClientTestCase):
124+
125+
def test_happy_path(self):
126+
with patch.object(self.app._http_client, "get", return_value=MinimalResponse(
127+
status_code=200,
128+
text='{"access_token": "AT", "expires_on": "%s", "resource": "R"}' % (
129+
int(time.time()) + 1234),
130+
)) as mocked_method:
131+
self._test_happy_path(self.app, mocked_method)
132+
133+
def test_machine_learning_error_should_be_normalized(self):
134+
raw_error = '{"error": "placeholder", "message": "placeholder"}'
135+
with patch.object(self.app._http_client, "get", return_value=MinimalResponse(
136+
status_code=500,
137+
text=raw_error,
138+
)) as mocked_method:
139+
self.assertEqual({
140+
"error": "invalid_scope",
141+
"error_description": "{'error': 'placeholder', 'message': 'placeholder'}",
142+
}, self.app.acquire_token_for_client(resource="R"))
143+
self.assertEqual({}, self.app._token_cache._cache)
144+
145+
122146
@patch.dict(os.environ, {
123147
"IDENTITY_ENDPOINT": "http://localhost",
124148
"IDENTITY_HEADER": "foo",

0 commit comments

Comments
 (0)