Skip to content

Commit 8b8310b

Browse files
committed
Unittest for VM, AppService
1 parent dd75a28 commit 8b8310b

File tree

2 files changed

+101
-3
lines changed

2 files changed

+101
-3
lines changed

msal/imds.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def _obtain_token_on_azure_vm(http_client, resource, client_id=None):
5252
"resource": payload.get("resource"),
5353
"token_type": payload.get("token_type", "Bearer"),
5454
}
55-
return payload # Typically an error
55+
return payload # Typically an error, but it is undefined in the doc above
5656
except ValueError:
5757
logger.debug("IMDS emits unexpected payload: %s", resp.text)
5858
raise
@@ -123,12 +123,13 @@ def __init__(self, http_client, client_id=None, token_cache=None):
123123

124124
def acquire_token(self, resource):
125125
access_token_from_cache = None
126+
client_id_in_cache = self._client_id or "SYSTEM_ASSIGNED_MANAGED_IDENTITY"
126127
if self._token_cache:
127128
matches = self._token_cache.find(
128129
self._token_cache.CredentialType.ACCESS_TOKEN,
129130
target=[resource],
130131
query=dict(
131-
client_id=self._client_id,
132+
client_id=client_id_in_cache,
132133
environment=self._instance,
133134
realm=self._tenant,
134135
home_account_id=None,
@@ -151,7 +152,7 @@ def acquire_token(self, resource):
151152
result = _obtain_token(self._http_client, resource, client_id=self._client_id)
152153
if self._token_cache and "access_token" in result:
153154
self._token_cache.add(dict(
154-
client_id=self._client_id,
155+
client_id=client_id_in_cache,
155156
scope=[resource],
156157
token_endpoint="https://{}/{}".format(self._instance, self._tenant),
157158
response=result,

tests/test_mi.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import json
2+
import os
3+
import time
4+
import unittest
5+
try:
6+
from unittest.mock import patch, ANY
7+
except:
8+
from mock import patch, ANY
9+
import requests
10+
11+
from tests.http_client import MinimalResponse
12+
from msal import TokenCache, ManagedIdentity
13+
14+
15+
class ManagedIdentityTestCase(unittest.TestCase):
16+
maxDiff = None
17+
18+
def _test_token_cache(self, app):
19+
cache = app._token_cache._cache
20+
self.assertEqual(1, len(cache.get("AccessToken", [])), "Should have 1 AT")
21+
at = list(cache["AccessToken"].values())[0]
22+
self.assertEqual(
23+
app._client_id or "SYSTEM_ASSIGNED_MANAGED_IDENTITY",
24+
at["client_id"],
25+
"Should have expected client_id")
26+
self.assertEqual("managed_identity", at["realm"], "Should have expected realm")
27+
28+
29+
class VmTestCase(ManagedIdentityTestCase):
30+
31+
def test_happy_path(self):
32+
app = ManagedIdentity(requests.Session(), token_cache=TokenCache())
33+
with patch.object(app._http_client, "get", return_value=MinimalResponse(
34+
status_code=200,
35+
text='{"access_token": "AT", "expires_in": "1234", "resource": "R"}',
36+
)) as mocked_method:
37+
result = app.acquire_token("R")
38+
mocked_method.assert_called_once()
39+
self.assertEqual({
40+
"access_token": "AT",
41+
"expires_in": 1234,
42+
"resource": "R",
43+
"token_type": "Bearer",
44+
}, result, "Should obtain a token response")
45+
self.assertEqual(
46+
result["access_token"], app.acquire_token("R").get("access_token"),
47+
"Should hit the same token from cache")
48+
self._test_token_cache(app)
49+
50+
def test_vm_error_should_be_returned_as_is(self):
51+
raw_error = '{"raw": "error format is undefined"}'
52+
app = ManagedIdentity(requests.Session(), token_cache=TokenCache())
53+
with patch.object(app._http_client, "get", return_value=MinimalResponse(
54+
status_code=400,
55+
text=raw_error,
56+
)) as mocked_method:
57+
self.assertEqual(json.loads(raw_error), app.acquire_token("R"))
58+
self.assertEqual({}, app._token_cache._cache)
59+
60+
61+
@patch.dict(os.environ, {"IDENTITY_ENDPOINT": "http://localhost", "IDENTITY_HEADER": "foo"})
62+
class AppServiceTestCase(ManagedIdentityTestCase):
63+
64+
def test_happy_path(self):
65+
# TODO: Combine this with VM's test case, and move it into base class
66+
app = ManagedIdentity(requests.Session(), token_cache=TokenCache())
67+
now = int(time.time())
68+
with patch.object(app._http_client, "get", return_value=MinimalResponse(
69+
status_code=200,
70+
text='{"access_token": "AT", "expires_on": "%s", "resource": "R"}' % (now + 100),
71+
)) as mocked_method:
72+
result = app.acquire_token("R")
73+
mocked_method.assert_called_once()
74+
self.assertEqual({
75+
"access_token": "AT",
76+
"expires_in": 100,
77+
"resource": "R",
78+
"token_type": "Bearer",
79+
}, result, "Should obtain a token response")
80+
self.assertEqual(
81+
result["access_token"], app.acquire_token("R").get("access_token"),
82+
"Should hit the same token from cache")
83+
self._test_token_cache(app)
84+
85+
def test_app_service_error_should_be_normalized(self):
86+
raw_error = '{"statusCode": 500, "message": "error content is undefined"}'
87+
app = ManagedIdentity(requests.Session(), token_cache=TokenCache())
88+
with patch.object(app._http_client, "get", return_value=MinimalResponse(
89+
status_code=500,
90+
text=raw_error,
91+
)) as mocked_method:
92+
self.assertEqual({
93+
"error": "invalid_scope",
94+
"error_description": "500, error content is undefined",
95+
}, app.acquire_token("R"))
96+
self.assertEqual({}, app._token_cache._cache)
97+

0 commit comments

Comments
 (0)