Skip to content

Commit ffea5ef

Browse files
committed
Abandon get_accounts(username=msal.CURRENT_USER) by acquire_token_interactive(..., prompt="none")
1 parent 5a10b39 commit ffea5ef

File tree

5 files changed

+101
-74
lines changed

5 files changed

+101
-74
lines changed

msal/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
ClientApplication,
3131
ConfidentialClientApplication,
3232
PublicClientApplication,
33-
CURRENT_USER,
3433
)
3534
from .oauth2cli.oidc import Prompt
3635
from .token_cache import TokenCache, SerializableTokenCache

msal/application.py

Lines changed: 29 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,13 @@
2222
from .region import _detect_region
2323
from .throttled_http_client import ThrottledHttpClient
2424
from .cloudshell import _is_running_in_cloud_shell
25-
from .cloudshell import _acquire_token as _acquire_token_by_cloud_shell
2625

2726

2827
# The __init__.py will import this. Not the other way around.
2928
__version__ = "1.17.0" # When releasing, also check and bump our dependencies's versions if needed
3029

3130
logger = logging.getLogger(__name__)
32-
CURRENT_USER = "Current User" # The value is subject to change
33-
_CLOUD_SHELL_USER = "current_cloud_shell_user" # The value is subject to change
34-
31+
_AUTHORITY_TYPE_CLOUDSHELL = "CLOUDSHELL"
3532

3633
def extract_certs(public_cert_content):
3734
# Parses raw public certificate file contents and returns a list of strings
@@ -118,8 +115,6 @@ def _preferred_browser():
118115
return None
119116

120117

121-
122-
123118
class _ClientWithCcsRoutingInfo(Client):
124119

125120
def initiate_auth_code_flow(self, **kwargs):
@@ -950,16 +945,6 @@ def get_accounts(self, username=None):
950945
Your app can choose to display those information to end user,
951946
and allow user to choose one of his/her accounts to proceed.
952947
"""
953-
cloud_shell_pseudo_account = {
954-
"home_account_id": _CLOUD_SHELL_USER,
955-
"environment": "",
956-
"realm": "",
957-
"local_account_id": _CLOUD_SHELL_USER,
958-
"username": CURRENT_USER,
959-
"authority_type": TokenCache.AuthorityType.MSSTS,
960-
}
961-
if _is_running_in_cloud_shell() and username == CURRENT_USER:
962-
return [cloud_shell_pseudo_account]
963948
accounts = self._find_msal_accounts(environment=self.authority.instance)
964949
if not accounts: # Now try other aliases of this authority instance
965950
for alias in self._get_authority_aliases(self.authority.instance):
@@ -978,13 +963,6 @@ def get_accounts(self, username=None):
978963
"they would contain no username for filtering. "
979964
"Consider calling get_accounts(username=None) instead."
980965
).format(username))
981-
if _is_running_in_cloud_shell() and not username:
982-
# In Cloud Shell, user already signed in w/ an account [email protected]
983-
# We pretend we have that account, for acquire_token_silent() to work.
984-
# Note: If user calls acquire_token_by_xyz() with same account later,
985-
# the get_accounts(username=None) would return multiple accounts,
986-
# with different usernames: [email protected] and CURRENT_USER.
987-
accounts.insert(0, cloud_shell_pseudo_account)
988966
# Does not further filter by existing RTs here. It probably won't matter.
989967
# Because in most cases Accounts and RTs co-exist.
990968
# Even in the rare case when an RT is revoked and then removed,
@@ -993,6 +971,10 @@ def get_accounts(self, username=None):
993971
return accounts
994972

995973
def _find_msal_accounts(self, environment):
974+
interested_authority_types = [
975+
TokenCache.AuthorityType.ADFS, TokenCache.AuthorityType.MSSTS]
976+
if _is_running_in_cloud_shell():
977+
interested_authority_types.append(_AUTHORITY_TYPE_CLOUDSHELL)
996978
grouped_accounts = {
997979
a.get("home_account_id"): # Grouped by home tenant's id
998980
{ # These are minimal amount of non-tenant-specific account info
@@ -1008,8 +990,7 @@ def _find_msal_accounts(self, environment):
1008990
for a in self.token_cache.find(
1009991
TokenCache.CredentialType.ACCOUNT,
1010992
query={"environment": environment})
1011-
if a["authority_type"] in (
1012-
TokenCache.AuthorityType.ADFS, TokenCache.AuthorityType.MSSTS)
993+
if a["authority_type"] in interested_authority_types
1013994
}
1014995
return list(grouped_accounts.values())
1015996

@@ -1069,6 +1050,21 @@ def _forget_me(self, home_account):
10691050
TokenCache.CredentialType.ACCOUNT, query=owned_by_home_account):
10701051
self.token_cache.remove_account(a)
10711052

1053+
def _acquire_token_by_cloud_shell(self, scopes, data=None):
1054+
from .cloudshell import _acquire_token
1055+
response = _acquire_token(
1056+
self.http_client, scopes, client_id=self.client_id, data=data)
1057+
if "error" not in response:
1058+
self.token_cache.add(dict(
1059+
client_id=self.client_id,
1060+
scope=response["scope"].split() if "scope" in response else scopes,
1061+
token_endpoint=self.authority.token_endpoint,
1062+
response=response.copy(),
1063+
data=data or {},
1064+
authority_type=_AUTHORITY_TYPE_CLOUDSHELL,
1065+
))
1066+
return response
1067+
10721068
def acquire_token_silent(
10731069
self,
10741070
scopes, # type: List[str]
@@ -1148,13 +1144,6 @@ def acquire_token_silent_with_error(
11481144
"""
11491145
assert isinstance(scopes, list), "Invalid parameter type"
11501146
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
1151-
1152-
# The special code path only for _CLOUD_SHELL_USER
1153-
if account and account.get("home_account_id") == _CLOUD_SHELL_USER:
1154-
# Since we don't currently store cloud shell tokens in MSAL's cache,
1155-
# we can have a shortcut here, and semantically bypass all those
1156-
# _acquire_token_silent_from_cache_and_possibly_refresh_it()
1157-
return _acquire_token_by_cloud_shell(self.http_client, scopes, **kwargs)
11581147
correlation_id = msal.telemetry._get_new_correlation_id()
11591148
if authority:
11601149
warnings.warn("We haven't decided how/if this method will accept authority parameter")
@@ -1209,6 +1198,7 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
12091198
authority, # This can be different than self.authority
12101199
force_refresh=False, # type: Optional[boolean]
12111200
claims_challenge=None,
1201+
correlation_id=None,
12121202
**kwargs):
12131203
access_token_from_cache = None
12141204
if not (force_refresh or claims_challenge): # Bypass AT when desired or using claims
@@ -1247,14 +1237,13 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
12471237
refresh_reason = msal.telemetry.FORCE_REFRESH # TODO: It could also mean claims_challenge
12481238
assert refresh_reason, "It should have been established at this point"
12491239
try:
1250-
## When/if we will store Cloud Shell tokens into MSAL's token cache,
1251-
# then we will add the following code snippet here.
1252-
#if account and account.get("home_account_id") == _CLOUD_SHELL_USER:
1253-
# result = _acquire_token_by_cloud_shell(..., scopes, **kwargs)
1254-
#else:
1240+
if account and account.get("authority_type") == _AUTHORITY_TYPE_CLOUDSHELL:
1241+
return self._acquire_token_by_cloud_shell(
1242+
scopes, data=kwargs.get("data"))
12551243
result = _clean_up(self._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
12561244
authority, self._decorate_scope(scopes), account,
12571245
refresh_reason=refresh_reason, claims_challenge=claims_challenge,
1246+
correlation_id=correlation_id,
12581247
**kwargs))
12591248
if (result and "error" not in result) or (not access_token_from_cache):
12601249
return result
@@ -1593,6 +1582,9 @@ def acquire_token_interactive(
15931582
- A dict containing an "error" key, when token refresh failed.
15941583
"""
15951584
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
1585+
if _is_running_in_cloud_shell() and prompt == "none":
1586+
return self._acquire_token_by_cloud_shell(
1587+
scopes, data=kwargs.pop("data", {}))
15961588
claims = _merge_claims_challenge_and_capabilities(
15971589
self._client_capabilities, claims_challenge)
15981590
telemetry_context = self._build_telemetry_context(

msal/cloudshell.py

Lines changed: 60 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,16 @@
44
# This code is licensed under the MIT License.
55

66
"""This module wraps Cloud Shell's IMDS-like interface inside an OAuth2-like helper"""
7+
import base64
78
import json
89
import logging
910
import os
11+
import time
1012
try: # Python 2
1113
from urlparse import urlparse
1214
except: # Python 3
1315
from urllib.parse import urlparse
16+
from .oauth2cli.oidc import decode_part
1417

1518

1619
logger = logging.getLogger(__name__)
@@ -34,33 +37,78 @@ def _scope_to_resource(scope): # This is an experimental reasonable-effort appr
3437
return scope # There is no much else we can do here
3538

3639

37-
def _acquire_token(http_client, scopes, **kwargs):
40+
def _acquire_token(http_client, scopes, client_id=None, data=None):
3841
resp = http_client.post(
3942
"http://localhost:50342/oauth2/token",
4043
data=dict(
41-
kwargs.pop("data", {}),
44+
data or {},
4245
resource=" ".join(map(_scope_to_resource, scopes))),
43-
headers=dict(kwargs.pop("headers", {}), Metadata="true"),
44-
**kwargs)
46+
headers={"Metadata": "true"},
47+
)
4548
if resp.status_code >= 300:
4649
logger.debug("Cloud Shell IMDS error: %s", resp.text)
4750
cs_error = json.loads(resp.text).get("error", {})
4851
return {k: v for k, v in {
4952
"error": cs_error.get("code"),
5053
"error_description": cs_error.get("message"),
5154
}.items() if v}
52-
payload = json.loads(resp.text)
55+
imds_payload = json.loads(resp.text)
5356
oauth2_response = {
54-
"access_token": payload["access_token"],
55-
"expires_in": int(payload["expires_in"]),
56-
"token_type": payload.get("token_type", "Bearer"),
57+
"access_token": imds_payload["access_token"],
58+
"expires_in": int(imds_payload["expires_in"]),
59+
"token_type": imds_payload.get("token_type", "Bearer"),
5760
}
61+
parts = imds_payload["access_token"].split(".")
62+
63+
# The following default values are useful in SSH Cert scenario
64+
client_info = { # Default value, in case the real value will be unavailable
65+
"uid": "user",
66+
"utid": "cloudshell",
67+
}
68+
now = time.time()
69+
preferred_username = "currentuser@cloudshell"
70+
oauth2_response["id_token_claims"] = { # First 5 claims are required per OIDC
71+
"iss": "cloudshell",
72+
"sub": "user",
73+
"aud": client_id,
74+
"exp": now + 3600,
75+
"iat": now,
76+
"preferred_username": preferred_username, # Useful as MSAL account's username
77+
}
78+
79+
if len(parts) == 3: # Probably a JWT. Use it to derive client_info and id token.
80+
try:
81+
# Data defined in https://docs.microsoft.com/en-us/azure/active-directory/develop/access-tokens#payload-claims
82+
jwt_payload = json.loads(decode_part(parts[1]))
83+
client_info = {
84+
# Mimic a real home_account_id,
85+
# so that this pseudo account and a real account would interop.
86+
"uid": jwt_payload.get("oid", "user"),
87+
"utid": jwt_payload.get("tid", "cloudshell"),
88+
}
89+
oauth2_response["id_token_claims"] = {
90+
"iss": jwt_payload["iss"],
91+
"sub": jwt_payload["sub"], # Could use oid instead
92+
"aud": client_id,
93+
"exp": jwt_payload["exp"],
94+
"iat": jwt_payload["iat"],
95+
"preferred_username": jwt_payload.get("preferred_username") # V2
96+
or jwt_payload.get("unique_name") # V1
97+
or preferred_username,
98+
}
99+
except ValueError:
100+
logger.debug("Unable to decode jwt payload: %s", parts[1])
101+
oauth2_response["client_info"] = base64.b64encode(
102+
# Mimic a client_info, so that MSAL would create an account
103+
json.dumps(client_info).encode("utf-8")).decode("utf-8")
104+
oauth2_response["id_token_claims"]["tid"] = client_info["utid"] # TBD
105+
58106
## Note: Decided to not surface resource back as scope,
59107
## because they would cause the downstream OAuth2 code path to
60108
## cache the token with a different scope and won't hit them later.
61-
#if payload.get("resource"):
62-
# oauth2_response["scope"] = payload["resource"]
63-
if payload.get("refresh_token"):
64-
oauth2_response["refresh_token"] = payload["refresh_token"]
109+
#if imds_payload.get("resource"):
110+
# oauth2_response["scope"] = imds_payload["resource"]
111+
if imds_payload.get("refresh_token"):
112+
oauth2_response["refresh_token"] = imds_payload["refresh_token"]
65113
return oauth2_response
66114

msal/token_cache.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def wipe(dictionary, sensitive_fields): # Masks sensitive info
113113
return self.__add(event, now=now)
114114
finally:
115115
wipe(event.get("response", {}), ( # These claims were useful during __add()
116+
"id_token_claims", # Provided by broker
116117
"access_token", "refresh_token", "id_token", "username"))
117118
wipe(event, ["username"]) # Needed for federated ROPC
118119
logger.debug("event=%s", json.dumps(
@@ -150,7 +151,8 @@ def __add(self, event, now=None):
150151
id_token = response.get("id_token")
151152
id_token_claims = (
152153
decode_id_token(id_token, client_id=event["client_id"])
153-
if id_token else {})
154+
if id_token
155+
else response.get("id_token_claims", {})) # Broker would provide id_token_claims
154156
client_info, home_account_id = self.__parse_account(response, id_token_claims)
155157

156158
target = ' '.join(event.get("scope") or []) # Per schema, we don't sort it
@@ -195,9 +197,10 @@ def __add(self, event, now=None):
195197
or data.get("username") # Falls back to ROPC username
196198
or event.get("username") # Falls back to Federated ROPC username
197199
or "", # The schema does not like null
198-
"authority_type":
200+
"authority_type": event.get(
201+
"authority_type", # Honor caller's choice of authority_type
199202
self.AuthorityType.ADFS if realm == "adfs"
200-
else self.AuthorityType.MSSTS,
203+
else self.AuthorityType.MSSTS),
201204
# "client_info": response.get("client_info"), # Optional
202205
}
203206
self.modify(self.CredentialType.ACCOUNT, account, account)

tests/test_e2e.py

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -185,12 +185,14 @@ def _test_acquire_token_interactive(
185185
self, client_id=None, authority=None, scope=None, port=None,
186186
username_uri="", # But you would want to provide one
187187
data=None, # Needed by ssh-cert feature
188+
prompt=None,
188189
**ignored):
189190
assert client_id and authority and scope
190191
self.app = msal.PublicClientApplication(
191192
client_id, authority=authority, http_client=MinimalHttpClient())
192193
result = self.app.acquire_token_interactive(
193194
scope,
195+
prompt=prompt,
194196
timeout=120,
195197
port=port,
196198
welcome_template= # This is an undocumented feature for testing
@@ -229,9 +231,6 @@ def test_ssh_cert_for_service_principal(self):
229231
self.assertEqual("ssh-cert", result["token_type"])
230232

231233
@unittest.skipIf(os.getenv("TRAVIS"), "Browser automation is not yet implemented")
232-
@unittest.skipIf(
233-
msal.application._is_running_in_cloud_shell(),
234-
"The test app does not opt in to Cloud Shell")
235234
def test_ssh_cert_for_user(self):
236235
result = self._test_acquire_token_interactive(
237236
client_id="04b07795-8ddb-461a-bbee-02f9e1bf7b46", # Azure CLI is one
@@ -240,6 +239,7 @@ def test_ssh_cert_for_user(self):
240239
scope=self.SCOPE,
241240
data=self.DATA1,
242241
username_uri="https://msidlab.com/api/user?usertype=cloud",
242+
prompt="none" if msal.application._is_running_in_cloud_shell() else None,
243243
) # It already tests reading AT from cache, and using RT to refresh
244244
# acquire_token_silent() would work because we pass in the same key
245245
self.assertIsNotNone(result.get("access_token"), "Encountered {}: {}".format(
@@ -256,31 +256,16 @@ def test_ssh_cert_for_user(self):
256256
self.assertEqual(refreshed_ssh_cert["token_type"], "ssh-cert")
257257
self.assertNotEqual(result["access_token"], refreshed_ssh_cert['access_token'])
258258

259-
@unittest.skipUnless(
260-
msal.application._is_running_in_cloud_shell(),
261-
"Manually run this test case from inside Cloud Shell")
262-
def test_ssh_cert_for_user_silent_inside_cloud_shell(self):
263-
app = msal.PublicClientApplication("client_id_wont_matter")
264-
accounts = app.get_accounts()
265-
self.assertNotEqual([], accounts)
266-
result = app.acquire_token_silent_with_error(
267-
self.SCOPE, account=accounts[0], data=self.DATA1)
268-
self.assertEqual(
269-
"ssh-cert", result.get("token_type"), "Unexpected result: %s" % result)
270-
self.assertIsNotNone(result.get("access_token"))
271-
272259

273260
@unittest.skipUnless(
274261
msal.application._is_running_in_cloud_shell(),
275262
"Manually run this test case from inside Cloud Shell")
276263
class CloudShellTestCase(E2eTestCase):
277-
app = msal.PublicClientApplication("client_id_wont_matter")
264+
app = msal.PublicClientApplication("client_id")
278265
scope_that_requires_no_managed_device = "https://management.core.windows.net/" # Scopes came from https://msazure.visualstudio.com/One/_git/compute-CloudShell?path=/src/images/agent/env/envconfig.PROD.json&version=GBmaster&_a=contents
279266
def test_access_token_should_be_obtained_for_a_supported_scope(self):
280-
accounts = self.app.get_accounts(username=msal.CURRENT_USER)
281-
self.assertNotEqual([], accounts)
282-
result = self.app.acquire_token_silent_with_error(
283-
[self.scope_that_requires_no_managed_device], account=accounts[0])
267+
result = self.app.acquire_token_interactive(
268+
[self.scope_that_requires_no_managed_device], prompt="none")
284269
self.assertEqual(
285270
"Bearer", result.get("token_type"), "Unexpected result: %s" % result)
286271
self.assertIsNotNone(result.get("access_token"))

0 commit comments

Comments
 (0)