Skip to content

Allow interactive flow to be aborted by CTRL+C even when running on Windows #404

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 35 additions & 6 deletions msal/oauth2cli/authcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import logging
import socket
from string import Template
import threading
import time

try: # Python 3
from http.server import HTTPServer, BaseHTTPRequestHandler
Expand Down Expand Up @@ -149,11 +151,7 @@ def get_port(self):
# https://docs.python.org/2.7/library/socketserver.html#SocketServer.BaseServer.server_address
return self._server.server_address[1]

def get_auth_response(self, auth_uri=None, timeout=None, state=None,
welcome_template=None, success_template=None, error_template=None,
auth_uri_callback=None,
browser_name=None,
):
def get_auth_response(self, timeout=None, **kwargs):
"""Wait and return the auth response. Raise RuntimeError when timeout.

:param str auth_uri:
Expand Down Expand Up @@ -192,6 +190,37 @@ def get_auth_response(self, auth_uri=None, timeout=None, state=None,
and https://openid.net/specs/openid-connect-core-1_0.html#AuthResponse
Returns None when the state was mismatched, or when timeout occurred.
"""
# Historically, the _get_auth_response() uses HTTPServer.handle_request(),
# because its handle-and-retry logic is conceptually as easy as a while loop.
# Also, handle_request() honors server.timeout setting, and CTRL+C simply works.
# All those are true when running on Linux.
#
# However, the behaviors on Windows turns out to be different.
# A socket server waiting for request would freeze the current thread.
# Neither timeout nor CTRL+C would work. End user would have to do CTRL+BREAK.
# https://stackoverflow.com/questions/1364173/stopping-python-using-ctrlc
#
# The solution would need to somehow put the http server into its own thread.
# This could be done by the pattern of ``http.server.test()`` which internally
# use ``ThreadingHTTPServer.serve_forever()`` (only available in Python 3.7).
# Or create our own thread to wrap the HTTPServer.handle_request() inside.
result = {} # A mutable object to be filled with thread's return value
t = threading.Thread(
target=self._get_auth_response, args=(result,), kwargs=kwargs)
t.daemon = True # So that it won't prevent the main thread from exiting
t.start()
begin = time.time()
while (time.time() - begin < timeout) if timeout else True:
time.sleep(1) # Short detection interval to make happy path responsive
if not t.is_alive(): # Then the thread has finished its job and exited
break
return result or None

def _get_auth_response(self, result, auth_uri=None, timeout=None, state=None,
welcome_template=None, success_template=None, error_template=None,
auth_uri_callback=None,
browser_name=None,
):
welcome_uri = "http://localhost:{p}".format(p=self.get_port())
abort_uri = "{loc}?error=abort".format(loc=welcome_uri)
logger.debug("Abort by visit %s", abort_uri)
Expand Down Expand Up @@ -238,7 +267,7 @@ def get_auth_response(self, auth_uri=None, timeout=None, state=None,
logger.debug("State mismatch. Ignoring this noise.")
else:
break
return self._server.auth_response
result.update(self._server.auth_response) # Return via writable result param

def close(self):
"""Either call this eventually; or use the entire class as context manager"""
Expand Down
2 changes: 1 addition & 1 deletion msal/oauth2cli/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749
_data["client_assertion"] = encoder(
self.client_assertion() # Do lazy on-the-fly computation
if callable(self.client_assertion) else self.client_assertion
) # The type is bytes, which is preferrable. See also:
) # The type is bytes, which is preferable. See also:
# https://github.com/psf/requests/issues/4503#issuecomment-455001070

_data.update(self.default_body) # It may contain authen parameters
Expand Down