Skip to content

Commit ee75018

Browse files
committed
bpo-30064: Fix asyncio loop.sock_* race condition issue
1 parent c73914a commit ee75018

File tree

3 files changed

+151
-14
lines changed

3 files changed

+151
-14
lines changed

Lib/asyncio/selector_events.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ def _add_reader(self, fd, callback, *args):
266266
(handle, writer))
267267
if reader is not None:
268268
reader.cancel()
269+
return handle
269270

270271
def _remove_reader(self, fd):
271272
if self.is_closed():
@@ -302,6 +303,7 @@ def _add_writer(self, fd, callback, *args):
302303
(reader, handle))
303304
if writer is not None:
304305
writer.cancel()
306+
return handle
305307

306308
def _remove_writer(self, fd):
307309
"""Remove a writer callback."""
@@ -362,13 +364,14 @@ async def sock_recv(self, sock, n):
362364
pass
363365
fut = self.create_future()
364366
fd = sock.fileno()
365-
self.add_reader(fd, self._sock_recv, fut, sock, n)
367+
handle = self.add_reader(fd, self._sock_recv, fut, sock, n)
366368
fut.add_done_callback(
367-
functools.partial(self._sock_read_done, fd))
369+
functools.partial(self._sock_read_done, fd, handle=handle))
368370
return await fut
369371

370-
def _sock_read_done(self, fd, fut):
371-
self.remove_reader(fd)
372+
def _sock_read_done(self, fd, fut, handle=None):
373+
if handle is None or not handle.cancelled():
374+
self.remove_reader(fd)
372375

373376
def _sock_recv(self, fut, sock, n):
374377
# _sock_recv() can add itself as an I/O callback if the operation can't
@@ -401,9 +404,9 @@ async def sock_recv_into(self, sock, buf):
401404
pass
402405
fut = self.create_future()
403406
fd = sock.fileno()
404-
self.add_reader(fd, self._sock_recv_into, fut, sock, buf)
407+
handle = self.add_reader(fd, self._sock_recv_into, fut, sock, buf)
405408
fut.add_done_callback(
406-
functools.partial(self._sock_read_done, fd))
409+
functools.partial(self._sock_read_done, fd, handle=handle))
407410
return await fut
408411

409412
def _sock_recv_into(self, fut, sock, buf):
@@ -446,11 +449,11 @@ async def sock_sendall(self, sock, data):
446449

447450
fut = self.create_future()
448451
fd = sock.fileno()
449-
fut.add_done_callback(
450-
functools.partial(self._sock_write_done, fd))
451452
# use a trick with a list in closure to store a mutable state
452-
self.add_writer(fd, self._sock_sendall, fut, sock,
453-
memoryview(data), [n])
453+
handle = self.add_writer(fd, self._sock_sendall, fut, sock,
454+
memoryview(data), [n])
455+
fut.add_done_callback(
456+
functools.partial(self._sock_write_done, fd, handle=handle))
454457
return await fut
455458

456459
def _sock_sendall(self, fut, sock, view, pos):
@@ -502,18 +505,20 @@ def _sock_connect(self, fut, sock, address):
502505
# connection runs in background. We have to wait until the socket
503506
# becomes writable to be notified when the connection succeed or
504507
# fails.
508+
handle = self.add_writer(
509+
fd, self._sock_connect_cb, fut, sock, address)
505510
fut.add_done_callback(
506-
functools.partial(self._sock_write_done, fd))
507-
self.add_writer(fd, self._sock_connect_cb, fut, sock, address)
511+
functools.partial(self._sock_write_done, fd, handle=handle))
508512
except (SystemExit, KeyboardInterrupt):
509513
raise
510514
except BaseException as exc:
511515
fut.set_exception(exc)
512516
else:
513517
fut.set_result(None)
514518

515-
def _sock_write_done(self, fd, fut):
516-
self.remove_writer(fd)
519+
def _sock_write_done(self, fd, fut, handle=None):
520+
if handle is None or not handle.cancelled():
521+
self.remove_writer(fd)
517522

518523
def _sock_connect_cb(self, fut, sock, address):
519524
if fut.done():

Lib/test/test_asyncio/test_sock_lowlevel.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import socket
2+
import time
23
import asyncio
34
import sys
45
from asyncio import proactor_events
@@ -122,6 +123,136 @@ def test_sock_client_ops(self):
122123
sock = socket.socket()
123124
self._basetest_sock_recv_into(httpd, sock)
124125

126+
async def _basetest_sock_recv_racing(self, httpd, sock):
127+
sock.setblocking(False)
128+
await self.loop.sock_connect(sock, httpd.address)
129+
130+
task = asyncio.create_task(self.loop.sock_recv(sock, 1024))
131+
await asyncio.sleep(0)
132+
task.cancel()
133+
134+
asyncio.create_task(
135+
self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
136+
data = await self.loop.sock_recv(sock, 1024)
137+
# consume data
138+
await self.loop.sock_recv(sock, 1024)
139+
140+
self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
141+
142+
async def _basetest_sock_recv_into_racing(self, httpd, sock):
143+
sock.setblocking(False)
144+
await self.loop.sock_connect(sock, httpd.address)
145+
146+
data = bytearray(1024)
147+
with memoryview(data) as buf:
148+
task = asyncio.create_task(
149+
self.loop.sock_recv_into(sock, buf[:1024]))
150+
await asyncio.sleep(0)
151+
task.cancel()
152+
153+
task = asyncio.create_task(
154+
self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
155+
nbytes = await self.loop.sock_recv_into(sock, buf[:1024])
156+
# consume data
157+
await self.loop.sock_recv_into(sock, buf[nbytes:])
158+
self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
159+
160+
await task
161+
162+
async def _basetest_sock_send_racing(self, listener, sock):
163+
listener.bind(('127.0.0.1', 0))
164+
listener.listen(1)
165+
166+
# make connection
167+
sock.setblocking(False)
168+
task = asyncio.create_task(
169+
self.loop.sock_connect(sock, listener.getsockname()))
170+
await asyncio.sleep(0)
171+
server = listener.accept()[0]
172+
server.setblocking(False)
173+
174+
with server:
175+
await task
176+
177+
# fill the buffer
178+
with self.assertRaises(BlockingIOError):
179+
while True:
180+
sock.send(b' ' * 5)
181+
182+
# cancel a blocked sock_sendall
183+
task = asyncio.create_task(
184+
self.loop.sock_sendall(sock, b'hello'))
185+
await asyncio.sleep(0)
186+
task.cancel()
187+
188+
# clear the buffer
189+
async def recv_until():
190+
data = b''
191+
while not data:
192+
data = await self.loop.sock_recv(server, 1024)
193+
data = data.strip()
194+
return data
195+
task = asyncio.create_task(recv_until())
196+
197+
# immediately register another sock_sendall
198+
await self.loop.sock_sendall(sock, b'world')
199+
data = await task
200+
# ProactorEventLoop could deliver hello
201+
self.assertTrue(data.endswith(b'world'))
202+
203+
async def _basetest_sock_connect_racing(self, listener, sock):
204+
listener.bind(('127.0.0.1', 0))
205+
addr = listener.getsockname()
206+
sock.setblocking(False)
207+
208+
task = asyncio.create_task(self.loop.sock_connect(sock, addr))
209+
await asyncio.sleep(0)
210+
task.cancel()
211+
212+
listener.listen(1)
213+
i = 0
214+
while True:
215+
try:
216+
await self.loop.sock_connect(sock, addr)
217+
break
218+
except ConnectionRefusedError: # on Linux we need another retry
219+
await self.loop.sock_connect(sock, addr)
220+
break
221+
except OSError as e: # on Windows we need more retries
222+
# A connect request was made on an already connected socket
223+
if getattr(e, 'winerror', 0) == 10056:
224+
break
225+
226+
# https://stackoverflow.com/a/54437602/3316267
227+
if getattr(e, 'winerror', 0) != 10022:
228+
raise
229+
i += 1
230+
if i >= 128:
231+
raise # too many retries
232+
# avoid touching event loop to maintain race condition
233+
time.sleep(0.01)
234+
235+
def test_sock_client_racing(self):
236+
with test_utils.run_test_server() as httpd:
237+
sock = socket.socket()
238+
with sock:
239+
self.loop.run_until_complete(asyncio.wait_for(
240+
self._basetest_sock_recv_racing(httpd, sock), 10))
241+
sock = socket.socket()
242+
with sock:
243+
self.loop.run_until_complete(asyncio.wait_for(
244+
self._basetest_sock_recv_into_racing(httpd, sock), 10))
245+
listener = socket.socket()
246+
sock = socket.socket()
247+
with listener, sock:
248+
self.loop.run_until_complete(asyncio.wait_for(
249+
self._basetest_sock_send_racing(listener, sock), 10))
250+
listener = socket.socket()
251+
sock = socket.socket()
252+
with listener, sock:
253+
self.loop.run_until_complete(asyncio.wait_for(
254+
self._basetest_sock_connect_racing(listener, sock), 10))
255+
125256
async def _basetest_huge_content(self, address):
126257
sock = socket.socket()
127258
sock.setblocking(False)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix asyncio ``loop.sock_*`` race condition issue

0 commit comments

Comments
 (0)