Skip to content

Commit 75601db

Browse files
committed
Support Service Fabric
1 parent 23e2e15 commit 75601db

File tree

2 files changed

+114
-30
lines changed

2 files changed

+114
-30
lines changed

msal/imds.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,18 @@ def _scope_to_resource(scope): # This is an experimental reasonable-effort appr
2222

2323

2424
def _obtain_token(http_client, resource, client_id=None, object_id=None, mi_res_id=None):
25+
if ("IDENTITY_ENDPOINT" in os.environ and "IDENTITY_HEADER" in os.environ
26+
and "IDENTITY_SERVER_THUMBPRINT" in os.environ
27+
):
28+
if client_id or object_id or mi_res_id:
29+
logger.debug(
30+
"Ignoring client_id/object_id/mi_res_id. "
31+
"Managed Identity in Service Fabric is configured in the cluster, "
32+
"not during runtime. See also "
33+
"https://learn.microsoft.com/en-us/azure/service-fabric/configure-existing-cluster-enable-managed-identity-token-service")
34+
return _obtain_token_on_service_fabric(
35+
http_client, os.environ["IDENTITY_ENDPOINT"], os.environ["IDENTITY_HEADER"],
36+
os.environ["IDENTITY_SERVER_THUMBPRINT"], resource)
2537
if "IDENTITY_ENDPOINT" in os.environ and "IDENTITY_HEADER" in os.environ:
2638
return _obtain_token_on_app_service(
2739
http_client, os.environ["IDENTITY_ENDPOINT"], os.environ["IDENTITY_HEADER"],
@@ -69,7 +81,8 @@ def _obtain_token_on_app_service(http_client, endpoint, identity_header, resourc
6981
client_id=None, object_id=None, mi_res_id=None,
7082
):
7183
"""Obtains token for
72-
`App Service <https://learn.microsoft.com/en-us/azure/app-service/overview-managed-identity?tabs=portal%2Chttp#rest-endpoint-reference>`_
84+
`App Service <https://learn.microsoft.com/en-us/azure/app-service/overview-managed-identity?tabs=portal%2Chttp#rest-endpoint-reference>`_,
85+
Azure Functions, and Azure Automation.
7386
"""
7487
# Prerequisite: Create your app service https://docs.microsoft.com/en-us/azure/app-service/quickstart-python
7588
# Assign it a managed identity https://docs.microsoft.com/en-us/azure/app-service/overview-managed-identity?tabs=portal%2Chttp
@@ -114,6 +127,46 @@ def _obtain_token_on_app_service(http_client, endpoint, identity_header, resourc
114127
raise
115128

116129

130+
def _obtain_token_on_service_fabric(
131+
http_client, endpoint, identity_header, server_thumbprint, resource,
132+
):
133+
"""Obtains token for
134+
`Service Fabric <https://learn.microsoft.com/en-us/azure/service-fabric/>`_
135+
"""
136+
# Deployment https://learn.microsoft.com/en-us/azure/service-fabric/service-fabric-get-started-containers-linux
137+
# See also https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/identity/azure-identity/tests/managed-identity-live/service-fabric/service_fabric.md
138+
# Protocol https://learn.microsoft.com/en-us/azure/service-fabric/how-to-managed-identity-service-fabric-app-code#acquiring-an-access-token-using-rest-api
139+
logger.debug("Obtaining token via managed identity on Azure Service Fabric")
140+
resp = http_client.get(
141+
endpoint,
142+
params={"api-version": "2019-07-01-preview", "resource": resource},
143+
headers={"Secret": identity_header},
144+
)
145+
try:
146+
payload = json.loads(resp.text)
147+
if payload.get("access_token") and payload.get("expires_on"):
148+
return { # Normalizing the payload into OAuth2 format
149+
"access_token": payload["access_token"],
150+
"expires_in": payload["expires_on"] - int(time.time()),
151+
"resource": payload.get("resource"),
152+
"token_type": payload["token_type"],
153+
}
154+
error = payload.get("error", {}) # https://learn.microsoft.com/en-us/azure/service-fabric/how-to-managed-identity-service-fabric-app-code#error-handling
155+
error_mapping = { # Map Service Fabric errors into OAuth2 errors https://www.rfc-editor.org/rfc/rfc6749#section-5.2
156+
"SecretHeaderNotFound": "unauthorized_client",
157+
"ManagedIdentityNotFound": "invalid_client",
158+
"ArgumentNullOrEmpty": "invalid_scope",
159+
}
160+
return {
161+
"error": error_mapping.get(payload["error"]["code"], "invalid_request"),
162+
"error_description": resp.text,
163+
}
164+
except ValueError:
165+
logger.debug("IMDS emits unexpected payload: %s", resp.text)
166+
raise
167+
168+
169+
117170
class ManagedIdentity(object):
118171
_instance, _tenant = socket.getfqdn(), "managed_identity" # Placeholders
119172

tests/test_mi.py

Lines changed: 60 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,21 @@ def _test_token_cache(self, app):
2525
"Should have expected client_id")
2626
self.assertEqual("managed_identity", at["realm"], "Should have expected realm")
2727

28+
def _test_happy_path(self, app, mocked_http):
29+
result = app.acquire_token(resource="R")
30+
mocked_http.assert_called_once()
31+
self.assertEqual({
32+
"access_token": "AT",
33+
"expires_in": 1234,
34+
"resource": "R",
35+
"token_type": "Bearer",
36+
}, result, "Should obtain a token response")
37+
self.assertEqual(
38+
result["access_token"],
39+
app.acquire_token(resource="R").get("access_token"),
40+
"Should hit the same token from cache")
41+
self._test_token_cache(app)
42+
2843

2944
class VmTestCase(ManagedIdentityTestCase):
3045

@@ -34,19 +49,7 @@ def test_happy_path(self):
3449
status_code=200,
3550
text='{"access_token": "AT", "expires_in": "1234", "resource": "R"}',
3651
)) as mocked_method:
37-
result = app.acquire_token(resource="R")
38-
mocked_method.assert_called_once()
39-
self.assertEqual({
40-
"access_token": "AT",
41-
"expires_in": 1234,
42-
"resource": "R",
43-
"token_type": "Bearer",
44-
}, result, "Should obtain a token response")
45-
self.assertEqual(
46-
result["access_token"],
47-
app.acquire_token(resource="R").get("access_token"),
48-
"Should hit the same token from cache")
49-
self._test_token_cache(app)
52+
self._test_happy_path(app, mocked_method)
5053

5154
def test_vm_error_should_be_returned_as_is(self):
5255
raw_error = '{"raw": "error format is undefined"}'
@@ -63,26 +66,13 @@ def test_vm_error_should_be_returned_as_is(self):
6366
class AppServiceTestCase(ManagedIdentityTestCase):
6467

6568
def test_happy_path(self):
66-
# TODO: Combine this with VM's test case, and move it into base class
6769
app = ManagedIdentity(requests.Session(), token_cache=TokenCache())
68-
now = int(time.time())
6970
with patch.object(app._http_client, "get", return_value=MinimalResponse(
7071
status_code=200,
71-
text='{"access_token": "AT", "expires_on": "%s", "resource": "R"}' % (now + 100),
72+
text='{"access_token": "AT", "expires_on": "%s", "resource": "R"}' % (
73+
int(time.time()) + 1234),
7274
)) as mocked_method:
73-
result = app.acquire_token(resource="R")
74-
mocked_method.assert_called_once()
75-
self.assertEqual({
76-
"access_token": "AT",
77-
"expires_in": 100,
78-
"resource": "R",
79-
"token_type": "Bearer",
80-
}, result, "Should obtain a token response")
81-
self.assertEqual(
82-
result["access_token"],
83-
app.acquire_token(resource="R").get("access_token"),
84-
"Should hit the same token from cache")
85-
self._test_token_cache(app)
75+
self._test_happy_path(app, mocked_method)
8676

8777
def test_app_service_error_should_be_normalized(self):
8878
raw_error = '{"statusCode": 500, "message": "error content is undefined"}'
@@ -97,3 +87,44 @@ def test_app_service_error_should_be_normalized(self):
9787
}, app.acquire_token(resource="R"))
9888
self.assertEqual({}, app._token_cache._cache)
9989

90+
@patch.dict(os.environ, {
91+
"IDENTITY_ENDPOINT": "http://localhost",
92+
"IDENTITY_HEADER": "foo",
93+
"IDENTITY_SERVER_THUMBPRINT": "bar",
94+
})
95+
class ServiceFabricTestCase(ManagedIdentityTestCase):
96+
97+
def _test_happy_path(self, app):
98+
with patch.object(app._http_client, "get", return_value=MinimalResponse(
99+
status_code=200,
100+
text='{"access_token": "AT", "expires_on": %s, "resource": "R", "token_type": "Bearer"}' % (
101+
int(time.time()) + 1234),
102+
)) as mocked_method:
103+
super(ServiceFabricTestCase, self)._test_happy_path(app, mocked_method)
104+
105+
def test_happy_path(self):
106+
self._test_happy_path(ManagedIdentity(
107+
requests.Session(), token_cache=TokenCache()))
108+
109+
def test_unified_api_service_should_ignore_unnecessary_client_id(self):
110+
self._test_happy_path(ManagedIdentity(
111+
requests.Session(), client_id="foo", token_cache=TokenCache()))
112+
113+
def test_app_service_error_should_be_normalized(self):
114+
raw_error = '''
115+
{"error": {
116+
"correlationId": "foo",
117+
"code": "SecretHeaderNotFound",
118+
"message": "Secret is not found in the request headers."
119+
}}''' # https://learn.microsoft.com/en-us/azure/service-fabric/how-to-managed-identity-service-fabric-app-code#error-handling
120+
app = ManagedIdentity(requests.Session(), token_cache=TokenCache())
121+
with patch.object(app._http_client, "get", return_value=MinimalResponse(
122+
status_code=404,
123+
text=raw_error,
124+
)) as mocked_method:
125+
self.assertEqual({
126+
"error": "unauthorized_client",
127+
"error_description": raw_error,
128+
}, app.acquire_token(resource="R"))
129+
self.assertEqual({}, app._token_cache._cache)
130+

0 commit comments

Comments
 (0)