Skip to content

Ssh cert tests #300

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion msal/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ def decorate_scope(
CLIENT_CURRENT_TELEMETRY = 'x-client-current-telemetry'

def _get_new_correlation_id():
return str(uuid.uuid4())
correlation_id = str(uuid.uuid4())
logger.debug("Generates correlation_id: %s", correlation_id)
return correlation_id


def _build_current_telemetry_request_header(public_api_id, force_refresh=False):
Expand Down Expand Up @@ -1233,6 +1235,7 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs):
- an error response would contain "error" and usually "error_description".
"""
# TBD: force_refresh behavior
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
return self.client.obtain_token_for_client(
scope=scopes, # This grant flow requires no scope decoration
headers={
Expand Down
165 changes: 86 additions & 79 deletions tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import time
import unittest
import sys

import requests

Expand All @@ -11,7 +12,7 @@
from msal.oauth2cli import AuthCodeReceiver

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
logging.basicConfig(level=logging.DEBUG if "-v" in sys.argv else logging.INFO)


def _get_app_and_auth_code(
Expand Down Expand Up @@ -49,7 +50,8 @@ def assertLoosely(self, response, assertion=None,
error_description=response.get("error_description")))
assertion()

def assertCacheWorksForUser(self, result_from_wire, scope, username=None):
def assertCacheWorksForUser(
self, result_from_wire, scope, username=None, data=None):
# You can filter by predefined username, or let end user to choose one
accounts = self.app.get_accounts(username=username)
self.assertNotEqual(0, len(accounts))
Expand All @@ -59,7 +61,8 @@ def assertCacheWorksForUser(self, result_from_wire, scope, username=None):
set(scope) <= set(result_from_wire["scope"].split(" "))
):
# Going to test acquire_token_silent(...) to locate an AT from cache
result_from_cache = self.app.acquire_token_silent(scope, account=account)
result_from_cache = self.app.acquire_token_silent(
scope, account=account, data=data or {})
self.assertIsNotNone(result_from_cache)
self.assertIsNone(
result_from_cache.get("refresh_token"), "A cache hit returns no RT")
Expand All @@ -69,7 +72,8 @@ def assertCacheWorksForUser(self, result_from_wire, scope, username=None):

# Going to test acquire_token_silent(...) to obtain an AT by a RT from cache
self.app.token_cache._cache["AccessToken"] = {} # A hacky way to clear ATs
result_from_cache = self.app.acquire_token_silent(scope, account=account)
result_from_cache = self.app.acquire_token_silent(
scope, account=account, data=data or {})
self.assertIsNotNone(result_from_cache,
"We should get a result from acquire_token_silent(...) call")
self.assertIsNotNone(
Expand Down Expand Up @@ -131,6 +135,84 @@ def _test_device_flow(
logger.info(
"%s obtained tokens: %s", self.id(), json.dumps(result, indent=4))

def _test_acquire_token_interactive(
self, client_id=None, authority=None, scope=None, port=None,
username_uri="", # But you would want to provide one
data=None, # Needed by ssh-cert feature
**ignored):
assert client_id and authority and scope
self.app = msal.PublicClientApplication(
client_id, authority=authority, http_client=MinimalHttpClient())
result = self.app.acquire_token_interactive(
scope,
timeout=120,
port=port,
welcome_template= # This is an undocumented feature for testing
"""<html><body><h1>{id}</h1><ol>
<li>Get a username from the upn shown at <a href="{username_uri}">here</a></li>
<li>Get its password from https://aka.ms/GetLabUserSecret?Secret=msidlabXYZ
(replace the lab name with the labName from the link above).</li>
<li><a href="$auth_uri">Sign In</a> or <a href="$abort_uri">Abort</a></li>
</ol></body></html>""".format(id=self.id(), username_uri=username_uri),
data=data or {},
)
logger.debug(
"%s: cache = %s, id_token_claims = %s",
self.id(),
json.dumps(self.app.token_cache._cache, indent=4),
json.dumps(result.get("id_token_claims"), indent=4),
)
self.assertIn(
"access_token", result,
"{error}: {error_description}".format(
# Note: No interpolation here, cause error won't always present
error=result.get("error"),
error_description=result.get("error_description")))
self.assertCacheWorksForUser(result, scope, username=None, data=data or {})
return result # For further testing


class SshCertTestCase(E2eTestCase):
_JWK1 = """{"kty":"RSA", "n":"2tNr73xwcj6lH7bqRZrFzgSLj7OeLfbn8216uOMDHuaZ6TEUBDN8Uz0ve8jAlKsP9CQFCSVoSNovdE-fs7c15MxEGHjDcNKLWonznximj8pDGZQjVdfK-7mG6P6z-lgVcLuYu5JcWU_PeEqIKg5llOaz-qeQ4LEDS4T1D2qWRGpAra4rJX1-kmrWmX_XIamq30C9EIO0gGuT4rc2hJBWQ-4-FnE1NXmy125wfT3NdotAJGq5lMIfhjfglDbJCwhc8Oe17ORjO3FsB5CLuBRpYmP7Nzn66lRY3Fe11Xz8AEBl3anKFSJcTvlMnFtu3EpD-eiaHfTgRBU7CztGQqVbiQ", "e":"AQAB"}"""
_JWK2 = """{"kty":"RSA", "n":"72u07mew8rw-ssw3tUs9clKstGO2lvD7ZNxJU7OPNKz5PGYx3gjkhUmtNah4I4FP0DuF1ogb_qSS5eD86w10Wb1ftjWcoY8zjNO9V3ph-Q2tMQWdDW5kLdeU3-EDzc0HQeou9E0udqmfQoPbuXFQcOkdcbh3eeYejs8sWn3TQprXRwGh_TRYi-CAurXXLxQ8rp-pltUVRIr1B63fXmXhMeCAGwCPEFX9FRRs-YHUszUJl9F9-E0nmdOitiAkKfCC9LhwB9_xKtjmHUM9VaEC9jWOcdvXZutwEoW2XPMOg0Ky-s197F9rfpgHle2gBrXsbvVMvS0D-wXg6vsq6BAHzQ", "e":"AQAB"}"""
DATA1 = {"token_type": "ssh-cert", "key_id": "key1", "req_cnf": _JWK1}
DATA2 = {"token_type": "ssh-cert", "key_id": "key2", "req_cnf": _JWK2}
_SCOPE_USER = ["https://pas.windows.net/CheckMyAccess/Linux/user_impersonation"]
_SCOPE_SP = ["https://pas.windows.net/CheckMyAccess/Linux/.default"]
SCOPE = _SCOPE_SP # Historically there was a separation, at 2021 it is unified

def test_ssh_cert_for_service_principal(self):
# Any SP can obtain an ssh-cert. Here we use the lab app.
result = get_lab_app().acquire_token_for_client(self.SCOPE, data=self.DATA1)
self.assertIsNotNone(result.get("access_token"), "Encountered {}: {}".format(
result.get("error"), result.get("error_description")))
self.assertEqual("ssh-cert", result["token_type"])

@unittest.skipIf(os.getenv("TRAVIS"), "Browser automation is not yet implemented")
def test_ssh_cert_for_user(self):
result = self._test_acquire_token_interactive(
client_id="04b07795-8ddb-461a-bbee-02f9e1bf7b46", # Azure CLI is one
# of the only 2 clients that are PreAuthz to use ssh cert feature
authority="https://login.microsoftonline.com/common",
scope=self.SCOPE,
data=self.DATA1,
username_uri="https://msidlab.com/api/user?usertype=cloud",
) # It already tests reading AT from cache, and using RT to refresh
# acquire_token_silent() would work because we pass in the same key
self.assertIsNotNone(result.get("access_token"), "Encountered {}: {}".format(
result.get("error"), result.get("error_description")))
self.assertEqual("ssh-cert", result["token_type"])
logger.debug("%s.cache = %s",
self.id(), json.dumps(self.app.token_cache._cache, indent=4))

# refresh_token grant can fetch an ssh-cert bound to a different key
account = self.app.get_accounts()[0]
refreshed_ssh_cert = self.app.acquire_token_silent(
self.SCOPE, account=account, data=self.DATA2)
self.assertIsNotNone(refreshed_ssh_cert)
self.assertEqual(refreshed_ssh_cert["token_type"], "ssh-cert")
self.assertNotEqual(result["access_token"], refreshed_ssh_cert['access_token'])


THIS_FOLDER = os.path.dirname(__file__)
CONFIG = os.path.join(THIS_FOLDER, "config.json")
Expand Down Expand Up @@ -190,48 +272,6 @@ def test_auth_code_with_mismatching_nonce(self):
self.app.acquire_token_by_authorization_code(
ac, self.config["scope"], redirect_uri=redirect_uri, nonce="bar")

def test_ssh_cert(self):
self.skipUnlessWithConfig(["client_id", "scope"])

JWK1 = """{"kty":"RSA", "n":"2tNr73xwcj6lH7bqRZrFzgSLj7OeLfbn8216uOMDHuaZ6TEUBDN8Uz0ve8jAlKsP9CQFCSVoSNovdE-fs7c15MxEGHjDcNKLWonznximj8pDGZQjVdfK-7mG6P6z-lgVcLuYu5JcWU_PeEqIKg5llOaz-qeQ4LEDS4T1D2qWRGpAra4rJX1-kmrWmX_XIamq30C9EIO0gGuT4rc2hJBWQ-4-FnE1NXmy125wfT3NdotAJGq5lMIfhjfglDbJCwhc8Oe17ORjO3FsB5CLuBRpYmP7Nzn66lRY3Fe11Xz8AEBl3anKFSJcTvlMnFtu3EpD-eiaHfTgRBU7CztGQqVbiQ", "e":"AQAB"}"""
JWK2 = """{"kty":"RSA", "n":"72u07mew8rw-ssw3tUs9clKstGO2lvD7ZNxJU7OPNKz5PGYx3gjkhUmtNah4I4FP0DuF1ogb_qSS5eD86w10Wb1ftjWcoY8zjNO9V3ph-Q2tMQWdDW5kLdeU3-EDzc0HQeou9E0udqmfQoPbuXFQcOkdcbh3eeYejs8sWn3TQprXRwGh_TRYi-CAurXXLxQ8rp-pltUVRIr1B63fXmXhMeCAGwCPEFX9FRRs-YHUszUJl9F9-E0nmdOitiAkKfCC9LhwB9_xKtjmHUM9VaEC9jWOcdvXZutwEoW2XPMOg0Ky-s197F9rfpgHle2gBrXsbvVMvS0D-wXg6vsq6BAHzQ", "e":"AQAB"}"""
data1 = {"token_type": "ssh-cert", "key_id": "key1", "req_cnf": JWK1}
ssh_test_slice = {
"dc": "prod-wst-test1",
"slice": "test",
"sshcrt": "true",
}

scopes = [ # Only this scope would result in an SSH-Cert
"https://pas.windows.net/CheckMyAccess/Linux/user_impersonation"]
(self.app, ac, redirect_uri) = self._get_app_and_auth_code(scopes=scopes)

result = self.app.acquire_token_by_authorization_code(
ac, scopes, redirect_uri=redirect_uri, data=data1,
params=ssh_test_slice)
self.assertIsNotNone(result.get("access_token"), "Encountered {}: {}".format(
result.get("error"), result.get("error_description")))
self.assertEqual("ssh-cert", result["token_type"])
logger.debug("%s.cache = %s",
self.id(), json.dumps(self.app.token_cache._cache, indent=4))

# acquire_token_silent() needs to be passed the same key to work
account = self.app.get_accounts()[0]
result_from_cache = self.app.acquire_token_silent(
scopes, account=account, data=data1)
self.assertIsNotNone(result_from_cache)
self.assertEqual(
result['access_token'], result_from_cache['access_token'],
"We should get the cached SSH-cert")

# refresh_token grant can fetch an ssh-cert bound to a different key
refreshed_ssh_cert = self.app.acquire_token_silent(
scopes, account=account, params=ssh_test_slice,
data={"token_type": "ssh-cert", "key_id": "key2", "req_cnf": JWK2})
self.assertIsNotNone(refreshed_ssh_cert)
self.assertEqual(refreshed_ssh_cert["token_type"], "ssh-cert")
self.assertNotEqual(result["access_token"], refreshed_ssh_cert['access_token'])

def test_client_secret(self):
self.skipUnlessWithConfig(["client_id", "client_secret"])
self.app = msal.ConfidentialClientApplication(
Expand Down Expand Up @@ -445,39 +485,6 @@ def _test_acquire_token_by_auth_code_flow(
error_description=result.get("error_description")))
self.assertCacheWorksForUser(result, scope, username=None)

def _test_acquire_token_interactive(
self, client_id=None, authority=None, scope=None, port=None,
username_uri="", # But you would want to provide one
**ignored):
assert client_id and authority and scope
self.app = msal.PublicClientApplication(
client_id, authority=authority, http_client=MinimalHttpClient())
result = self.app.acquire_token_interactive(
scope,
timeout=60,
port=port,
welcome_template= # This is an undocumented feature for testing
"""<html><body><h1>{id}</h1><ol>
<li>Get a username from the upn shown at <a href="{username_uri}">here</a></li>
<li>Get its password from https://aka.ms/GetLabUserSecret?Secret=msidlabXYZ
(replace the lab name with the labName from the link above).</li>
<li><a href="$auth_uri">Sign In</a> or <a href="$abort_uri">Abort</a></li>
</ol></body></html>""".format(id=self.id(), username_uri=username_uri),
)
logger.debug(
"%s: cache = %s, id_token_claims = %s",
self.id(),
json.dumps(self.app.token_cache._cache, indent=4),
json.dumps(result.get("id_token_claims"), indent=4),
)
self.assertIn(
"access_token", result,
"{error}: {error_description}".format(
# Note: No interpolation here, cause error won't always present
error=result.get("error"),
error_description=result.get("error_description")))
self.assertCacheWorksForUser(result, scope, username=None)

def _test_acquire_token_obo(self, config_pca, config_cca):
# 1. An app obtains a token representing a user, for our mid-tier service
pca = msal.PublicClientApplication(
Expand Down