Skip to content

Commit c122390

Browse files
GH-91166: Implement zero copy writes for SelectorSocketTransport in asyncio (#31871)
Co-authored-by: Guido van Rossum <[email protected]>
1 parent 0f64206 commit c122390

File tree

3 files changed

+175
-29
lines changed

3 files changed

+175
-29
lines changed

Lib/asyncio/selector_events.py

Lines changed: 75 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import collections
1010
import errno
1111
import functools
12+
import itertools
13+
import os
1214
import selectors
1315
import socket
1416
import warnings
@@ -28,6 +30,14 @@
2830
from . import trsock
2931
from .log import logger
3032

33+
_HAS_SENDMSG = hasattr(socket.socket, 'sendmsg')
34+
35+
if _HAS_SENDMSG:
36+
try:
37+
SC_IOV_MAX = os.sysconf('SC_IOV_MAX')
38+
except OSError:
39+
# Fallback to send
40+
_HAS_SENDMSG = False
3141

3242
def _test_selector_event(selector, fd, event):
3343
# Test if the selector is monitoring 'event' events
@@ -757,8 +767,6 @@ class _SelectorTransport(transports._FlowControlMixin,
757767

758768
max_size = 256 * 1024 # Buffer size passed to recv().
759769

760-
_buffer_factory = bytearray # Constructs initial value for self._buffer.
761-
762770
# Attribute used in the destructor: it must be set even if the constructor
763771
# is not called (see _SelectorSslTransport which may start by raising an
764772
# exception)
@@ -783,7 +791,7 @@ def __init__(self, loop, sock, protocol, extra=None, server=None):
783791
self.set_protocol(protocol)
784792

785793
self._server = server
786-
self._buffer = self._buffer_factory()
794+
self._buffer = collections.deque()
787795
self._conn_lost = 0 # Set when call to connection_lost scheduled.
788796
self._closing = False # Set when close() called.
789797
if self._server is not None:
@@ -887,7 +895,7 @@ def _call_connection_lost(self, exc):
887895
self._server = None
888896

889897
def get_write_buffer_size(self):
890-
return len(self._buffer)
898+
return sum(map(len, self._buffer))
891899

892900
def _add_reader(self, fd, callback, *args):
893901
if self._closing:
@@ -909,7 +917,10 @@ def __init__(self, loop, sock, protocol, waiter=None,
909917
self._eof = False
910918
self._paused = False
911919
self._empty_waiter = None
912-
920+
if _HAS_SENDMSG:
921+
self._write_ready = self._write_sendmsg
922+
else:
923+
self._write_ready = self._write_send
913924
# Disable the Nagle algorithm -- small writes will be
914925
# sent without waiting for the TCP ACK. This generally
915926
# decreases the latency (in some cases significantly.)
@@ -1066,23 +1077,68 @@ def write(self, data):
10661077
self._fatal_error(exc, 'Fatal write error on socket transport')
10671078
return
10681079
else:
1069-
data = data[n:]
1080+
data = memoryview(data)[n:]
10701081
if not data:
10711082
return
10721083
# Not all was written; register write handler.
10731084
self._loop._add_writer(self._sock_fd, self._write_ready)
10741085

10751086
# Add it to the buffer.
1076-
self._buffer.extend(data)
1087+
self._buffer.append(data)
10771088
self._maybe_pause_protocol()
10781089

1079-
def _write_ready(self):
1090+
def _get_sendmsg_buffer(self):
1091+
return itertools.islice(self._buffer, SC_IOV_MAX)
1092+
1093+
def _write_sendmsg(self):
10801094
assert self._buffer, 'Data should not be empty'
1095+
if self._conn_lost:
1096+
return
1097+
try:
1098+
nbytes = self._sock.sendmsg(self._get_sendmsg_buffer())
1099+
self._adjust_leftover_buffer(nbytes)
1100+
except (BlockingIOError, InterruptedError):
1101+
pass
1102+
except (SystemExit, KeyboardInterrupt):
1103+
raise
1104+
except BaseException as exc:
1105+
self._loop._remove_writer(self._sock_fd)
1106+
self._buffer.clear()
1107+
self._fatal_error(exc, 'Fatal write error on socket transport')
1108+
if self._empty_waiter is not None:
1109+
self._empty_waiter.set_exception(exc)
1110+
else:
1111+
self._maybe_resume_protocol() # May append to buffer.
1112+
if not self._buffer:
1113+
self._loop._remove_writer(self._sock_fd)
1114+
if self._empty_waiter is not None:
1115+
self._empty_waiter.set_result(None)
1116+
if self._closing:
1117+
self._call_connection_lost(None)
1118+
elif self._eof:
1119+
self._sock.shutdown(socket.SHUT_WR)
10811120

1121+
def _adjust_leftover_buffer(self, nbytes: int) -> None:
1122+
buffer = self._buffer
1123+
while nbytes:
1124+
b = buffer.popleft()
1125+
b_len = len(b)
1126+
if b_len <= nbytes:
1127+
nbytes -= b_len
1128+
else:
1129+
buffer.appendleft(b[nbytes:])
1130+
break
1131+
1132+
def _write_send(self):
1133+
assert self._buffer, 'Data should not be empty'
10821134
if self._conn_lost:
10831135
return
10841136
try:
1085-
n = self._sock.send(self._buffer)
1137+
buffer = self._buffer.popleft()
1138+
n = self._sock.send(buffer)
1139+
if n != len(buffer):
1140+
# Not all data was written
1141+
self._buffer.appendleft(buffer[n:])
10861142
except (BlockingIOError, InterruptedError):
10871143
pass
10881144
except (SystemExit, KeyboardInterrupt):
@@ -1094,8 +1150,6 @@ def _write_ready(self):
10941150
if self._empty_waiter is not None:
10951151
self._empty_waiter.set_exception(exc)
10961152
else:
1097-
if n:
1098-
del self._buffer[:n]
10991153
self._maybe_resume_protocol() # May append to buffer.
11001154
if not self._buffer:
11011155
self._loop._remove_writer(self._sock_fd)
@@ -1113,6 +1167,16 @@ def write_eof(self):
11131167
if not self._buffer:
11141168
self._sock.shutdown(socket.SHUT_WR)
11151169

1170+
def writelines(self, list_of_data):
1171+
if self._eof:
1172+
raise RuntimeError('Cannot call writelines() after write_eof()')
1173+
if self._empty_waiter is not None:
1174+
raise RuntimeError('unable to writelines; sendfile is in progress')
1175+
if not list_of_data:
1176+
return
1177+
self._buffer.extend([memoryview(data) for data in list_of_data])
1178+
self._write_ready()
1179+
11161180
def can_write_eof(self):
11171181
return True
11181182

Lib/test/test_asyncio/test_selector_events.py

Lines changed: 99 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,25 @@
11
"""Tests for selector_events.py"""
22

3-
import sys
3+
import collections
44
import selectors
55
import socket
6+
import sys
67
import unittest
8+
from asyncio import selector_events
79
from unittest import mock
10+
811
try:
912
import ssl
1013
except ImportError:
1114
ssl = None
1215

1316
import asyncio
14-
from asyncio.selector_events import BaseSelectorEventLoop
15-
from asyncio.selector_events import _SelectorTransport
16-
from asyncio.selector_events import _SelectorSocketTransport
17-
from asyncio.selector_events import _SelectorDatagramTransport
17+
from asyncio.selector_events import (BaseSelectorEventLoop,
18+
_SelectorDatagramTransport,
19+
_SelectorSocketTransport,
20+
_SelectorTransport)
1821
from test.test_asyncio import utils as test_utils
1922

20-
2123
MOCK_ANY = mock.ANY
2224

2325

@@ -37,7 +39,10 @@ def _close_self_pipe(self):
3739

3840

3941
def list_to_buffer(l=()):
40-
return bytearray().join(l)
42+
buffer = collections.deque()
43+
buffer.extend((memoryview(i) for i in l))
44+
return buffer
45+
4146

4247

4348
def close_transport(transport):
@@ -493,9 +498,13 @@ def setUp(self):
493498
self.sock = mock.Mock(socket.socket)
494499
self.sock_fd = self.sock.fileno.return_value = 7
495500

496-
def socket_transport(self, waiter=None):
501+
def socket_transport(self, waiter=None, sendmsg=False):
497502
transport = _SelectorSocketTransport(self.loop, self.sock,
498503
self.protocol, waiter=waiter)
504+
if sendmsg:
505+
transport._write_ready = transport._write_sendmsg
506+
else:
507+
transport._write_ready = transport._write_send
499508
self.addCleanup(close_transport, transport)
500509
return transport
501510

@@ -664,14 +673,14 @@ def test_write_memoryview(self):
664673

665674
def test_write_no_data(self):
666675
transport = self.socket_transport()
667-
transport._buffer.extend(b'data')
676+
transport._buffer.append(memoryview(b'data'))
668677
transport.write(b'')
669678
self.assertFalse(self.sock.send.called)
670679
self.assertEqual(list_to_buffer([b'data']), transport._buffer)
671680

672681
def test_write_buffer(self):
673682
transport = self.socket_transport()
674-
transport._buffer.extend(b'data1')
683+
transport._buffer.append(b'data1')
675684
transport.write(b'data2')
676685
self.assertFalse(self.sock.send.called)
677686
self.assertEqual(list_to_buffer([b'data1', b'data2']),
@@ -729,6 +738,77 @@ def test_write_tryagain(self):
729738
self.loop.assert_writer(7, transport._write_ready)
730739
self.assertEqual(list_to_buffer([b'data']), transport._buffer)
731740

741+
def test_write_sendmsg_no_data(self):
742+
self.sock.sendmsg = mock.Mock()
743+
self.sock.sendmsg.return_value = 0
744+
transport = self.socket_transport(sendmsg=True)
745+
transport._buffer.append(memoryview(b'data'))
746+
transport.write(b'')
747+
self.assertFalse(self.sock.sendmsg.called)
748+
self.assertEqual(list_to_buffer([b'data']), transport._buffer)
749+
750+
@unittest.skipUnless(selector_events._HAS_SENDMSG, 'no sendmsg')
751+
def test_write_sendmsg_full(self):
752+
data = memoryview(b'data')
753+
self.sock.sendmsg = mock.Mock()
754+
self.sock.sendmsg.return_value = len(data)
755+
756+
transport = self.socket_transport(sendmsg=True)
757+
transport._buffer.append(data)
758+
self.loop._add_writer(7, transport._write_ready)
759+
transport._write_ready()
760+
self.assertTrue(self.sock.sendmsg.called)
761+
self.assertFalse(self.loop.writers)
762+
763+
@unittest.skipUnless(selector_events._HAS_SENDMSG, 'no sendmsg')
764+
def test_write_sendmsg_partial(self):
765+
766+
data = memoryview(b'data')
767+
self.sock.sendmsg = mock.Mock()
768+
# Sent partial data
769+
self.sock.sendmsg.return_value = 2
770+
771+
transport = self.socket_transport(sendmsg=True)
772+
transport._buffer.append(data)
773+
self.loop._add_writer(7, transport._write_ready)
774+
transport._write_ready()
775+
self.assertTrue(self.sock.sendmsg.called)
776+
self.assertTrue(self.loop.writers)
777+
self.assertEqual(list_to_buffer([b'ta']), transport._buffer)
778+
779+
@unittest.skipUnless(selector_events._HAS_SENDMSG, 'no sendmsg')
780+
def test_write_sendmsg_half_buffer(self):
781+
data = [memoryview(b'data1'), memoryview(b'data2')]
782+
self.sock.sendmsg = mock.Mock()
783+
# Sent partial data
784+
self.sock.sendmsg.return_value = 2
785+
786+
transport = self.socket_transport(sendmsg=True)
787+
transport._buffer.extend(data)
788+
self.loop._add_writer(7, transport._write_ready)
789+
transport._write_ready()
790+
self.assertTrue(self.sock.sendmsg.called)
791+
self.assertTrue(self.loop.writers)
792+
self.assertEqual(list_to_buffer([b'ta1', b'data2']), transport._buffer)
793+
794+
@unittest.skipUnless(selector_events._HAS_SENDMSG, 'no sendmsg')
795+
def test_write_sendmsg_OSError(self):
796+
data = memoryview(b'data')
797+
self.sock.sendmsg = mock.Mock()
798+
err = self.sock.sendmsg.side_effect = OSError()
799+
800+
transport = self.socket_transport(sendmsg=True)
801+
transport._fatal_error = mock.Mock()
802+
transport._buffer.extend(data)
803+
# Calls _fatal_error and clears the buffer
804+
transport._write_ready()
805+
self.assertTrue(self.sock.sendmsg.called)
806+
self.assertFalse(self.loop.writers)
807+
self.assertEqual(list_to_buffer([]), transport._buffer)
808+
transport._fatal_error.assert_called_with(
809+
err,
810+
'Fatal write error on socket transport')
811+
732812
@mock.patch('asyncio.selector_events.logger')
733813
def test_write_exception(self, m_log):
734814
err = self.sock.send.side_effect = OSError()
@@ -768,19 +848,19 @@ def test_write_ready(self):
768848
self.sock.send.return_value = len(data)
769849

770850
transport = self.socket_transport()
771-
transport._buffer.extend(data)
851+
transport._buffer.append(data)
772852
self.loop._add_writer(7, transport._write_ready)
773853
transport._write_ready()
774854
self.assertTrue(self.sock.send.called)
775855
self.assertFalse(self.loop.writers)
776856

777857
def test_write_ready_closing(self):
778-
data = b'data'
858+
data = memoryview(b'data')
779859
self.sock.send.return_value = len(data)
780860

781861
transport = self.socket_transport()
782862
transport._closing = True
783-
transport._buffer.extend(data)
863+
transport._buffer.append(data)
784864
self.loop._add_writer(7, transport._write_ready)
785865
transport._write_ready()
786866
self.assertTrue(self.sock.send.called)
@@ -795,11 +875,11 @@ def test_write_ready_no_data(self):
795875
self.assertRaises(AssertionError, transport._write_ready)
796876

797877
def test_write_ready_partial(self):
798-
data = b'data'
878+
data = memoryview(b'data')
799879
self.sock.send.return_value = 2
800880

801881
transport = self.socket_transport()
802-
transport._buffer.extend(data)
882+
transport._buffer.append(data)
803883
self.loop._add_writer(7, transport._write_ready)
804884
transport._write_ready()
805885
self.loop.assert_writer(7, transport._write_ready)
@@ -810,7 +890,7 @@ def test_write_ready_partial_none(self):
810890
self.sock.send.return_value = 0
811891

812892
transport = self.socket_transport()
813-
transport._buffer.extend(data)
893+
transport._buffer.append(data)
814894
self.loop._add_writer(7, transport._write_ready)
815895
transport._write_ready()
816896
self.loop.assert_writer(7, transport._write_ready)
@@ -820,12 +900,13 @@ def test_write_ready_tryagain(self):
820900
self.sock.send.side_effect = BlockingIOError
821901

822902
transport = self.socket_transport()
823-
transport._buffer = list_to_buffer([b'data1', b'data2'])
903+
buffer = list_to_buffer([b'data1', b'data2'])
904+
transport._buffer = buffer
824905
self.loop._add_writer(7, transport._write_ready)
825906
transport._write_ready()
826907

827908
self.loop.assert_writer(7, transport._write_ready)
828-
self.assertEqual(list_to_buffer([b'data1data2']), transport._buffer)
909+
self.assertEqual(buffer, transport._buffer)
829910

830911
def test_write_ready_exception(self):
831912
err = self.sock.send.side_effect = OSError()
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
:mod:`asyncio` is optimized to avoid excessive copying when writing to socket and use :meth:`~socket.socket.sendmsg` if the platform supports it. Patch by Kumar Aditya.

0 commit comments

Comments
 (0)