|
16 | 16 | ManagedIdentityError,
|
17 | 17 | ArcPlatformNotSupportedError,
|
18 | 18 | )
|
19 |
| -from msal.managed_identity import _supported_arc_platforms_and_their_prefixes |
| 19 | +from msal.managed_identity import ( |
| 20 | + _supported_arc_platforms_and_their_prefixes, |
| 21 | + get_managed_identity_source, |
| 22 | + APP_SERVICE, |
| 23 | + AZURE_ARC, |
| 24 | + CLOUD_SHELL, |
| 25 | + MACHINE_LEARNING, |
| 26 | + SERVICE_FABRIC, |
| 27 | + DEFAULT_TO_VM, |
| 28 | +) |
20 | 29 |
|
21 | 30 |
|
22 | 31 | class ManagedIdentityTestCase(unittest.TestCase):
|
@@ -234,3 +243,44 @@ def test_arc_error_should_be_normalized(self, mocked_stat):
|
234 | 243 | if sys.platform in _supported_arc_platforms_and_their_prefixes:
|
235 | 244 | self.fail("Should not raise ArcPlatformNotSupportedError")
|
236 | 245 |
|
| 246 | + |
| 247 | +class GetManagedIdentitySourceTestCase(unittest.TestCase): |
| 248 | + |
| 249 | + @patch.dict(os.environ, { |
| 250 | + "IDENTITY_ENDPOINT": "http://localhost", |
| 251 | + "IDENTITY_HEADER": "foo", |
| 252 | + "IDENTITY_SERVER_THUMBPRINT": "bar", |
| 253 | + }) |
| 254 | + def test_service_fabric(self): |
| 255 | + self.assertEqual(get_managed_identity_source(), SERVICE_FABRIC) |
| 256 | + |
| 257 | + @patch.dict(os.environ, { |
| 258 | + "IDENTITY_ENDPOINT": "http://localhost", |
| 259 | + "IDENTITY_HEADER": "foo", |
| 260 | + }) |
| 261 | + def test_app_service(self): |
| 262 | + self.assertEqual(get_managed_identity_source(), APP_SERVICE) |
| 263 | + |
| 264 | + @patch.dict(os.environ, { |
| 265 | + "MSI_ENDPOINT": "http://localhost", |
| 266 | + "MSI_SECRET": "foo", |
| 267 | + }) |
| 268 | + def test_machine_learning(self): |
| 269 | + self.assertEqual(get_managed_identity_source(), MACHINE_LEARNING) |
| 270 | + |
| 271 | + @patch.dict(os.environ, { |
| 272 | + "IDENTITY_ENDPOINT": "http://localhost", |
| 273 | + "IMDS_ENDPOINT": "http://localhost", |
| 274 | + }) |
| 275 | + def test_arc(self): |
| 276 | + self.assertEqual(get_managed_identity_source(), AZURE_ARC) |
| 277 | + |
| 278 | + @patch.dict(os.environ, { |
| 279 | + "AZUREPS_HOST_ENVIRONMENT": "cloud-shell-foo", |
| 280 | + }) |
| 281 | + def test_cloud_shell(self): |
| 282 | + self.assertEqual(get_managed_identity_source(), CLOUD_SHELL) |
| 283 | + |
| 284 | + def test_default_to_vm(self): |
| 285 | + self.assertEqual(get_managed_identity_source(), DEFAULT_TO_VM) |
| 286 | + |
0 commit comments