@@ -330,6 +330,15 @@ def _obtain_token(http_client, managed_identity, resource):
330
330
managed_identity ,
331
331
resource ,
332
332
)
333
+ if "MSI_ENDPOINT" in os .environ and "MSI_SECRET" in os .environ :
334
+ # Back ported from https://github.com/Azure/azure-sdk-for-python/blob/azure-identity_1.15.0/sdk/identity/azure-identity/azure/identity/_credentials/azure_ml.py
335
+ return _obtain_token_on_machine_learning (
336
+ http_client ,
337
+ os .environ ["MSI_ENDPOINT" ],
338
+ os .environ ["MSI_SECRET" ],
339
+ managed_identity ,
340
+ resource ,
341
+ )
333
342
if "IDENTITY_ENDPOINT" in os .environ and "IMDS_ENDPOINT" in os .environ :
334
343
if ManagedIdentity .is_user_assigned (managed_identity ):
335
344
raise ManagedIdentityError ( # Note: Azure Identity for Python raised exception too
@@ -346,6 +355,7 @@ def _obtain_token(http_client, managed_identity, resource):
346
355
347
356
348
357
def _adjust_param (params , managed_identity ):
358
+ # Modify the params dict in place
349
359
id_name = ManagedIdentity ._types_mapping .get (
350
360
managed_identity .get (ManagedIdentity .ID_TYPE ))
351
361
if id_name :
@@ -422,6 +432,39 @@ def _obtain_token_on_app_service(
422
432
logger .debug ("IMDS emits unexpected payload: %s" , resp .text )
423
433
raise
424
434
435
+ def _obtain_token_on_machine_learning (
436
+ http_client , endpoint , secret , managed_identity , resource ,
437
+ ):
438
+ # Could not find protocol docs from https://docs.microsoft.com/en-us/azure/machine-learning
439
+ # The following implementation is back ported from Azure Identity 1.15.0
440
+ logger .debug ("Obtaining token via managed identity on Azure Machine Learning" )
441
+ params = {"api-version" : "2017-09-01" , "resource" : resource }
442
+ _adjust_param (params , managed_identity )
443
+ if params ["api-version" ] == "2017-09-01" and "client_id" in params :
444
+ # Workaround for a known bug in Azure ML 2017 API
445
+ params ["clientid" ] = params .pop ("client_id" )
446
+ resp = http_client .get (
447
+ endpoint ,
448
+ params = params ,
449
+ headers = {"secret" : secret },
450
+ )
451
+ try :
452
+ payload = json .loads (resp .text )
453
+ if payload .get ("access_token" ) and payload .get ("expires_on" ):
454
+ return { # Normalizing the payload into OAuth2 format
455
+ "access_token" : payload ["access_token" ],
456
+ "expires_in" : int (payload ["expires_on" ]) - int (time .time ()),
457
+ "resource" : payload .get ("resource" ),
458
+ "token_type" : payload .get ("token_type" , "Bearer" ),
459
+ }
460
+ return {
461
+ "error" : "invalid_scope" , # TODO: To be tested
462
+ "error_description" : "{}" .format (payload ),
463
+ }
464
+ except json .decoder .JSONDecodeError :
465
+ logger .debug ("IMDS emits unexpected payload: %s" , resp .text )
466
+ raise
467
+
425
468
426
469
def _obtain_token_on_service_fabric (
427
470
http_client , endpoint , identity_header , server_thumbprint , resource ,
0 commit comments