Skip to content

MSAL Python 1.24.1 #601

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Sep 29, 2023
114 changes: 62 additions & 52 deletions tests/msaltest.py → msal/__main__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
# It is currently shipped inside msal library.
# Pros: It is always available wherever msal is installed.
# Cons: Its 3rd-party dependencies (if any) may become msal's dependency.
"""MSAL Python Tester

Usage 1: Run it on the fly.
python -m msal

Usage 2: Build an all-in-one executable file for bug bash.
shiv -e msal.__main__._main -o msaltest-on-os-name.pyz .
Note: We choose to not define a console script to avoid name conflict.
"""
import base64, getpass, json, logging, sys, msal


AZURE_CLI = "04b07795-8ddb-461a-bbee-02f9e1bf7b46"
VISUAL_STUDIO = "04f0c124-f2bc-4f59-8241-bf6df9866bbd"
_AZURE_CLI = "04b07795-8ddb-461a-bbee-02f9e1bf7b46"
_VISUAL_STUDIO = "04f0c124-f2bc-4f59-8241-bf6df9866bbd"

def print_json(blob):
print(json.dumps(blob, indent=2, sort_keys=True))
Expand Down Expand Up @@ -61,7 +72,7 @@ def _select_account(app):
else:
print("No account available inside MSAL Python. Use other methods to acquire token first.")

def acquire_token_silent(app):
def _acquire_token_silent(app):
"""acquire_token_silent() - with an account already signed into MSAL Python."""
account = _select_account(app)
if account:
Expand All @@ -71,95 +82,94 @@ def acquire_token_silent(app):
force_refresh=_input_boolean("Bypass MSAL Python's token cache?"),
))

def _acquire_token_interactive(app, scopes, data=None):
def _acquire_token_interactive(app, scopes=None, data=None):
"""acquire_token_interactive() - User will be prompted if app opts to do select_account."""
scopes = scopes or _input_scopes() # Let user input scope param before less important prompt and login_hint
prompt = _select_options([
{"value": None, "description": "Unspecified. Proceed silently with a default account (if any), fallback to prompt."},
{"value": "none", "description": "none. Proceed silently with a default account (if any), or error out."},
{"value": "select_account", "description": "select_account. Prompt with an account picker."},
],
option_renderer=lambda o: o["description"],
header="Prompt behavior?")["value"]
raw_login_hint = _select_options(
# login_hint is unnecessary when prompt=select_account,
# but we still let tester input login_hint, just for testing purpose.
[None] + [a["username"] for a in app.get_accounts()],
header="login_hint? (If you have multiple signed-in sessions in browser/broker, and you specify a login_hint to match one of them, you will bypass the account picker.)",
accept_nonempty_string=True,
)
login_hint = raw_login_hint["username"] if isinstance(raw_login_hint, dict) else raw_login_hint
if prompt == "select_account":
login_hint = None # login_hint is unnecessary when prompt=select_account
else:
raw_login_hint = _select_options(
[None] + [a["username"] for a in app.get_accounts()],
header="login_hint? (If you have multiple signed-in sessions in browser/broker, and you specify a login_hint to match one of them, you will bypass the account picker.)",
accept_nonempty_string=True,
)
login_hint = raw_login_hint["username"] if isinstance(raw_login_hint, dict) else raw_login_hint
result = app.acquire_token_interactive(
scopes,
parent_window_handle=app.CONSOLE_WINDOW_HANDLE, # This test app is a console app
enable_msa_passthrough=app.client_id in [ # Apps are expected to set this right
AZURE_CLI, VISUAL_STUDIO,
_AZURE_CLI, _VISUAL_STUDIO,
], # Here this test app mimics the setting for some known MSA-PT apps
prompt=prompt, login_hint=login_hint, data=data or {})
prompt=prompt, login_hint=login_hint, data=data or {},
)
if login_hint and "id_token_claims" in result:
signed_in_user = result.get("id_token_claims", {}).get("preferred_username")
if signed_in_user != login_hint:
logging.warning('Signed-in user "%s" does not match login_hint', signed_in_user)
print_json(result)
return result

def acquire_token_interactive(app):
"""acquire_token_interactive() - User will be prompted if app opts to do select_account."""
print_json(_acquire_token_interactive(app, _input_scopes()))

def acquire_token_by_username_password(app):
def _acquire_token_by_username_password(app):
"""acquire_token_by_username_password() - See constraints here: https://docs.microsoft.com/en-us/azure/active-directory/develop/msal-authentication-flows#constraints-for-ropc"""
print_json(app.acquire_token_by_username_password(
_input("username: "), getpass.getpass("password: "), scopes=_input_scopes()))

_JWK1 = """{"kty":"RSA", "n":"2tNr73xwcj6lH7bqRZrFzgSLj7OeLfbn8216uOMDHuaZ6TEUBDN8Uz0ve8jAlKsP9CQFCSVoSNovdE-fs7c15MxEGHjDcNKLWonznximj8pDGZQjVdfK-7mG6P6z-lgVcLuYu5JcWU_PeEqIKg5llOaz-qeQ4LEDS4T1D2qWRGpAra4rJX1-kmrWmX_XIamq30C9EIO0gGuT4rc2hJBWQ-4-FnE1NXmy125wfT3NdotAJGq5lMIfhjfglDbJCwhc8Oe17ORjO3FsB5CLuBRpYmP7Nzn66lRY3Fe11Xz8AEBl3anKFSJcTvlMnFtu3EpD-eiaHfTgRBU7CztGQqVbiQ", "e":"AQAB"}"""
SSH_CERT_DATA = {"token_type": "ssh-cert", "key_id": "key1", "req_cnf": _JWK1}
SSH_CERT_SCOPE = ["https://pas.windows.net/CheckMyAccess/Linux/.default"]
_SSH_CERT_DATA = {"token_type": "ssh-cert", "key_id": "key1", "req_cnf": _JWK1}
_SSH_CERT_SCOPE = ["https://pas.windows.net/CheckMyAccess/Linux/.default"]

def acquire_ssh_cert_silently(app):
def _acquire_ssh_cert_silently(app):
"""Acquire an SSH Cert silently- This typically only works with Azure CLI"""
account = _select_account(app)
if account:
result = app.acquire_token_silent(
SSH_CERT_SCOPE,
_SSH_CERT_SCOPE,
account,
data=SSH_CERT_DATA,
data=_SSH_CERT_DATA,
force_refresh=_input_boolean("Bypass MSAL Python's token cache?"),
)
print_json(result)
if result and result.get("token_type") != "ssh-cert":
logging.error("Unable to acquire an ssh-cert.")

def acquire_ssh_cert_interactive(app):
def _acquire_ssh_cert_interactive(app):
"""Acquire an SSH Cert interactively - This typically only works with Azure CLI"""
result = _acquire_token_interactive(app, SSH_CERT_SCOPE, data=SSH_CERT_DATA)
print_json(result)
result = _acquire_token_interactive(app, scopes=_SSH_CERT_SCOPE, data=_SSH_CERT_DATA)
if result.get("token_type") != "ssh-cert":
logging.error("Unable to acquire an ssh-cert")

POP_KEY_ID = 'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA-AAAAAAAA' # Fake key with a certain format and length
RAW_REQ_CNF = json.dumps({"kid": POP_KEY_ID, "xms_ksl": "sw"})
POP_DATA = { # Sampled from Azure CLI's plugin connectedk8s
_POP_KEY_ID = 'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA-AAAAAAAA' # Fake key with a certain format and length
_RAW_REQ_CNF = json.dumps({"kid": _POP_KEY_ID, "xms_ksl": "sw"})
_POP_DATA = { # Sampled from Azure CLI's plugin connectedk8s
'token_type': 'pop',
'key_id': POP_KEY_ID,
"req_cnf": base64.urlsafe_b64encode(RAW_REQ_CNF.encode('utf-8')).decode('utf-8').rstrip('='),
# Note: Sending RAW_REQ_CNF without base64 encoding would result in an http 500 error
'key_id': _POP_KEY_ID,
"req_cnf": base64.urlsafe_b64encode(_RAW_REQ_CNF.encode('utf-8')).decode('utf-8').rstrip('='),
# Note: Sending _RAW_REQ_CNF without base64 encoding would result in an http 500 error
} # See also https://github.com/Azure/azure-cli-extensions/blob/main/src/connectedk8s/azext_connectedk8s/_clientproxyutils.py#L86-L92

def acquire_pop_token_interactive(app):
def _acquire_pop_token_interactive(app):
"""Acquire a POP token interactively - This typically only works with Azure CLI"""
POP_SCOPE = ['6256c85f-0aad-4d50-b960-e6e9b21efe35/.default'] # KAP 1P Server App Scope, obtained from https://github.com/Azure/azure-cli-extensions/pull/4468/files#diff-a47efa3186c7eb4f1176e07d0b858ead0bf4a58bfd51e448ee3607a5b4ef47f6R116
result = _acquire_token_interactive(app, POP_SCOPE, data=POP_DATA)
result = _acquire_token_interactive(app, scopes=POP_SCOPE, data=_POP_DATA)
print_json(result)
if result.get("token_type") != "pop":
logging.error("Unable to acquire a pop token")


def remove_account(app):
def _remove_account(app):
"""remove_account() - Invalidate account and/or token(s) from cache, so that acquire_token_silent() would be reset"""
account = _select_account(app)
if account:
app.remove_account(account)
print('Account "{}" and/or its token(s) are signed out from MSAL Python'.format(account["username"]))

def exit(app):
def _exit(app):
"""Exit"""
bug_link = (
"https://identitydivision.visualstudio.com/Engineering/_queries/query/79b3a352-a775-406f-87cd-a487c382a8ed/"
Expand All @@ -169,11 +179,11 @@ def exit(app):
print("Bye. If you found a bug, please report it here: {}".format(bug_link))
sys.exit()

def main():
print("Welcome to the Msal Python {} Tester\n".format(msal.__version__))
def _main():
print("Welcome to the Msal Python {} Tester (Experimental)\n".format(msal.__version__))
chosen_app = _select_options([
{"client_id": AZURE_CLI, "name": "Azure CLI (Correctly configured for MSA-PT)"},
{"client_id": VISUAL_STUDIO, "name": "Visual Studio (Correctly configured for MSA-PT)"},
{"client_id": _AZURE_CLI, "name": "Azure CLI (Correctly configured for MSA-PT)"},
{"client_id": _VISUAL_STUDIO, "name": "Visual Studio (Correctly configured for MSA-PT)"},
{"client_id": "95de633a-083e-42f5-b444-a4295d8e9314", "name": "Whiteboard Services (Non MSA-PT app. Accepts AAD & MSA accounts.)"},
],
option_renderer=lambda a: a["name"],
Expand Down Expand Up @@ -201,14 +211,14 @@ def main():
logging.basicConfig(level=logging.DEBUG)
while True:
func = _select_options([
acquire_token_silent,
acquire_token_interactive,
acquire_token_by_username_password,
acquire_ssh_cert_silently,
acquire_ssh_cert_interactive,
acquire_pop_token_interactive,
remove_account,
exit,
_acquire_token_silent,
_acquire_token_interactive,
_acquire_token_by_username_password,
_acquire_ssh_cert_silently,
_acquire_ssh_cert_interactive,
_acquire_pop_token_interactive,
_remove_account,
_exit,
], option_renderer=lambda f: f.__doc__, header="MSAL Python APIs:")
try:
func(app)
Expand All @@ -218,5 +228,5 @@ def main():
print("Aborted")

if __name__ == "__main__":
main()
_main()

2 changes: 1 addition & 1 deletion msal/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@


# The __init__.py will import this. Not the other way around.
__version__ = "1.24.0" # When releasing, also check and bump our dependencies's versions if needed
__version__ = "1.24.1" # When releasing, also check and bump our dependencies's versions if needed

logger = logging.getLogger(__name__)
_AUTHORITY_TYPE_CLOUDSHELL = "CLOUDSHELL"
Expand Down
46 changes: 32 additions & 14 deletions msal/oauth2cli/authcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
try: # Python 3
from http.server import HTTPServer, BaseHTTPRequestHandler
from urllib.parse import urlparse, parse_qs, urlencode
from html import escape
except ImportError: # Fall back to Python 2
from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler
from urlparse import urlparse, parse_qs
from urllib import urlencode
from cgi import escape


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -77,25 +79,42 @@ def _qs2kv(qs):
for k, v in qs.items()}


def _is_html(text):
return text.startswith("<") # Good enough for our purpose


def _escape(key_value_pairs):
return {k: escape(v) for k, v in key_value_pairs.items()}


class _AuthCodeHandler(BaseHTTPRequestHandler):
def do_GET(self):
# For flexibility, we choose to not check self.path matching redirect_uri
#assert self.path.startswith('/THE_PATH_REGISTERED_BY_THE_APP')
qs = parse_qs(urlparse(self.path).query)
if qs.get('code') or qs.get("error"): # So, it is an auth response
self.server.auth_response = _qs2kv(qs)
logger.debug("Got auth response: %s", self.server.auth_response)
template = (self.server.success_template
if "code" in qs else self.server.error_template)
self._send_full_response(
template.safe_substitute(**self.server.auth_response))
# NOTE: Don't do self.server.shutdown() here. It'll halt the server.
auth_response = _qs2kv(qs)
logger.debug("Got auth response: %s", auth_response)
if self.server.auth_state and self.server.auth_state != auth_response.get("state"):
# OAuth2 successful and error responses contain state when it was used
# https://www.rfc-editor.org/rfc/rfc6749#section-4.2.2.1
self._send_full_response("State mismatch") # Possibly an attack
else:
template = (self.server.success_template
if "code" in qs else self.server.error_template)
if _is_html(template.template):
safe_data = _escape(auth_response) # Foiling an XSS attack
else:
safe_data = auth_response
self._send_full_response(template.safe_substitute(**safe_data))
self.server.auth_response = auth_response # Set it now, after the response is likely sent
else:
self._send_full_response(self.server.welcome_page)
# NOTE: Don't do self.server.shutdown() here. It'll halt the server.

def _send_full_response(self, body, is_ok=True):
self.send_response(200 if is_ok else 400)
content_type = 'text/html' if body.startswith('<') else 'text/plain'
content_type = 'text/html' if _is_html(body) else 'text/plain'
self.send_header('Content-type', content_type)
self.end_headers()
self.wfile.write(body.encode("utf-8"))
Expand Down Expand Up @@ -281,16 +300,14 @@ def _get_auth_response(self, result, auth_uri=None, timeout=None, state=None,

self._server.timeout = timeout # Otherwise its handle_timeout() won't work
self._server.auth_response = {} # Shared with _AuthCodeHandler
self._server.auth_state = state # So handler will check it before sending response
while not self._closing: # Otherwise, the handle_request() attempt
# would yield noisy ValueError trace
# Derived from
# https://docs.python.org/2/library/basehttpserver.html#more-examples
self._server.handle_request()
if self._server.auth_response:
if state and state != self._server.auth_response.get("state"):
logger.debug("State mismatch. Ignoring this noise.")
else:
break
break
result.update(self._server.auth_response) # Return via writable result param

def close(self):
Expand Down Expand Up @@ -318,6 +335,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
default="https://login.microsoftonline.com/common/oauth2/v2.0/authorize")
p.add_argument('client_id', help="The client_id of your application")
p.add_argument('--port', type=int, default=0, help="The port in redirect_uri")
p.add_argument('--timeout', type=int, default=60, help="Timeout value, in second")
p.add_argument('--host', default="127.0.0.1", help="The host of redirect_uri")
p.add_argument('--scope', default=None, help="The scope list")
args = parser.parse_args()
Expand All @@ -331,8 +349,8 @@ def __exit__(self, exc_type, exc_val, exc_tb):
auth_uri=flow["auth_uri"],
welcome_template=
"<a href='$auth_uri'>Sign In</a>, or <a href='$abort_uri'>Abort</a",
error_template="Oh no. $error",
error_template="<html>Oh no. $error</html>",
success_template="Oh yeah. Got $code",
timeout=60,
timeout=args.timeout,
state=flow["state"], # Optional
), indent=4))
2 changes: 1 addition & 1 deletion msal/oauth2cli/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,7 @@ def _obtain_token_by_browser(
**(auth_params or {}))
auth_response = auth_code_receiver.get_auth_response(
auth_uri=flow["auth_uri"],
state=flow["state"], # Optional but we choose to do it upfront
state=flow["state"], # So receiver can check it early
timeout=timeout,
welcome_template=welcome_template,
success_template=success_template,
Expand Down
22 changes: 19 additions & 3 deletions tests/test_authcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import socket
import sys

import requests

from msal.oauth2cli.authcode import AuthCodeReceiver


Expand All @@ -17,10 +19,24 @@ def test_setup_at_a_ephemeral_port_and_teardown(self):
self.assertNotEqual(port, receiver.get_port())

def test_no_two_concurrent_receivers_can_listen_on_same_port(self):
port = 12345 # Assuming this port is available
with AuthCodeReceiver(port=port) as receiver:
with AuthCodeReceiver() as receiver:
expected_error = OSError if sys.version_info[0] > 2 else socket.error
with self.assertRaises(expected_error):
with AuthCodeReceiver(port=port) as receiver2:
with AuthCodeReceiver(port=receiver.get_port()):
pass

def test_template_should_escape_input(self):
with AuthCodeReceiver() as receiver:
receiver._scheduled_actions = [( # Injection happens here when the port is known
1, # Delay it until the receiver is activated by get_auth_response()
lambda: self.assertEqual(
"<html>&lt;tag&gt;foo&lt;/tag&gt;</html>",
requests.get("http://localhost:{}?error=<tag>foo</tag>".format(
receiver.get_port())).text,
"Unsafe data in HTML should be escaped",
))]
receiver.get_auth_response( # Starts server and hang until timeout
timeout=3,
error_template="<html>$error</html>",
)