@@ -313,6 +313,15 @@ def _obtain_token(http_client, managed_identity, resource):
313
313
managed_identity ,
314
314
resource ,
315
315
)
316
+ if "MSI_ENDPOINT" in os .environ and "MSI_SECRET" in os .environ :
317
+ # Back ported from https://github.com/Azure/azure-sdk-for-python/blob/azure-identity_1.15.0/sdk/identity/azure-identity/azure/identity/_credentials/azure_ml.py
318
+ return _obtain_token_on_machine_learning (
319
+ http_client ,
320
+ os .environ ["MSI_ENDPOINT" ],
321
+ os .environ ["MSI_SECRET" ],
322
+ managed_identity ,
323
+ resource ,
324
+ )
316
325
if "IDENTITY_ENDPOINT" in os .environ and "IMDS_ENDPOINT" in os .environ :
317
326
if ManagedIdentity .is_user_assigned (managed_identity ):
318
327
raise ManagedIdentityError ( # Note: Azure Identity for Python raised exception too
@@ -329,6 +338,7 @@ def _obtain_token(http_client, managed_identity, resource):
329
338
330
339
331
340
def _adjust_param (params , managed_identity ):
341
+ # Modify the params dict in place
332
342
id_name = ManagedIdentity ._types_mapping .get (
333
343
managed_identity .get (ManagedIdentity .ID_TYPE ))
334
344
if id_name :
@@ -405,6 +415,39 @@ def _obtain_token_on_app_service(
405
415
logger .debug ("IMDS emits unexpected payload: %s" , resp .text )
406
416
raise
407
417
418
+ def _obtain_token_on_machine_learning (
419
+ http_client , endpoint , secret , managed_identity , resource ,
420
+ ):
421
+ # Could not find protocol docs from https://docs.microsoft.com/en-us/azure/machine-learning
422
+ # The following implementation is back ported from Azure Identity 1.15.0
423
+ logger .debug ("Obtaining token via managed identity on Azure Machine Learning" )
424
+ params = {"api-version" : "2017-09-01" , "resource" : resource }
425
+ _adjust_param (params , managed_identity )
426
+ if params ["api-version" ] == "2017-09-01" and "client_id" in params :
427
+ # Workaround for a known bug in Azure ML 2017 API
428
+ params ["clientid" ] = params .pop ("client_id" )
429
+ resp = http_client .get (
430
+ endpoint ,
431
+ params = params ,
432
+ headers = {"secret" : secret },
433
+ )
434
+ try :
435
+ payload = json .loads (resp .text )
436
+ if payload .get ("access_token" ) and payload .get ("expires_on" ):
437
+ return { # Normalizing the payload into OAuth2 format
438
+ "access_token" : payload ["access_token" ],
439
+ "expires_in" : int (payload ["expires_on" ]) - int (time .time ()),
440
+ "resource" : payload .get ("resource" ),
441
+ "token_type" : payload .get ("token_type" , "Bearer" ),
442
+ }
443
+ return {
444
+ "error" : "invalid_scope" , # TODO: To be tested
445
+ "error_description" : "{}" .format (payload ),
446
+ }
447
+ except json .decoder .JSONDecodeError :
448
+ logger .debug ("IMDS emits unexpected payload: %s" , resp .text )
449
+ raise
450
+
408
451
409
452
def _obtain_token_on_service_fabric (
410
453
http_client , endpoint , identity_header , server_thumbprint , resource ,
0 commit comments