|
4 | 4 | import time
|
5 | 5 | import unittest
|
6 | 6 | try:
|
7 |
| - from unittest.mock import patch, ANY, mock_open |
| 7 | + from unittest.mock import patch, ANY, mock_open, Mock |
8 | 8 | except:
|
9 |
| - from mock import patch, ANY, mock_open |
| 9 | + from mock import patch, ANY, mock_open, Mock |
10 | 10 | import requests
|
11 | 11 |
|
12 | 12 | from tests.http_client import MinimalResponse
|
13 | 13 | from msal import (
|
14 | 14 | SystemAssignedManagedIdentity, UserAssignedManagedIdentity,
|
15 | 15 | ManagedIdentityClient,
|
16 | 16 | ManagedIdentityError,
|
| 17 | + ArcPlatformNotSupportedError, |
17 | 18 | )
|
| 19 | +from msal.managed_identity import _supported_arc_platforms_and_their_prefixes |
18 | 20 |
|
19 | 21 |
|
20 | 22 | class ManagedIdentityTestCase(unittest.TestCase):
|
@@ -194,29 +196,41 @@ def test_sf_error_should_be_normalized(self):
|
194 | 196 | new=mock_open(read_data="secret"), # `new` requires no extra argument on the decorated function.
|
195 | 197 | # https://docs.python.org/3/library/unittest.mock.html#unittest.mock.patch
|
196 | 198 | )
|
| 199 | +@patch("os.stat", return_value=Mock(st_size=4096)) |
197 | 200 | class ArcTestCase(ClientTestCase):
|
198 | 201 | challenge = MinimalResponse(status_code=401, text="", headers={
|
199 | 202 | "WWW-Authenticate": "Basic realm=/tmp/foo",
|
200 | 203 | })
|
201 | 204 |
|
202 |
| - def test_happy_path(self): |
| 205 | + def test_happy_path(self, mocked_stat): |
203 | 206 | with patch.object(self.app._http_client, "get", side_effect=[
|
204 | 207 | self.challenge,
|
205 | 208 | MinimalResponse(
|
206 | 209 | status_code=200,
|
207 | 210 | text='{"access_token": "AT", "expires_in": "1234", "resource": "R"}',
|
208 | 211 | ),
|
209 | 212 | ]) as mocked_method:
|
210 |
| - super(ArcTestCase, self)._test_happy_path(self.app, mocked_method) |
211 |
| - |
212 |
| - def test_arc_error_should_be_normalized(self): |
| 213 | + try: |
| 214 | + super(ArcTestCase, self)._test_happy_path(self.app, mocked_method) |
| 215 | + mocked_stat.assert_called_with(os.path.join( |
| 216 | + _supported_arc_platforms_and_their_prefixes[sys.platform], |
| 217 | + "foo.key")) |
| 218 | + except ArcPlatformNotSupportedError: |
| 219 | + if sys.platform in _supported_arc_platforms_and_their_prefixes: |
| 220 | + self.fail("Should not raise ArcPlatformNotSupportedError") |
| 221 | + |
| 222 | + def test_arc_error_should_be_normalized(self, mocked_stat): |
213 | 223 | with patch.object(self.app._http_client, "get", side_effect=[
|
214 | 224 | self.challenge,
|
215 | 225 | MinimalResponse(status_code=400, text="undefined"),
|
216 | 226 | ]) as mocked_method:
|
217 |
| - self.assertEqual({ |
218 |
| - "error": "invalid_request", |
219 |
| - "error_description": "undefined", |
220 |
| - }, self.app.acquire_token_for_client(resource="R")) |
221 |
| - self.assertEqual({}, self.app._token_cache._cache) |
| 227 | + try: |
| 228 | + self.assertEqual({ |
| 229 | + "error": "invalid_request", |
| 230 | + "error_description": "undefined", |
| 231 | + }, self.app.acquire_token_for_client(resource="R")) |
| 232 | + self.assertEqual({}, self.app._token_cache._cache) |
| 233 | + except ArcPlatformNotSupportedError: |
| 234 | + if sys.platform in _supported_arc_platforms_and_their_prefixes: |
| 235 | + self.fail("Should not raise ArcPlatformNotSupportedError") |
222 | 236 |
|
0 commit comments