Skip to content

acquire_token_interactive(..., prompt="none") acquires token via Cloud Shell's IMDS-like interface #420

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions msal/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,14 @@
import msal.telemetry
from .region import _detect_region
from .throttled_http_client import ThrottledHttpClient
from .cloudshell import _is_running_in_cloud_shell


# The __init__.py will import this. Not the other way around.
__version__ = "1.17.0" # When releasing, also check and bump our dependencies's versions if needed

logger = logging.getLogger(__name__)

_AUTHORITY_TYPE_CLOUDSHELL = "CLOUDSHELL"

def extract_certs(public_cert_content):
# Parses raw public certificate file contents and returns a list of strings
Expand Down Expand Up @@ -986,6 +987,10 @@ def get_accounts(self, username=None):
return accounts

def _find_msal_accounts(self, environment):
interested_authority_types = [
TokenCache.AuthorityType.ADFS, TokenCache.AuthorityType.MSSTS]
if _is_running_in_cloud_shell():
interested_authority_types.append(_AUTHORITY_TYPE_CLOUDSHELL)
grouped_accounts = {
a.get("home_account_id"): # Grouped by home tenant's id
{ # These are minimal amount of non-tenant-specific account info
Expand All @@ -1001,8 +1006,7 @@ def _find_msal_accounts(self, environment):
for a in self.token_cache.find(
TokenCache.CredentialType.ACCOUNT,
query={"environment": environment})
if a["authority_type"] in (
TokenCache.AuthorityType.ADFS, TokenCache.AuthorityType.MSSTS)
if a["authority_type"] in interested_authority_types
}
return list(grouped_accounts.values())

Expand Down Expand Up @@ -1062,6 +1066,21 @@ def _forget_me(self, home_account):
TokenCache.CredentialType.ACCOUNT, query=owned_by_home_account):
self.token_cache.remove_account(a)

def _acquire_token_by_cloud_shell(self, scopes, data=None):
from .cloudshell import _obtain_token
response = _obtain_token(
self.http_client, scopes, client_id=self.client_id, data=data)
if "error" not in response:
self.token_cache.add(dict(
client_id=self.client_id,
scope=response["scope"].split() if "scope" in response else scopes,
token_endpoint=self.authority.token_endpoint,
response=response.copy(),
data=data or {},
authority_type=_AUTHORITY_TYPE_CLOUDSHELL,
))
Comment on lines +1074 to +1081
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Cloud Shell's pseudo managed identity endpoint already has a cache, so does a normal managed identity endpoint. Do we really need to save it to MSAL's cache?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The devil is in the details.

I think Cloud Shell's pseudo managed identity endpoint already has a cache,

The SSH Cert feature has its own quirk in caching. Currently, the Cloud Shell (actually the Azure Portal) does not yet really support caching for SSH Cert.

so does a normal managed identity endpoint.

Yes but there is a catch. Quoted from normal managed identity's document: "The managed identities subsystem caches tokens but we still recommend that you implement token caching in your code." They even defined an HTTP 429 throttling error when you use Managed Identity endpoint too often. FWIW, I'm currently working on another project that is trying to enable MSAL token cache for managed identity.

Do we really need to save it to MSAL's cache?

The MSAL cache has long been designed to be used as often as you want, and it also stores SSH cert. Adding this (technically still a) one-liner here can allow existing MSAL-powered apps' acquire_token_silent() to work, for free (meaning without any extra code change in app's code). Why not?

return response

def acquire_token_silent(
self,
scopes, # type: List[str]
Expand Down Expand Up @@ -1195,6 +1214,7 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
authority, # This can be different than self.authority
force_refresh=False, # type: Optional[boolean]
claims_challenge=None,
correlation_id=None,
**kwargs):
access_token_from_cache = None
if not (force_refresh or claims_challenge): # Bypass AT when desired or using claims
Expand Down Expand Up @@ -1233,9 +1253,13 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
refresh_reason = msal.telemetry.FORCE_REFRESH # TODO: It could also mean claims_challenge
assert refresh_reason, "It should have been established at this point"
try:
if account and account.get("authority_type") == _AUTHORITY_TYPE_CLOUDSHELL:
return self._acquire_token_by_cloud_shell(
scopes, data=kwargs.get("data"))
result = _clean_up(self._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
authority, self._decorate_scope(scopes), account,
refresh_reason=refresh_reason, claims_challenge=claims_challenge,
correlation_id=correlation_id,
**kwargs))
if (result and "error" not in result) or (not access_token_from_cache):
return result
Expand Down Expand Up @@ -1574,6 +1598,9 @@ def acquire_token_interactive(
- A dict containing an "error" key, when token refresh failed.
"""
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
if _is_running_in_cloud_shell() and prompt == "none":
return self._acquire_token_by_cloud_shell(
scopes, data=kwargs.pop("data", {}))
claims = _merge_claims_challenge_and_capabilities(
self._client_capabilities, claims_challenge)
telemetry_context = self._build_telemetry_context(
Expand Down
122 changes: 122 additions & 0 deletions msal/cloudshell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright (c) Microsoft Corporation.
# All rights reserved.
#
# This code is licensed under the MIT License.

"""This module wraps Cloud Shell's IMDS-like interface inside an OAuth2-like helper"""
import base64
import json
import logging
import os
import time
try: # Python 2
from urlparse import urlparse
except: # Python 3
from urllib.parse import urlparse
from .oauth2cli.oidc import decode_part


logger = logging.getLogger(__name__)


def _is_running_in_cloud_shell():
return os.environ.get("AZUREPS_HOST_ENVIRONMENT", "").startswith("cloud-shell")


def _scope_to_resource(scope): # This is an experimental reasonable-effort approach
cloud_shell_supported_audiences = [
"https://analysis.windows.net/powerbi/api", # Came from https://msazure.visualstudio.com/One/_git/compute-CloudShell?path=/src/images/agent/env/envconfig.PROD.json
"https://pas.windows.net/CheckMyAccess/Linux/.default", # Cloud Shell accepts it as-is
]
for a in cloud_shell_supported_audiences:
if scope.startswith(a):
return a
u = urlparse(scope)
if u.scheme:
return "{}://{}".format(u.scheme, u.netloc)
return scope # There is no much else we can do here


def _obtain_token(http_client, scopes, client_id=None, data=None):
resp = http_client.post(
"http://localhost:50342/oauth2/token",
data=dict(
data or {},
resource=" ".join(map(_scope_to_resource, scopes))),
headers={"Metadata": "true"},
)
if resp.status_code >= 300:
logger.debug("Cloud Shell IMDS error: %s", resp.text)
cs_error = json.loads(resp.text).get("error", {})
return {k: v for k, v in {
"error": cs_error.get("code"),
"error_description": cs_error.get("message"),
}.items() if v}
imds_payload = json.loads(resp.text)
BEARER = "Bearer"
oauth2_response = {
"access_token": imds_payload["access_token"],
"expires_in": int(imds_payload["expires_in"]),
"token_type": imds_payload.get("token_type", BEARER),
}
expected_token_type = (data or {}).get("token_type", BEARER)
if oauth2_response["token_type"] != expected_token_type:
return { # Generate a normal error (rather than an intrusive exception)
"error": "broker_error",
"error_description": "token_type {} is not supported by this version of Azure Portal".format(
expected_token_type),
}
parts = imds_payload["access_token"].split(".")

# The following default values are useful in SSH Cert scenario
client_info = { # Default value, in case the real value will be unavailable
"uid": "user",
"utid": "cloudshell",
}
now = time.time()
preferred_username = "currentuser@cloudshell"
oauth2_response["id_token_claims"] = { # First 5 claims are required per OIDC
"iss": "cloudshell",
"sub": "user",
"aud": client_id,
"exp": now + 3600,
"iat": now,
"preferred_username": preferred_username, # Useful as MSAL account's username
}

if len(parts) == 3: # Probably a JWT. Use it to derive client_info and id token.
try:
# Data defined in https://docs.microsoft.com/en-us/azure/active-directory/develop/access-tokens#payload-claims
jwt_payload = json.loads(decode_part(parts[1]))
client_info = {
# Mimic a real home_account_id,
# so that this pseudo account and a real account would interop.
"uid": jwt_payload.get("oid", "user"),
"utid": jwt_payload.get("tid", "cloudshell"),
}
oauth2_response["id_token_claims"] = {
"iss": jwt_payload["iss"],
"sub": jwt_payload["sub"], # Could use oid instead
"aud": client_id,
"exp": jwt_payload["exp"],
"iat": jwt_payload["iat"],
"preferred_username": jwt_payload.get("preferred_username") # V2
or jwt_payload.get("unique_name") # V1
or preferred_username,
}
except ValueError:
logger.debug("Unable to decode jwt payload: %s", parts[1])
oauth2_response["client_info"] = base64.b64encode(
# Mimic a client_info, so that MSAL would create an account
json.dumps(client_info).encode("utf-8")).decode("utf-8")
oauth2_response["id_token_claims"]["tid"] = client_info["utid"] # TBD

## Note: Decided to not surface resource back as scope,
## because they would cause the downstream OAuth2 code path to
## cache the token with a different scope and won't hit them later.
#if imds_payload.get("resource"):
# oauth2_response["scope"] = imds_payload["resource"]
if imds_payload.get("refresh_token"):
oauth2_response["refresh_token"] = imds_payload["refresh_token"]
return oauth2_response

9 changes: 6 additions & 3 deletions msal/token_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def wipe(dictionary, sensitive_fields): # Masks sensitive info
return self.__add(event, now=now)
finally:
wipe(event.get("response", {}), ( # These claims were useful during __add()
"id_token_claims", # Provided by broker
"access_token", "refresh_token", "id_token", "username"))
wipe(event, ["username"]) # Needed for federated ROPC
logger.debug("event=%s", json.dumps(
Expand Down Expand Up @@ -150,7 +151,8 @@ def __add(self, event, now=None):
id_token = response.get("id_token")
id_token_claims = (
decode_id_token(id_token, client_id=event["client_id"])
if id_token else {})
if id_token
else response.get("id_token_claims", {})) # Broker would provide id_token_claims
client_info, home_account_id = self.__parse_account(response, id_token_claims)

target = ' '.join(event.get("scope") or []) # Per schema, we don't sort it
Expand Down Expand Up @@ -195,9 +197,10 @@ def __add(self, event, now=None):
or data.get("username") # Falls back to ROPC username
or event.get("username") # Falls back to Federated ROPC username
or "", # The schema does not like null
"authority_type":
"authority_type": event.get(
"authority_type", # Honor caller's choice of authority_type
self.AuthorityType.ADFS if realm == "adfs"
else self.AuthorityType.MSSTS,
else self.AuthorityType.MSSTS),
# "client_info": response.get("client_info"), # Optional
}
self.modify(self.CredentialType.ACCOUNT, account, account)
Expand Down
17 changes: 17 additions & 0 deletions tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,12 +185,14 @@ def _test_acquire_token_interactive(
self, client_id=None, authority=None, scope=None, port=None,
username_uri="", # But you would want to provide one
data=None, # Needed by ssh-cert feature
prompt=None,
**ignored):
assert client_id and authority and scope
self.app = msal.PublicClientApplication(
client_id, authority=authority, http_client=MinimalHttpClient())
result = self.app.acquire_token_interactive(
scope,
prompt=prompt,
timeout=120,
port=port,
welcome_template= # This is an undocumented feature for testing
Expand Down Expand Up @@ -237,6 +239,7 @@ def test_ssh_cert_for_user(self):
scope=self.SCOPE,
data=self.DATA1,
username_uri="https://msidlab.com/api/user?usertype=cloud",
prompt="none" if msal.application._is_running_in_cloud_shell() else None,
) # It already tests reading AT from cache, and using RT to refresh
# acquire_token_silent() would work because we pass in the same key
self.assertIsNotNone(result.get("access_token"), "Encountered {}: {}".format(
Expand All @@ -254,6 +257,20 @@ def test_ssh_cert_for_user(self):
self.assertNotEqual(result["access_token"], refreshed_ssh_cert['access_token'])


@unittest.skipUnless(
msal.application._is_running_in_cloud_shell(),
"Manually run this test case from inside Cloud Shell")
class CloudShellTestCase(E2eTestCase):
app = msal.PublicClientApplication("client_id")
scope_that_requires_no_managed_device = "https://management.core.windows.net/" # Scopes came from https://msazure.visualstudio.com/One/_git/compute-CloudShell?path=/src/images/agent/env/envconfig.PROD.json&version=GBmaster&_a=contents
def test_access_token_should_be_obtained_for_a_supported_scope(self):
result = self.app.acquire_token_interactive(
[self.scope_that_requires_no_managed_device], prompt="none")
self.assertEqual(
"Bearer", result.get("token_type"), "Unexpected result: %s" % result)
self.assertIsNotNone(result.get("access_token"))


THIS_FOLDER = os.path.dirname(__file__)
CONFIG = os.path.join(THIS_FOLDER, "config.json")
@unittest.skipUnless(os.path.exists(CONFIG), "Optional %s not found" % CONFIG)
Expand Down