Skip to content

Commit 93aa57a

Browse files
bpo-36801: Fix waiting in StreamWriter.drain for closing SSL transport (GH-13098)
https://bugs.python.org/issue36801 (cherry picked from commit 1cc0ee7) Co-authored-by: Andrew Svetlov <[email protected]>
1 parent 5edd82c commit 93aa57a

File tree

4 files changed

+46
-8
lines changed

4 files changed

+46
-8
lines changed

Lib/asyncio/streams.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,9 @@ async def _drain_helper(self):
208208
self._drain_waiter = waiter
209209
await waiter
210210

211+
def _get_close_waiter(self, stream):
212+
raise NotImplementedError
213+
211214

212215
class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
213216
"""Helper class to adapt between Protocol and StreamReader.
@@ -265,6 +268,9 @@ def eof_received(self):
265268
return False
266269
return True
267270

271+
def _get_close_waiter(self, stream):
272+
return self._closed
273+
268274
def __del__(self):
269275
# Prevent reports about unhandled exceptions.
270276
# Better than self._closed._log_traceback = False hack
@@ -320,7 +326,7 @@ def is_closing(self):
320326
return self._transport.is_closing()
321327

322328
async def wait_closed(self):
323-
await self._protocol._closed
329+
await self._protocol._get_close_waiter(self)
324330

325331
def get_extra_info(self, name, default=None):
326332
return self._transport.get_extra_info(name, default)
@@ -338,13 +344,12 @@ async def drain(self):
338344
if exc is not None:
339345
raise exc
340346
if self._transport.is_closing():
341-
# Yield to the event loop so connection_lost() may be
342-
# called. Without this, _drain_helper() would return
343-
# immediately, and code that calls
344-
# write(...); await drain()
345-
# in a loop would never call connection_lost(), so it
346-
# would not see an error when the socket is closed.
347-
await sleep(0, loop=self._loop)
347+
# Wait for protocol.connection_lost() call
348+
# Raise connection closing error if any,
349+
# ConnectionResetError otherwise
350+
fut = self._protocol._get_close_waiter(self)
351+
await fut
352+
raise ConnectionResetError('Connection lost')
348353
await self._protocol._drain_helper()
349354

350355

Lib/asyncio/subprocess.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def __init__(self, limit, loop):
2525
self._transport = None
2626
self._process_exited = False
2727
self._pipe_fds = []
28+
self._stdin_closed = self._loop.create_future()
2829

2930
def __repr__(self):
3031
info = [self.__class__.__name__]
@@ -76,6 +77,10 @@ def pipe_connection_lost(self, fd, exc):
7677
if pipe is not None:
7778
pipe.close()
7879
self.connection_lost(exc)
80+
if exc is None:
81+
self._stdin_closed.set_result(None)
82+
else:
83+
self._stdin_closed.set_exception(exc)
7984
return
8085
if fd == 1:
8186
reader = self.stdout
@@ -102,6 +107,10 @@ def _maybe_close_transport(self):
102107
self._transport.close()
103108
self._transport = None
104109

110+
def _get_close_waiter(self, stream):
111+
if stream is self.stdin:
112+
return self._stdin_closed
113+
105114

106115
class Process:
107116
def __init__(self, transport, protocol, loop):

Lib/test/test_asyncio/test_streams.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,29 @@ def test_open_unix_connection_no_loop_ssl(self):
9999

100100
self._basetest_open_connection_no_loop_ssl(conn_fut)
101101

102+
@unittest.skipIf(ssl is None, 'No ssl module')
103+
def test_drain_on_closed_writer_ssl(self):
104+
105+
async def inner(httpd):
106+
reader, writer = await asyncio.open_connection(
107+
*httpd.address,
108+
ssl=test_utils.dummy_ssl_context())
109+
110+
messages = []
111+
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
112+
writer.write(b'GET / HTTP/1.0\r\n\r\n')
113+
data = await reader.read()
114+
self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
115+
116+
writer.close()
117+
with self.assertRaises(ConnectionResetError):
118+
await writer.drain()
119+
120+
self.assertEqual(messages, [])
121+
122+
with test_utils.run_test_server(use_ssl=True) as httpd:
123+
self.loop.run_until_complete(inner(httpd))
124+
102125
def _basetest_open_connection_error(self, open_connection_fut):
103126
reader, writer = self.loop.run_until_complete(open_connection_fut)
104127
writer._protocol.connection_lost(ZeroDivisionError())
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Properly handle SSL connection closing in asyncio StreamWriter.drain() call.

0 commit comments

Comments
 (0)