Skip to content

Commit 86d5976

Browse files
committed
Introduce get_accounts(username=msal.CURRENT_USER)
1 parent 2b12c6a commit 86d5976

File tree

4 files changed

+104
-40
lines changed

4 files changed

+104
-40
lines changed

msal/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
ClientApplication,
3131
ConfidentialClientApplication,
3232
PublicClientApplication,
33+
CURRENT_USER,
3334
)
3435
from .oauth2cli.oidc import Prompt
3536
from .token_cache import TokenCache, SerializableTokenCache

msal/application.py

Lines changed: 23 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,16 @@
2323
import msal.telemetry
2424
from .region import _detect_region
2525
from .throttled_http_client import ThrottledHttpClient
26+
from .cloudshell import _is_running_in_cloud_shell
27+
from .cloudshell import _acquire_token as _acquire_token_by_cloud_shell
2628

2729

2830
# The __init__.py will import this. Not the other way around.
2931
__version__ = "1.16.0"
3032

3133
logger = logging.getLogger(__name__)
32-
_CLOUD_SHELL_USER = "current_cloud_shell_user"
34+
CURRENT_USER = "Current User" # The value is subject to change
35+
_CLOUD_SHELL_USER = "current_cloud_shell_user" # The value is subject to change
3336

3437

3538
def extract_certs(public_cert_content):
@@ -113,8 +116,6 @@ def _preferred_browser():
113116
return None
114117

115118

116-
def _is_running_in_cloud_shell():
117-
return os.environ.get("AZUREPS_HOST_ENVIRONMENT", "").startswith("cloud-shell")
118119

119120

120121
class _ClientWithCcsRoutingInfo(Client):
@@ -945,6 +946,16 @@ def get_accounts(self, username=None):
945946
Your app can choose to display those information to end user,
946947
and allow user to choose one of his/her accounts to proceed.
947948
"""
949+
cloud_shell_pseudo_account = {
950+
"home_account_id": _CLOUD_SHELL_USER,
951+
"environment": "",
952+
"realm": "",
953+
"local_account_id": _CLOUD_SHELL_USER,
954+
"username": CURRENT_USER,
955+
"authority_type": TokenCache.AuthorityType.MSSTS,
956+
}
957+
if _is_running_in_cloud_shell() and username == CURRENT_USER:
958+
return [cloud_shell_pseudo_account]
948959
accounts = self._find_msal_accounts(environment=self.authority.instance)
949960
if not accounts: # Now try other aliases of this authority instance
950961
for alias in self._get_authority_aliases(self.authority.instance):
@@ -963,24 +974,18 @@ def get_accounts(self, username=None):
963974
"they would contain no username for filtering. "
964975
"Consider calling get_accounts(username=None) instead."
965976
).format(username))
977+
if _is_running_in_cloud_shell() and not username:
978+
# In Cloud Shell, user already signed in w/ an account [email protected]
979+
# We pretend we have that account, for acquire_token_silent() to work.
980+
# Note: If user calls acquire_token_by_xyz() with same account later,
981+
# the get_accounts() would return multiple accounts to calling app,
982+
# with different usernames: [email protected] and CURRENT_USER.
983+
accounts.insert(0, cloud_shell_pseudo_account)
966984
# Does not further filter by existing RTs here. It probably won't matter.
967985
# Because in most cases Accounts and RTs co-exist.
968986
# Even in the rare case when an RT is revoked and then removed,
969987
# acquire_token_silent() would then yield no result,
970988
# apps would fall back to other acquire methods. This is the standard pattern.
971-
if _is_running_in_cloud_shell():
972-
# In Cloud Shell, user already signed in with an account.
973-
# We pretend we have that account, for acquire_token_silent() to work.
974-
# Note: If user acquire_token_by_xyz() using that account in MSAL later,
975-
# the get_accounts() would return multiple accounts to calling app.
976-
accounts.insert(0, {
977-
"home_account_id": _CLOUD_SHELL_USER,
978-
"environment": "",
979-
"realm": "",
980-
"local_account_id": _CLOUD_SHELL_USER,
981-
"username": "Current Cloud Shell User",
982-
"authority_type": TokenCache.AuthorityType.MSSTS,
983-
})
984989
return accounts
985990

986991
def _find_msal_accounts(self, environment):
@@ -1158,8 +1163,7 @@ def acquire_token_silent_with_error(
11581163
# Since we don't currently store cloud shell tokens in MSAL's cache,
11591164
# we can have a shortcut here, and semantically bypass all those
11601165
# _acquire_token_silent_from_cache_and_possibly_refresh_it()
1161-
return self._acquire_token_by_cloud_shell(
1162-
scopes, data=kwargs.get("data", {}))
1166+
return _acquire_token_by_cloud_shell(self.http_client, scopes, **kwargs)
11631167
correlation_id = msal.telemetry._get_new_correlation_id()
11641168
if authority:
11651169
warnings.warn("We haven't decided how/if this method will accept authority parameter")
@@ -1255,7 +1259,7 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
12551259
## When/if we will store Cloud Shell tokens into MSAL's token cache,
12561260
# then we will add the following code snippet here.
12571261
#if account and account.get("home_account_id") == _CLOUD_SHELL_USER:
1258-
# result = self._acquire_token_by_cloud_shell(scopes, **kwargs)
1262+
# result = _acquire_token_by_cloud_shell(..., scopes, **kwargs)
12591263
#else:
12601264
result = _clean_up(self._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
12611265
authority, self._decorate_scope(scopes), account,
@@ -1269,24 +1273,6 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
12691273
raise # We choose to bubble up the exception
12701274
return access_token_from_cache
12711275

1272-
def _acquire_token_by_cloud_shell(self, scopes, **kwargs):
1273-
kwargs.pop("correlation_id", None) # IMDS does not use correlation_id
1274-
resp = self.http_client.post(
1275-
"http://localhost:50342/oauth2/token",
1276-
data=dict(kwargs.pop("data", {}), resource=" ".join(scopes)),
1277-
headers=dict(kwargs.pop("headers", {}), Metadata="true"),
1278-
**kwargs)
1279-
if resp.status_code >= 300:
1280-
logger.debug("Cloud Shell IMDS error: %s", resp.text)
1281-
cs_error = json.loads(resp.text).get("error", {})
1282-
return {k: v for k, v in {
1283-
"error": cs_error.get("code"),
1284-
"error_description": cs_error.get("message"),
1285-
}.items() if v}
1286-
else:
1287-
# Skip token cache, for now. Cloud Shell IMDS has its own cache anyway.
1288-
return json.loads(resp.text)
1289-
12901276
def _acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
12911277
self, authority, scopes, account, **kwargs):
12921278
query = {

msal/cloudshell.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# All rights reserved.
3+
#
4+
# This code is licensed under the MIT License.
5+
6+
"""This module wraps Cloud Shell's IMDS-like interface inside an OAuth2-like helper"""
7+
import json
8+
import logging
9+
import os
10+
11+
12+
logger = logging.getLogger(__name__)
13+
14+
15+
def _is_running_in_cloud_shell():
16+
return os.environ.get("AZUREPS_HOST_ENVIRONMENT", "").startswith("cloud-shell")
17+
18+
19+
def _scope_to_resource(scope):
20+
cloud_shell_supported_audiences = [ # Came from https://msazure.visualstudio.com/One/_git/compute-CloudShell?path=/src/images/agent/env/envconfig.PROD.json
21+
"https://management.core.windows.net/",
22+
"https://management.azure.com/",
23+
"https://graph.windows.net/",
24+
"https://vault.azure.net",
25+
"https://datalake.azure.net/",
26+
"https://outlook.office365.com/",
27+
"https://graph.microsoft.com/",
28+
"https://batch.core.windows.net/",
29+
"https://analysis.windows.net/powerbi/api",
30+
"https://storage.azure.com/",
31+
"https://rest.media.azure.net",
32+
"https://api.loganalytics.io",
33+
"https://ossrdbms-aad.database.windows.net",
34+
"https://www.yammer.com",
35+
"https://digitaltwins.azure.net",
36+
"0b07f429-9f4b-4714-9392-cc5e8e80c8b0",
37+
"822c8694-ad95-4735-9c55-256f7db2f9b4",
38+
"https://dev.azuresynapse.net",
39+
"https://database.windows.net",
40+
"https://quantum.microsoft.com",
41+
"https://iothubs.azure.net",
42+
"2ff814a6-3304-4ab8-85cb-cd0e6f879c1d",
43+
"https://azuredatabricks.net/",
44+
"ce34e7e5-485f-4d76-964f-b3d2b16d1e4f",
45+
"https://azure-devices-provisioning.net"
46+
]
47+
for a in cloud_shell_supported_audiences:
48+
if scope.startswith(a): # This is an experimental approach
49+
return a
50+
return scope # Some scope would work as-is, such as the SSH Cert scope
51+
52+
53+
def _acquire_token(http_client, scopes, **kwargs):
54+
kwargs.pop("correlation_id", None) # IMDS does not use correlation_id
55+
resp = http_client.post(
56+
"http://localhost:50342/oauth2/token",
57+
data=dict(
58+
kwargs.pop("data", {}),
59+
resource=" ".join(map(_scope_to_resource, scopes))),
60+
headers=dict(kwargs.pop("headers", {}), Metadata="true"),
61+
**kwargs)
62+
if resp.status_code >= 300:
63+
logger.debug("Cloud Shell IMDS error: %s", resp.text)
64+
cs_error = json.loads(resp.text).get("error", {})
65+
return {k: v for k, v in {
66+
"error": cs_error.get("code"),
67+
"error_description": cs_error.get("message"),
68+
}.items() if v}
69+
payload = json.loads(resp.text)
70+
oauth2_response = {
71+
"access_token": payload["access_token"],
72+
"expires_in": int(payload["expires_in"]),
73+
"token_type": payload.get("token_type", "Bearer"),
74+
}
75+
if payload.get("refresh_token"):
76+
oauth2_response["refresh_token"] = payload["refresh_token"]
77+
return oauth2_response
78+

tests/test_e2e.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -275,10 +275,9 @@ def test_ssh_cert_for_user_silent_inside_cloud_shell(self):
275275
"Manually run this test case from inside Cloud Shell")
276276
class CloudShellTestCase(E2eTestCase):
277277
app = msal.PublicClientApplication("client_id_wont_matter")
278-
# Scopes came from https://msazure.visualstudio.com/One/_git/compute-CloudShell?path=/src/images/agent/env/envconfig.PROD.json&version=GBmaster&_a=contents
279-
scope_that_requires_no_managed_device = "https://management.core.windows.net/"
278+
scope_that_requires_no_managed_device = "https://management.core.windows.net/" # Scopes came from https://msazure.visualstudio.com/One/_git/compute-CloudShell?path=/src/images/agent/env/envconfig.PROD.json&version=GBmaster&_a=contents
280279
def test_access_token_should_be_obtained_for_a_supported_scope(self):
281-
accounts = self.app.get_accounts()
280+
accounts = self.app.get_accounts(username=msal.CURRENT_USER)
282281
self.assertNotEqual([], accounts)
283282
result = self.app.acquire_token_silent_with_error(
284283
[self.scope_that_requires_no_managed_device], account=accounts[0])

0 commit comments

Comments
 (0)