Skip to content

Commit dcdca81

Browse files
committed
get_managed_identity_source() for Azure Identity
1 parent d809624 commit dcdca81

File tree

2 files changed

+81
-1
lines changed

2 files changed

+81
-1
lines changed

msal/managed_identity.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .token_cache import TokenCache
1414
from .individual_cache import _IndividualCache as IndividualCache
1515
from .throttled_http_client import ThrottledHttpClientBase, RetryAfterParser
16+
from .cloudshell import _is_running_in_cloud_shell
1617

1718

1819
logger = logging.getLogger(__name__)
@@ -299,6 +300,35 @@ def _scope_to_resource(scope): # This is an experimental reasonable-effort appr
299300
return scope # There is no much else we can do here
300301

301302

303+
APP_SERVICE = object()
304+
AZURE_ARC = object()
305+
CLOUD_SHELL = object() # In MSAL Python, token acquisition was done by
306+
# PublicClientApplication(...).acquire_token_interactive(..., prompt="none")
307+
MACHINE_LEARNING = object()
308+
SERVICE_FABRIC = object()
309+
DEFAULT_TO_VM = object() # Unknown environment; default to VM; you may want to probe
310+
def get_managed_identity_source():
311+
"""Detect the current environment and return the likely identity source.
312+
313+
When this function returns ``CLOUD_SHELL``, you should use
314+
:func:`msal.PublicClientApplication.acquire_token_interactive` with ``prompt="none"``
315+
to obtain a token.
316+
"""
317+
if ("IDENTITY_ENDPOINT" in os.environ and "IDENTITY_HEADER" in os.environ
318+
and "IDENTITY_SERVER_THUMBPRINT" in os.environ
319+
):
320+
return SERVICE_FABRIC
321+
if "IDENTITY_ENDPOINT" in os.environ and "IDENTITY_HEADER" in os.environ:
322+
return APP_SERVICE
323+
if "MSI_ENDPOINT" in os.environ and "MSI_SECRET" in os.environ:
324+
return MACHINE_LEARNING
325+
if "IDENTITY_ENDPOINT" in os.environ and "IMDS_ENDPOINT" in os.environ:
326+
return AZURE_ARC
327+
if _is_running_in_cloud_shell():
328+
return CLOUD_SHELL
329+
return DEFAULT_TO_VM
330+
331+
302332
def _obtain_token(http_client, managed_identity, resource):
303333
# A unified low-level API that talks to different Managed Identity
304334
if ("IDENTITY_ENDPOINT" in os.environ and "IDENTITY_HEADER" in os.environ

tests/test_mi.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,16 @@
1616
ManagedIdentityError,
1717
ArcPlatformNotSupportedError,
1818
)
19-
from msal.managed_identity import _supported_arc_platforms_and_their_prefixes
19+
from msal.managed_identity import (
20+
_supported_arc_platforms_and_their_prefixes,
21+
get_managed_identity_source,
22+
APP_SERVICE,
23+
AZURE_ARC,
24+
CLOUD_SHELL,
25+
MACHINE_LEARNING,
26+
SERVICE_FABRIC,
27+
DEFAULT_TO_VM,
28+
)
2029

2130

2231
class ManagedIdentityTestCase(unittest.TestCase):
@@ -234,3 +243,44 @@ def test_arc_error_should_be_normalized(self, mocked_stat):
234243
if sys.platform in _supported_arc_platforms_and_their_prefixes:
235244
self.fail("Should not raise ArcPlatformNotSupportedError")
236245

246+
247+
class GetManagedIdentitySourceTestCase(unittest.TestCase):
248+
249+
@patch.dict(os.environ, {
250+
"IDENTITY_ENDPOINT": "http://localhost",
251+
"IDENTITY_HEADER": "foo",
252+
"IDENTITY_SERVER_THUMBPRINT": "bar",
253+
})
254+
def test_service_fabric(self):
255+
self.assertEqual(get_managed_identity_source(), SERVICE_FABRIC)
256+
257+
@patch.dict(os.environ, {
258+
"IDENTITY_ENDPOINT": "http://localhost",
259+
"IDENTITY_HEADER": "foo",
260+
})
261+
def test_app_service(self):
262+
self.assertEqual(get_managed_identity_source(), APP_SERVICE)
263+
264+
@patch.dict(os.environ, {
265+
"MSI_ENDPOINT": "http://localhost",
266+
"MSI_SECRET": "foo",
267+
})
268+
def test_machine_learning(self):
269+
self.assertEqual(get_managed_identity_source(), MACHINE_LEARNING)
270+
271+
@patch.dict(os.environ, {
272+
"IDENTITY_ENDPOINT": "http://localhost",
273+
"IMDS_ENDPOINT": "http://localhost",
274+
})
275+
def test_arc(self):
276+
self.assertEqual(get_managed_identity_source(), AZURE_ARC)
277+
278+
@patch.dict(os.environ, {
279+
"AZUREPS_HOST_ENVIRONMENT": "cloud-shell-foo",
280+
})
281+
def test_cloud_shell(self):
282+
self.assertEqual(get_managed_identity_source(), CLOUD_SHELL)
283+
284+
def test_default_to_vm(self):
285+
self.assertEqual(get_managed_identity_source(), DEFAULT_TO_VM)
286+

0 commit comments

Comments
 (0)