Skip to content

Commit 1b26417

Browse files
authored
Merge pull request #420 from AzureAD/cloudshell-imds
acquire_token_interactive(..., prompt="none") acquires token via Cloud Shell's IMDS-like interface
2 parents c7e81ba + 292e28b commit 1b26417

File tree

4 files changed

+175
-6
lines changed

4 files changed

+175
-6
lines changed

msal/application.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,14 @@
2121
import msal.telemetry
2222
from .region import _detect_region
2323
from .throttled_http_client import ThrottledHttpClient
24+
from .cloudshell import _is_running_in_cloud_shell
2425

2526

2627
# The __init__.py will import this. Not the other way around.
2728
__version__ = "1.17.0" # When releasing, also check and bump our dependencies's versions if needed
2829

2930
logger = logging.getLogger(__name__)
30-
31+
_AUTHORITY_TYPE_CLOUDSHELL = "CLOUDSHELL"
3132

3233
def extract_certs(public_cert_content):
3334
# Parses raw public certificate file contents and returns a list of strings
@@ -986,6 +987,10 @@ def get_accounts(self, username=None):
986987
return accounts
987988

988989
def _find_msal_accounts(self, environment):
990+
interested_authority_types = [
991+
TokenCache.AuthorityType.ADFS, TokenCache.AuthorityType.MSSTS]
992+
if _is_running_in_cloud_shell():
993+
interested_authority_types.append(_AUTHORITY_TYPE_CLOUDSHELL)
989994
grouped_accounts = {
990995
a.get("home_account_id"): # Grouped by home tenant's id
991996
{ # These are minimal amount of non-tenant-specific account info
@@ -1001,8 +1006,7 @@ def _find_msal_accounts(self, environment):
10011006
for a in self.token_cache.find(
10021007
TokenCache.CredentialType.ACCOUNT,
10031008
query={"environment": environment})
1004-
if a["authority_type"] in (
1005-
TokenCache.AuthorityType.ADFS, TokenCache.AuthorityType.MSSTS)
1009+
if a["authority_type"] in interested_authority_types
10061010
}
10071011
return list(grouped_accounts.values())
10081012

@@ -1062,6 +1066,21 @@ def _forget_me(self, home_account):
10621066
TokenCache.CredentialType.ACCOUNT, query=owned_by_home_account):
10631067
self.token_cache.remove_account(a)
10641068

1069+
def _acquire_token_by_cloud_shell(self, scopes, data=None):
1070+
from .cloudshell import _obtain_token
1071+
response = _obtain_token(
1072+
self.http_client, scopes, client_id=self.client_id, data=data)
1073+
if "error" not in response:
1074+
self.token_cache.add(dict(
1075+
client_id=self.client_id,
1076+
scope=response["scope"].split() if "scope" in response else scopes,
1077+
token_endpoint=self.authority.token_endpoint,
1078+
response=response.copy(),
1079+
data=data or {},
1080+
authority_type=_AUTHORITY_TYPE_CLOUDSHELL,
1081+
))
1082+
return response
1083+
10651084
def acquire_token_silent(
10661085
self,
10671086
scopes, # type: List[str]
@@ -1195,6 +1214,7 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
11951214
authority, # This can be different than self.authority
11961215
force_refresh=False, # type: Optional[boolean]
11971216
claims_challenge=None,
1217+
correlation_id=None,
11981218
**kwargs):
11991219
access_token_from_cache = None
12001220
if not (force_refresh or claims_challenge): # Bypass AT when desired or using claims
@@ -1233,9 +1253,13 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
12331253
refresh_reason = msal.telemetry.FORCE_REFRESH # TODO: It could also mean claims_challenge
12341254
assert refresh_reason, "It should have been established at this point"
12351255
try:
1256+
if account and account.get("authority_type") == _AUTHORITY_TYPE_CLOUDSHELL:
1257+
return self._acquire_token_by_cloud_shell(
1258+
scopes, data=kwargs.get("data"))
12361259
result = _clean_up(self._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
12371260
authority, self._decorate_scope(scopes), account,
12381261
refresh_reason=refresh_reason, claims_challenge=claims_challenge,
1262+
correlation_id=correlation_id,
12391263
**kwargs))
12401264
if (result and "error" not in result) or (not access_token_from_cache):
12411265
return result
@@ -1574,6 +1598,9 @@ def acquire_token_interactive(
15741598
- A dict containing an "error" key, when token refresh failed.
15751599
"""
15761600
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
1601+
if _is_running_in_cloud_shell() and prompt == "none":
1602+
return self._acquire_token_by_cloud_shell(
1603+
scopes, data=kwargs.pop("data", {}))
15771604
claims = _merge_claims_challenge_and_capabilities(
15781605
self._client_capabilities, claims_challenge)
15791606
telemetry_context = self._build_telemetry_context(

msal/cloudshell.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# All rights reserved.
3+
#
4+
# This code is licensed under the MIT License.
5+
6+
"""This module wraps Cloud Shell's IMDS-like interface inside an OAuth2-like helper"""
7+
import base64
8+
import json
9+
import logging
10+
import os
11+
import time
12+
try: # Python 2
13+
from urlparse import urlparse
14+
except: # Python 3
15+
from urllib.parse import urlparse
16+
from .oauth2cli.oidc import decode_part
17+
18+
19+
logger = logging.getLogger(__name__)
20+
21+
22+
def _is_running_in_cloud_shell():
23+
return os.environ.get("AZUREPS_HOST_ENVIRONMENT", "").startswith("cloud-shell")
24+
25+
26+
def _scope_to_resource(scope): # This is an experimental reasonable-effort approach
27+
cloud_shell_supported_audiences = [
28+
"https://analysis.windows.net/powerbi/api", # Came from https://msazure.visualstudio.com/One/_git/compute-CloudShell?path=/src/images/agent/env/envconfig.PROD.json
29+
"https://pas.windows.net/CheckMyAccess/Linux/.default", # Cloud Shell accepts it as-is
30+
]
31+
for a in cloud_shell_supported_audiences:
32+
if scope.startswith(a):
33+
return a
34+
u = urlparse(scope)
35+
if u.scheme:
36+
return "{}://{}".format(u.scheme, u.netloc)
37+
return scope # There is no much else we can do here
38+
39+
40+
def _obtain_token(http_client, scopes, client_id=None, data=None):
41+
resp = http_client.post(
42+
"http://localhost:50342/oauth2/token",
43+
data=dict(
44+
data or {},
45+
resource=" ".join(map(_scope_to_resource, scopes))),
46+
headers={"Metadata": "true"},
47+
)
48+
if resp.status_code >= 300:
49+
logger.debug("Cloud Shell IMDS error: %s", resp.text)
50+
cs_error = json.loads(resp.text).get("error", {})
51+
return {k: v for k, v in {
52+
"error": cs_error.get("code"),
53+
"error_description": cs_error.get("message"),
54+
}.items() if v}
55+
imds_payload = json.loads(resp.text)
56+
BEARER = "Bearer"
57+
oauth2_response = {
58+
"access_token": imds_payload["access_token"],
59+
"expires_in": int(imds_payload["expires_in"]),
60+
"token_type": imds_payload.get("token_type", BEARER),
61+
}
62+
expected_token_type = (data or {}).get("token_type", BEARER)
63+
if oauth2_response["token_type"] != expected_token_type:
64+
return { # Generate a normal error (rather than an intrusive exception)
65+
"error": "broker_error",
66+
"error_description": "token_type {} is not supported by this version of Azure Portal".format(
67+
expected_token_type),
68+
}
69+
parts = imds_payload["access_token"].split(".")
70+
71+
# The following default values are useful in SSH Cert scenario
72+
client_info = { # Default value, in case the real value will be unavailable
73+
"uid": "user",
74+
"utid": "cloudshell",
75+
}
76+
now = time.time()
77+
preferred_username = "currentuser@cloudshell"
78+
oauth2_response["id_token_claims"] = { # First 5 claims are required per OIDC
79+
"iss": "cloudshell",
80+
"sub": "user",
81+
"aud": client_id,
82+
"exp": now + 3600,
83+
"iat": now,
84+
"preferred_username": preferred_username, # Useful as MSAL account's username
85+
}
86+
87+
if len(parts) == 3: # Probably a JWT. Use it to derive client_info and id token.
88+
try:
89+
# Data defined in https://docs.microsoft.com/en-us/azure/active-directory/develop/access-tokens#payload-claims
90+
jwt_payload = json.loads(decode_part(parts[1]))
91+
client_info = {
92+
# Mimic a real home_account_id,
93+
# so that this pseudo account and a real account would interop.
94+
"uid": jwt_payload.get("oid", "user"),
95+
"utid": jwt_payload.get("tid", "cloudshell"),
96+
}
97+
oauth2_response["id_token_claims"] = {
98+
"iss": jwt_payload["iss"],
99+
"sub": jwt_payload["sub"], # Could use oid instead
100+
"aud": client_id,
101+
"exp": jwt_payload["exp"],
102+
"iat": jwt_payload["iat"],
103+
"preferred_username": jwt_payload.get("preferred_username") # V2
104+
or jwt_payload.get("unique_name") # V1
105+
or preferred_username,
106+
}
107+
except ValueError:
108+
logger.debug("Unable to decode jwt payload: %s", parts[1])
109+
oauth2_response["client_info"] = base64.b64encode(
110+
# Mimic a client_info, so that MSAL would create an account
111+
json.dumps(client_info).encode("utf-8")).decode("utf-8")
112+
oauth2_response["id_token_claims"]["tid"] = client_info["utid"] # TBD
113+
114+
## Note: Decided to not surface resource back as scope,
115+
## because they would cause the downstream OAuth2 code path to
116+
## cache the token with a different scope and won't hit them later.
117+
#if imds_payload.get("resource"):
118+
# oauth2_response["scope"] = imds_payload["resource"]
119+
if imds_payload.get("refresh_token"):
120+
oauth2_response["refresh_token"] = imds_payload["refresh_token"]
121+
return oauth2_response
122+

msal/token_cache.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def wipe(dictionary, sensitive_fields): # Masks sensitive info
113113
return self.__add(event, now=now)
114114
finally:
115115
wipe(event.get("response", {}), ( # These claims were useful during __add()
116+
"id_token_claims", # Provided by broker
116117
"access_token", "refresh_token", "id_token", "username"))
117118
wipe(event, ["username"]) # Needed for federated ROPC
118119
logger.debug("event=%s", json.dumps(
@@ -150,7 +151,8 @@ def __add(self, event, now=None):
150151
id_token = response.get("id_token")
151152
id_token_claims = (
152153
decode_id_token(id_token, client_id=event["client_id"])
153-
if id_token else {})
154+
if id_token
155+
else response.get("id_token_claims", {})) # Broker would provide id_token_claims
154156
client_info, home_account_id = self.__parse_account(response, id_token_claims)
155157

156158
target = ' '.join(event.get("scope") or []) # Per schema, we don't sort it
@@ -195,9 +197,10 @@ def __add(self, event, now=None):
195197
or data.get("username") # Falls back to ROPC username
196198
or event.get("username") # Falls back to Federated ROPC username
197199
or "", # The schema does not like null
198-
"authority_type":
200+
"authority_type": event.get(
201+
"authority_type", # Honor caller's choice of authority_type
199202
self.AuthorityType.ADFS if realm == "adfs"
200-
else self.AuthorityType.MSSTS,
203+
else self.AuthorityType.MSSTS),
201204
# "client_info": response.get("client_info"), # Optional
202205
}
203206
self.modify(self.CredentialType.ACCOUNT, account, account)

tests/test_e2e.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,12 +185,14 @@ def _test_acquire_token_interactive(
185185
self, client_id=None, authority=None, scope=None, port=None,
186186
username_uri="", # But you would want to provide one
187187
data=None, # Needed by ssh-cert feature
188+
prompt=None,
188189
**ignored):
189190
assert client_id and authority and scope
190191
self.app = msal.PublicClientApplication(
191192
client_id, authority=authority, http_client=MinimalHttpClient())
192193
result = self.app.acquire_token_interactive(
193194
scope,
195+
prompt=prompt,
194196
timeout=120,
195197
port=port,
196198
welcome_template= # This is an undocumented feature for testing
@@ -237,6 +239,7 @@ def test_ssh_cert_for_user(self):
237239
scope=self.SCOPE,
238240
data=self.DATA1,
239241
username_uri="https://msidlab.com/api/user?usertype=cloud",
242+
prompt="none" if msal.application._is_running_in_cloud_shell() else None,
240243
) # It already tests reading AT from cache, and using RT to refresh
241244
# acquire_token_silent() would work because we pass in the same key
242245
self.assertIsNotNone(result.get("access_token"), "Encountered {}: {}".format(
@@ -254,6 +257,20 @@ def test_ssh_cert_for_user(self):
254257
self.assertNotEqual(result["access_token"], refreshed_ssh_cert['access_token'])
255258

256259

260+
@unittest.skipUnless(
261+
msal.application._is_running_in_cloud_shell(),
262+
"Manually run this test case from inside Cloud Shell")
263+
class CloudShellTestCase(E2eTestCase):
264+
app = msal.PublicClientApplication("client_id")
265+
scope_that_requires_no_managed_device = "https://management.core.windows.net/" # Scopes came from https://msazure.visualstudio.com/One/_git/compute-CloudShell?path=/src/images/agent/env/envconfig.PROD.json&version=GBmaster&_a=contents
266+
def test_access_token_should_be_obtained_for_a_supported_scope(self):
267+
result = self.app.acquire_token_interactive(
268+
[self.scope_that_requires_no_managed_device], prompt="none")
269+
self.assertEqual(
270+
"Bearer", result.get("token_type"), "Unexpected result: %s" % result)
271+
self.assertIsNotNone(result.get("access_token"))
272+
273+
257274
THIS_FOLDER = os.path.dirname(__file__)
258275
CONFIG = os.path.join(THIS_FOLDER, "config.json")
259276
@unittest.skipUnless(os.path.exists(CONFIG), "Optional %s not found" % CONFIG)

0 commit comments

Comments
 (0)