Skip to content

Commit 072436b

Browse files
committed
PoC: Silent flow utilizes Cloud Shell IMDS
1 parent b3aa473 commit 072436b

File tree

2 files changed

+58
-0
lines changed

2 files changed

+58
-0
lines changed

msal/application.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
__version__ = "1.16.0"
3030

3131
logger = logging.getLogger(__name__)
32+
_CLOUD_SHELL_USER = "current_cloud_shell_user"
3233

3334

3435
def extract_certs(public_cert_content):
@@ -954,6 +955,19 @@ def get_accounts(self, username=None):
954955
# Even in the rare case when an RT is revoked and then removed,
955956
# acquire_token_silent() would then yield no result,
956957
# apps would fall back to other acquire methods. This is the standard pattern.
958+
if _is_running_in_cloud_shell():
959+
# In Cloud Shell, user already signed in with an account.
960+
# We pretend we have that account, for acquire_token_silent() to work.
961+
# Note: If user acquire_token_by_xyz() using that account in MSAL later,
962+
# the get_accounts() would return multiple accounts to calling app.
963+
accounts.insert(0, {
964+
"home_account_id": _CLOUD_SHELL_USER,
965+
"environment": "",
966+
"realm": "",
967+
"local_account_id": _CLOUD_SHELL_USER,
968+
"username": "Current Cloud Shell User",
969+
"authority_type": TokenCache.AuthorityType.MSSTS,
970+
})
957971
return accounts
958972

959973
def _find_msal_accounts(self, environment):
@@ -1112,6 +1126,12 @@ def acquire_token_silent_with_error(
11121126
"""
11131127
assert isinstance(scopes, list), "Invalid parameter type"
11141128
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
1129+
if account and account.get("home_account_id") == _CLOUD_SHELL_USER:
1130+
# Since we don't currently store cloud shell tokens in MSAL's cache,
1131+
# we can have a shortcut here, and semantically bypass all those
1132+
# _acquire_token_silent_from_cache_and_possibly_refresh_it()
1133+
return self._acquire_token_by_cloud_shell(
1134+
scopes, data=kwargs.get("data", {}))
11151135
correlation_id = msal.telemetry._get_new_correlation_id()
11161136
if authority:
11171137
warnings.warn("We haven't decided how/if this method will accept authority parameter")
@@ -1204,6 +1224,11 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
12041224
refresh_reason = msal.telemetry.FORCE_REFRESH # TODO: It could also mean claims_challenge
12051225
assert refresh_reason, "It should have been established at this point"
12061226
try:
1227+
## When/if we will store Cloud Shell tokens into MSAL's token cache,
1228+
# then we will add the following code snippet here.
1229+
#if account and account.get("home_account_id") == _CLOUD_SHELL_USER:
1230+
# result = self._acquire_token_by_cloud_shell(scopes, **kwargs)
1231+
#else:
12071232
result = _clean_up(self._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
12081233
authority, self._decorate_scope(scopes), account,
12091234
refresh_reason=refresh_reason, claims_challenge=claims_challenge,
@@ -1216,6 +1241,24 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
12161241
raise # We choose to bubble up the exception
12171242
return access_token_from_cache
12181243

1244+
def _acquire_token_by_cloud_shell(self, scopes, **kwargs):
1245+
kwargs.pop("correlation_id", None) # IMDS does not use correlation_id
1246+
resp = self.http_client.post(
1247+
"http://localhost:50342/oauth2/token",
1248+
data=dict(kwargs.pop("data", {}), resource=" ".join(scopes)),
1249+
headers=dict(kwargs.pop("headers", {}), Metadata="true"),
1250+
**kwargs)
1251+
if resp.status_code >= 300:
1252+
logger.debug("Cloud Shell IMDS error: %s", resp.text)
1253+
cs_error = json.loads(resp.text).get("error", {})
1254+
return {k: v for k, v in {
1255+
"error": cs_error.get("code"),
1256+
"error_description": cs_error.get("message"),
1257+
}.items() if v}
1258+
else:
1259+
# Skip token cache, for now. Cloud Shell IMDS has its own cache anyway.
1260+
return json.loads(resp.text)
1261+
12191262
def _acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
12201263
self, authority, scopes, account, **kwargs):
12211264
query = {

tests/test_e2e.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,9 @@ def test_ssh_cert_for_service_principal(self):
229229
self.assertEqual("ssh-cert", result["token_type"])
230230

231231
@unittest.skipIf(os.getenv("TRAVIS"), "Browser automation is not yet implemented")
232+
@unittest.skipIf(
233+
msal.application._is_running_in_cloud_shell(),
234+
"The test app does not opt in to Cloud Shell")
232235
def test_ssh_cert_for_user(self):
233236
result = self._test_acquire_token_interactive(
234237
client_id="04b07795-8ddb-461a-bbee-02f9e1bf7b46", # Azure CLI is one
@@ -253,6 +256,18 @@ def test_ssh_cert_for_user(self):
253256
self.assertEqual(refreshed_ssh_cert["token_type"], "ssh-cert")
254257
self.assertNotEqual(result["access_token"], refreshed_ssh_cert['access_token'])
255258

259+
@unittest.skipUnless(
260+
msal.application._is_running_in_cloud_shell(), "Test case for Cloud Shell IMDS")
261+
def test_ssh_cert_for_user_silent(self):
262+
app = msal.PublicClientApplication("client_id_wont_matter")
263+
accounts = app.get_accounts()
264+
self.assertNotEqual([], accounts)
265+
result = app.acquire_token_silent_with_error(
266+
self.SCOPE, account=accounts[0], data=self.DATA1)
267+
self.assertEqual(
268+
"ssh-cert", result.get("token_type"), "Unexpected result: %s" % result)
269+
self.assertIsNotNone(result.get("access_token"))
270+
256271

257272
THIS_FOLDER = os.path.dirname(__file__)
258273
CONFIG = os.path.join(THIS_FOLDER, "config.json")

0 commit comments

Comments
 (0)