Skip to content

Commit f8f6d69

Browse files
committed
PoC: Silent flow utilizes Cloud Shell IMDS
1 parent 84c5053 commit f8f6d69

File tree

2 files changed

+59
-0
lines changed

2 files changed

+59
-0
lines changed

msal/application.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
__version__ = "1.16.0"
3030

3131
logger = logging.getLogger(__name__)
32+
_CLOUD_SHELL_USER = "current_cloud_shell_user"
3233

3334

3435
def extract_certs(public_cert_content):
@@ -967,6 +968,19 @@ def get_accounts(self, username=None):
967968
# Even in the rare case when an RT is revoked and then removed,
968969
# acquire_token_silent() would then yield no result,
969970
# apps would fall back to other acquire methods. This is the standard pattern.
971+
if _is_running_in_cloud_shell():
972+
# In Cloud Shell, user already signed in with an account.
973+
# We pretend we have that account, for acquire_token_silent() to work.
974+
# Note: If user acquire_token_by_xyz() using that account in MSAL later,
975+
# the get_accounts() would return multiple accounts to calling app.
976+
accounts.insert(0, {
977+
"home_account_id": _CLOUD_SHELL_USER,
978+
"environment": "",
979+
"realm": "",
980+
"local_account_id": _CLOUD_SHELL_USER,
981+
"username": "Current Cloud Shell User",
982+
"authority_type": TokenCache.AuthorityType.MSSTS,
983+
})
970984
return accounts
971985

972986
def _find_msal_accounts(self, environment):
@@ -1125,6 +1139,12 @@ def acquire_token_silent_with_error(
11251139
"""
11261140
assert isinstance(scopes, list), "Invalid parameter type"
11271141
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
1142+
if account and account.get("home_account_id") == _CLOUD_SHELL_USER:
1143+
# Since we don't currently store cloud shell tokens in MSAL's cache,
1144+
# we can have a shortcut here, and semantically bypass all those
1145+
# _acquire_token_silent_from_cache_and_possibly_refresh_it()
1146+
return self._acquire_token_by_cloud_shell(
1147+
scopes, data=kwargs.get("data", {}))
11281148
correlation_id = msal.telemetry._get_new_correlation_id()
11291149
if authority:
11301150
warnings.warn("We haven't decided how/if this method will accept authority parameter")
@@ -1217,6 +1237,11 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
12171237
refresh_reason = msal.telemetry.FORCE_REFRESH # TODO: It could also mean claims_challenge
12181238
assert refresh_reason, "It should have been established at this point"
12191239
try:
1240+
## When/if we will store Cloud Shell tokens into MSAL's token cache,
1241+
# then we will add the following code snippet here.
1242+
#if account and account.get("home_account_id") == _CLOUD_SHELL_USER:
1243+
# result = self._acquire_token_by_cloud_shell(scopes, **kwargs)
1244+
#else:
12201245
result = _clean_up(self._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
12211246
authority, self._decorate_scope(scopes), account,
12221247
refresh_reason=refresh_reason, claims_challenge=claims_challenge,
@@ -1229,6 +1254,24 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
12291254
raise # We choose to bubble up the exception
12301255
return access_token_from_cache
12311256

1257+
def _acquire_token_by_cloud_shell(self, scopes, **kwargs):
1258+
kwargs.pop("correlation_id", None) # IMDS does not use correlation_id
1259+
resp = self.http_client.post(
1260+
"http://localhost:50342/oauth2/token",
1261+
data=dict(kwargs.pop("data", {}), resource=" ".join(scopes)),
1262+
headers=dict(kwargs.pop("headers", {}), Metadata="true"),
1263+
**kwargs)
1264+
if resp.status_code >= 300:
1265+
logger.debug("Cloud Shell IMDS error: %s", resp.text)
1266+
cs_error = json.loads(resp.text).get("error", {})
1267+
return {k: v for k, v in {
1268+
"error": cs_error.get("code"),
1269+
"error_description": cs_error.get("message"),
1270+
}.items() if v}
1271+
else:
1272+
# Skip token cache, for now. Cloud Shell IMDS has its own cache anyway.
1273+
return json.loads(resp.text)
1274+
12321275
def _acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
12331276
self, authority, scopes, account, **kwargs):
12341277
query = {

tests/test_e2e.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,9 @@ def test_ssh_cert_for_service_principal(self):
229229
self.assertEqual("ssh-cert", result["token_type"])
230230

231231
@unittest.skipIf(os.getenv("TRAVIS"), "Browser automation is not yet implemented")
232+
@unittest.skipIf(
233+
msal.application._is_running_in_cloud_shell(),
234+
"The test app does not opt in to Cloud Shell")
232235
def test_ssh_cert_for_user(self):
233236
result = self._test_acquire_token_interactive(
234237
client_id="04b07795-8ddb-461a-bbee-02f9e1bf7b46", # Azure CLI is one
@@ -253,6 +256,19 @@ def test_ssh_cert_for_user(self):
253256
self.assertEqual(refreshed_ssh_cert["token_type"], "ssh-cert")
254257
self.assertNotEqual(result["access_token"], refreshed_ssh_cert['access_token'])
255258

259+
@unittest.skipUnless(
260+
msal.application._is_running_in_cloud_shell(),
261+
"Manually run this by python -m unittest tests.test_e2e.SshCertTestCase")
262+
def test_ssh_cert_for_user_silent_inside_cloud_shell(self):
263+
app = msal.PublicClientApplication("client_id_wont_matter")
264+
accounts = app.get_accounts()
265+
self.assertNotEqual([], accounts)
266+
result = app.acquire_token_silent_with_error(
267+
self.SCOPE, account=accounts[0], data=self.DATA1)
268+
self.assertEqual(
269+
"ssh-cert", result.get("token_type"), "Unexpected result: %s" % result)
270+
self.assertIsNotNone(result.get("access_token"))
271+
256272

257273
THIS_FOLDER = os.path.dirname(__file__)
258274
CONFIG = os.path.join(THIS_FOLDER, "config.json")

0 commit comments

Comments
 (0)