Skip to content

Commit 0753928

Browse files
committed
Implementation based on feature requirement
Error out on platforms other than Linux and Windows
1 parent 9350391 commit 0753928

File tree

3 files changed

+44
-12
lines changed

3 files changed

+44
-12
lines changed

msal/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
SystemAssignedManagedIdentity, UserAssignedManagedIdentity,
3939
ManagedIdentityClient,
4040
ManagedIdentityError,
41+
ArcPlatformNotSupportedError,
4142
)
4243

4344
# Putting module-level exceptions into the package namespace, to make them

msal/managed_identity.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import logging
77
import os
88
import socket
9+
import sys
910
import time
1011
from urllib.parse import urlparse # Python 3+
1112
from collections import UserDict # Python 3+
@@ -490,6 +491,14 @@ def _obtain_token_on_service_fabric(
490491
raise
491492

492493

494+
_supported_arc_platforms_and_their_prefixes = {
495+
"linux": "/var/opt/azcmagent/tokens",
496+
"win32": os.path.expandvars(r"%ProgramData%\AzureConnectedMachineAgent\Tokens"),
497+
}
498+
499+
class ArcPlatformNotSupportedError(ManagedIdentityError):
500+
pass
501+
493502
def _obtain_token_on_arc(http_client, endpoint, resource):
494503
# https://learn.microsoft.com/en-us/azure/azure-arc/servers/managed-identity-authentication
495504
logger.debug("Obtaining token via managed identity on Azure Arc")
@@ -508,7 +517,15 @@ def _obtain_token_on_arc(http_client, endpoint, resource):
508517
len(challenge) == 2 and challenge[0].lower() == "basic realm"):
509518
raise ManagedIdentityError(
510519
"Unrecognizable WWW-Authenticate header: {}".format(resp.headers))
511-
with open(challenge[1]) as f:
520+
if sys.platform not in _supported_arc_platforms_and_their_prefixes:
521+
raise ArcPlatformNotSupportedError(
522+
f"Platform {sys.platform} was undefined and unsupported")
523+
filename = os.path.join(
524+
_supported_arc_platforms_and_their_prefixes[sys.platform],
525+
os.path.splitext(os.path.basename(challenge[1]))[0] + ".key")
526+
if os.stat(filename).st_size > 4096: # Check size BEFORE loading its content
527+
raise ManagedIdentityError("Local key file shall not be larger than 4KB")
528+
with open(filename) as f:
512529
secret = f.read()
513530
response = http_client.get(
514531
endpoint,

tests/test_mi.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,19 @@
44
import time
55
import unittest
66
try:
7-
from unittest.mock import patch, ANY, mock_open
7+
from unittest.mock import patch, ANY, mock_open, Mock
88
except:
9-
from mock import patch, ANY, mock_open
9+
from mock import patch, ANY, mock_open, Mock
1010
import requests
1111

1212
from tests.http_client import MinimalResponse
1313
from msal import (
1414
SystemAssignedManagedIdentity, UserAssignedManagedIdentity,
1515
ManagedIdentityClient,
1616
ManagedIdentityError,
17+
ArcPlatformNotSupportedError,
1718
)
19+
from msal.managed_identity import _supported_arc_platforms_and_their_prefixes
1820

1921

2022
class ManagedIdentityTestCase(unittest.TestCase):
@@ -194,29 +196,41 @@ def test_sf_error_should_be_normalized(self):
194196
new=mock_open(read_data="secret"), # `new` requires no extra argument on the decorated function.
195197
# https://docs.python.org/3/library/unittest.mock.html#unittest.mock.patch
196198
)
199+
@patch("os.stat", return_value=Mock(st_size=4096))
197200
class ArcTestCase(ClientTestCase):
198201
challenge = MinimalResponse(status_code=401, text="", headers={
199202
"WWW-Authenticate": "Basic realm=/tmp/foo",
200203
})
201204

202-
def test_happy_path(self):
205+
def test_happy_path(self, mocked_stat):
203206
with patch.object(self.app._http_client, "get", side_effect=[
204207
self.challenge,
205208
MinimalResponse(
206209
status_code=200,
207210
text='{"access_token": "AT", "expires_in": "1234", "resource": "R"}',
208211
),
209212
]) as mocked_method:
210-
super(ArcTestCase, self)._test_happy_path(self.app, mocked_method)
211-
212-
def test_arc_error_should_be_normalized(self):
213+
try:
214+
super(ArcTestCase, self)._test_happy_path(self.app, mocked_method)
215+
mocked_stat.assert_called_with(os.path.join(
216+
_supported_arc_platforms_and_their_prefixes[sys.platform],
217+
"foo.key"))
218+
except ArcPlatformNotSupportedError:
219+
if sys.platform in _supported_arc_platforms_and_their_prefixes:
220+
self.fail("Should not raise ArcPlatformNotSupportedError")
221+
222+
def test_arc_error_should_be_normalized(self, mocked_stat):
213223
with patch.object(self.app._http_client, "get", side_effect=[
214224
self.challenge,
215225
MinimalResponse(status_code=400, text="undefined"),
216226
]) as mocked_method:
217-
self.assertEqual({
218-
"error": "invalid_request",
219-
"error_description": "undefined",
220-
}, self.app.acquire_token_for_client(resource="R"))
221-
self.assertEqual({}, self.app._token_cache._cache)
227+
try:
228+
self.assertEqual({
229+
"error": "invalid_request",
230+
"error_description": "undefined",
231+
}, self.app.acquire_token_for_client(resource="R"))
232+
self.assertEqual({}, self.app._token_cache._cache)
233+
except ArcPlatformNotSupportedError:
234+
if sys.platform in _supported_arc_platforms_and_their_prefixes:
235+
self.fail("Should not raise ArcPlatformNotSupportedError")
222236

0 commit comments

Comments
 (0)