Skip to content

Commit 16a9a34

Browse files
committed
Merge branch 'auth-code-receiver-and-ports' into dev
2 parents 6622313 + ef87c00 commit 16a9a34

File tree

2 files changed

+38
-1
lines changed

2 files changed

+38
-1
lines changed

oauth2cli/authcode.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
"""
88
import logging
99
import socket
10+
import sys
1011
from string import Template
1112
import threading
1213
import time
@@ -103,7 +104,17 @@ def log_message(self, format, *args):
103104
logger.debug(format, *args) # To override the default log-to-stderr behavior
104105

105106

106-
class _AuthCodeHttpServer(HTTPServer):
107+
class _AuthCodeHttpServer(HTTPServer, object):
108+
def __init__(self, server_address, *args, **kwargs):
109+
_, port = server_address
110+
if port and (sys.platform == "win32" or is_wsl()):
111+
# The default allow_reuse_address is True. It works fine on non-Windows.
112+
# On Windows, it undesirably allows multiple servers listening on same port,
113+
# yet the second server would not receive any incoming request.
114+
# So, we need to turn it off.
115+
self.allow_reuse_address = False
116+
super(_AuthCodeHttpServer, self).__init__(server_address, *args, **kwargs)
117+
107118
def handle_timeout(self):
108119
# It will be triggered when no request comes in self.timeout seconds.
109120
# See https://docs.python.org/3/library/socketserver.html#socketserver.BaseServer.handle_timeout

tests/test_authcode.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import unittest
2+
import socket
3+
import sys
4+
5+
from oauth2cli.authcode import AuthCodeReceiver
6+
7+
8+
class TestAuthCodeReceiver(unittest.TestCase):
9+
def test_setup_at_a_given_port_and_teardown(self):
10+
port = 12345 # Assuming this port is available
11+
with AuthCodeReceiver(port=port) as receiver:
12+
self.assertEqual(port, receiver.get_port())
13+
14+
def test_setup_at_a_ephemeral_port_and_teardown(self):
15+
port = 0
16+
with AuthCodeReceiver(port=port) as receiver:
17+
self.assertNotEqual(port, receiver.get_port())
18+
19+
def test_no_two_concurrent_receivers_can_listen_on_same_port(self):
20+
port = 12345 # Assuming this port is available
21+
with AuthCodeReceiver(port=port) as receiver:
22+
expected_error = OSError if sys.version_info[0] > 2 else socket.error
23+
with self.assertRaises(expected_error):
24+
with AuthCodeReceiver(port=port) as receiver2:
25+
pass
26+

0 commit comments

Comments
 (0)