Skip to content

Implementation based on feature requirement #706

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions msal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
SystemAssignedManagedIdentity, UserAssignedManagedIdentity,
ManagedIdentityClient,
ManagedIdentityError,
ArcPlatformNotSupportedError,
)

# Putting module-level exceptions into the package namespace, to make them
Expand Down
19 changes: 18 additions & 1 deletion msal/managed_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging
import os
import socket
import sys
import time
from urllib.parse import urlparse # Python 3+
from collections import UserDict # Python 3+
Expand Down Expand Up @@ -490,6 +491,14 @@ def _obtain_token_on_service_fabric(
raise


_supported_arc_platforms_and_their_prefixes = {
"linux": "/var/opt/azcmagent/tokens",
"win32": os.path.expandvars(r"%ProgramData%\AzureConnectedMachineAgent\Tokens"),
}

class ArcPlatformNotSupportedError(ManagedIdentityError):
pass

def _obtain_token_on_arc(http_client, endpoint, resource):
# https://learn.microsoft.com/en-us/azure/azure-arc/servers/managed-identity-authentication
logger.debug("Obtaining token via managed identity on Azure Arc")
Expand All @@ -508,7 +517,15 @@ def _obtain_token_on_arc(http_client, endpoint, resource):
len(challenge) == 2 and challenge[0].lower() == "basic realm"):
raise ManagedIdentityError(
"Unrecognizable WWW-Authenticate header: {}".format(resp.headers))
with open(challenge[1]) as f:
if sys.platform not in _supported_arc_platforms_and_their_prefixes:
raise ArcPlatformNotSupportedError(
f"Platform {sys.platform} was undefined and unsupported")
filename = os.path.join(
_supported_arc_platforms_and_their_prefixes[sys.platform],
os.path.splitext(os.path.basename(challenge[1]))[0] + ".key")
if os.stat(filename).st_size > 4096: # Check size BEFORE loading its content
raise ManagedIdentityError("Local key file shall not be larger than 4KB")
with open(filename) as f:
secret = f.read()
response = http_client.get(
endpoint,
Expand Down
36 changes: 25 additions & 11 deletions tests/test_mi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,19 @@
import time
import unittest
try:
from unittest.mock import patch, ANY, mock_open
from unittest.mock import patch, ANY, mock_open, Mock
except:
from mock import patch, ANY, mock_open
from mock import patch, ANY, mock_open, Mock
import requests

from tests.http_client import MinimalResponse
from msal import (
SystemAssignedManagedIdentity, UserAssignedManagedIdentity,
ManagedIdentityClient,
ManagedIdentityError,
ArcPlatformNotSupportedError,
)
from msal.managed_identity import _supported_arc_platforms_and_their_prefixes


class ManagedIdentityTestCase(unittest.TestCase):
Expand Down Expand Up @@ -194,29 +196,41 @@ def test_sf_error_should_be_normalized(self):
new=mock_open(read_data="secret"), # `new` requires no extra argument on the decorated function.
# https://docs.python.org/3/library/unittest.mock.html#unittest.mock.patch
)
@patch("os.stat", return_value=Mock(st_size=4096))
class ArcTestCase(ClientTestCase):
challenge = MinimalResponse(status_code=401, text="", headers={
"WWW-Authenticate": "Basic realm=/tmp/foo",
})

def test_happy_path(self):
def test_happy_path(self, mocked_stat):
with patch.object(self.app._http_client, "get", side_effect=[
self.challenge,
MinimalResponse(
status_code=200,
text='{"access_token": "AT", "expires_in": "1234", "resource": "R"}',
),
]) as mocked_method:
super(ArcTestCase, self)._test_happy_path(self.app, mocked_method)

def test_arc_error_should_be_normalized(self):
try:
super(ArcTestCase, self)._test_happy_path(self.app, mocked_method)
mocked_stat.assert_called_with(os.path.join(
_supported_arc_platforms_and_their_prefixes[sys.platform],
"foo.key"))
except ArcPlatformNotSupportedError:
if sys.platform in _supported_arc_platforms_and_their_prefixes:
self.fail("Should not raise ArcPlatformNotSupportedError")

def test_arc_error_should_be_normalized(self, mocked_stat):
with patch.object(self.app._http_client, "get", side_effect=[
self.challenge,
MinimalResponse(status_code=400, text="undefined"),
]) as mocked_method:
self.assertEqual({
"error": "invalid_request",
"error_description": "undefined",
}, self.app.acquire_token_for_client(resource="R"))
self.assertEqual({}, self.app._token_cache._cache)
try:
self.assertEqual({
"error": "invalid_request",
"error_description": "undefined",
}, self.app.acquire_token_for_client(resource="R"))
self.assertEqual({}, self.app._token_cache._cache)
except ArcPlatformNotSupportedError:
if sys.platform in _supported_arc_platforms_and_their_prefixes:
self.fail("Should not raise ArcPlatformNotSupportedError")

Loading