Skip to content

Commit bd43e1d

Browse files
committed
bpo-33654: Support BufferedProtocol in set_protocol() and start_tls()
1 parent 6e33f81 commit bd43e1d

File tree

10 files changed

+302
-33
lines changed

10 files changed

+302
-33
lines changed

Doc/library/asyncio-protocol.rst

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -469,10 +469,16 @@ the buffer only once at creation time.
469469
The following callbacks are called on :class:`BufferedProtocol`
470470
instances:
471471

472-
.. method:: BufferedProtocol.get_buffer()
472+
.. method:: BufferedProtocol.get_buffer(sizehint)
473473

474-
Called to allocate a new receive buffer. Must return an object
475-
that implements the :ref:`buffer protocol <bufferobjects>`.
474+
Called to allocate a new receive buffer.
475+
476+
*sizehint* is a recommended minimal size for the returned
477+
buffer. When set to -1, the buffer size can be arbitrary.
478+
479+
Must return an object that implements the
480+
:ref:`buffer protocol <bufferobjects>`.
481+
It is an error to return a zero-sized buffer.
476482

477483
.. method:: BufferedProtocol.buffer_updated(nbytes)
478484

Lib/asyncio/proactor_events.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(self, loop, sock, protocol, waiter=None,
3030
super().__init__(extra, loop)
3131
self._set_extra(sock)
3232
self._sock = sock
33-
self._protocol = protocol
33+
self.set_protocol(protocol)
3434
self._server = server
3535
self._buffer = None # None or bytearray.
3636
self._read_fut = None
@@ -159,16 +159,27 @@ class _ProactorReadPipeTransport(_ProactorBasePipeTransport,
159159

160160
def __init__(self, loop, sock, protocol, waiter=None,
161161
extra=None, server=None):
162+
self._loop_reading_cb = None
163+
self._paused = True
162164
super().__init__(loop, sock, protocol, waiter, extra, server)
163-
self._paused = False
165+
164166
self._reschedule_on_resume = False
167+
self._loop.call_soon(self._loop_reading)
168+
self._paused = False
165169

166-
if protocols._is_buffered_protocol(protocol):
167-
self._loop_reading = self._loop_reading__get_buffer
170+
def set_protocol(self, protocol):
171+
if isinstance(protocol, protocols.BufferedProtocol):
172+
self._loop_reading_cb = self._loop_reading__get_buffer
168173
else:
169-
self._loop_reading = self._loop_reading__data_received
174+
self._loop_reading_cb = self._loop_reading__data_received
170175

171-
self._loop.call_soon(self._loop_reading)
176+
super().set_protocol(protocol)
177+
178+
if not self._paused:
179+
# reset reading callback / buffers / self._read_fut
180+
self.pause_reading()
181+
self._reschedule_on_resume = True
182+
self.resume_reading()
172183

173184
def is_reading(self):
174185
return not self._paused and not self._closing
@@ -210,7 +221,10 @@ def _loop_reading__on_eof(self):
210221
if not keep_open:
211222
self.close()
212223

213-
def _loop_reading__data_received(self, fut=None):
224+
def _loop_reading(self, fut=None):
225+
self._loop_reading_cb(fut)
226+
227+
def _loop_reading__data_received(self, fut):
214228
if self._paused:
215229
self._reschedule_on_resume = True
216230
return
@@ -260,7 +274,7 @@ def _loop_reading__data_received(self, fut=None):
260274
elif data == b'':
261275
self._loop_reading__on_eof()
262276

263-
def _loop_reading__get_buffer(self, fut=None):
277+
def _loop_reading__get_buffer(self, fut):
264278
if self._paused:
265279
self._reschedule_on_resume = True
266280
return
@@ -310,7 +324,9 @@ def _loop_reading__get_buffer(self, fut=None):
310324
return
311325

312326
try:
313-
buf = self._protocol.get_buffer()
327+
buf = self._protocol.get_buffer(-1)
328+
if not len(buf):
329+
raise RuntimeError('get_buffer() returned an empty buffer')
314330
except Exception as exc:
315331
self._fatal_error(
316332
exc, 'Fatal error: protocol.get_buffer() call failed.')

Lib/asyncio/protocols.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -130,11 +130,15 @@ class BufferedProtocol(BaseProtocol):
130130
* CL: connection_lost()
131131
"""
132132

133-
def get_buffer(self):
133+
def get_buffer(self, sizehint):
134134
"""Called to allocate a new receive buffer.
135135
136+
*sizehint* is a recommended minimal size for the returned
137+
buffer. When set to -1, the buffer size can be arbitrary.
138+
136139
Must return an object that implements the
137140
:ref:`buffer protocol <bufferobjects>`.
141+
It is an error to return a zero-sized buffer.
138142
"""
139143

140144
def buffer_updated(self, nbytes):
@@ -185,7 +189,3 @@ def pipe_connection_lost(self, fd, exc):
185189

186190
def process_exited(self):
187191
"""Called when subprocess has exited."""
188-
189-
190-
def _is_buffered_protocol(proto):
191-
return hasattr(proto, 'get_buffer') and not hasattr(proto, 'data_received')

Lib/asyncio/selector_events.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -597,8 +597,10 @@ def __init__(self, loop, sock, protocol, extra=None, server=None):
597597
self._extra['peername'] = None
598598
self._sock = sock
599599
self._sock_fd = sock.fileno()
600-
self._protocol = protocol
601-
self._protocol_connected = True
600+
601+
self._protocol_connected = False
602+
self.set_protocol(protocol)
603+
602604
self._server = server
603605
self._buffer = self._buffer_factory()
604606
self._conn_lost = 0 # Set when call to connection_lost scheduled.
@@ -640,6 +642,7 @@ def abort(self):
640642

641643
def set_protocol(self, protocol):
642644
self._protocol = protocol
645+
self._protocol_connected = True
643646

644647
def get_protocol(self):
645648
return self._protocol
@@ -721,11 +724,7 @@ class _SelectorSocketTransport(_SelectorTransport):
721724
def __init__(self, loop, sock, protocol, waiter=None,
722725
extra=None, server=None):
723726

724-
if protocols._is_buffered_protocol(protocol):
725-
self._read_ready = self._read_ready__get_buffer
726-
else:
727-
self._read_ready = self._read_ready__data_received
728-
727+
self._read_ready_cb = None
729728
super().__init__(loop, sock, protocol, extra, server)
730729
self._eof = False
731730
self._paused = False
@@ -745,6 +744,14 @@ def __init__(self, loop, sock, protocol, waiter=None,
745744
self._loop.call_soon(futures._set_result_unless_cancelled,
746745
waiter, None)
747746

747+
def set_protocol(self, protocol):
748+
if isinstance(protocol, protocols.BufferedProtocol):
749+
self._read_ready_cb = self._read_ready__get_buffer
750+
else:
751+
self._read_ready_cb = self._read_ready__data_received
752+
753+
super().set_protocol(protocol)
754+
748755
def is_reading(self):
749756
return not self._paused and not self._closing
750757

@@ -764,12 +771,17 @@ def resume_reading(self):
764771
if self._loop.get_debug():
765772
logger.debug("%r resumes reading", self)
766773

774+
def _read_ready(self):
775+
self._read_ready_cb()
776+
767777
def _read_ready__get_buffer(self):
768778
if self._conn_lost:
769779
return
770780

771781
try:
772-
buf = self._protocol.get_buffer()
782+
buf = self._protocol.get_buffer(-1)
783+
if not len(buf):
784+
raise RuntimeError('get_buffer() returned an empty buffer')
773785
except Exception as exc:
774786
self._fatal_error(
775787
exc, 'Fatal error: protocol.get_buffer() call failed.')

Lib/asyncio/sslproto.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,8 @@ def __init__(self, loop, app_protocol, sslcontext, waiter,
441441
self._waiter = waiter
442442
self._loop = loop
443443
self._app_protocol = app_protocol
444+
self._app_protocol_is_buffer = \
445+
isinstance(app_protocol, protocols.BufferedProtocol)
444446
self._app_transport = _SSLProtocolTransport(self._loop, self)
445447
# _SSLPipe instance (None until the connection is made)
446448
self._sslpipe = None
@@ -522,7 +524,16 @@ def data_received(self, data):
522524

523525
for chunk in appdata:
524526
if chunk:
525-
self._app_protocol.data_received(chunk)
527+
try:
528+
if self._app_protocol_is_buffer:
529+
_feed_data_to_bufferred_proto(
530+
self._app_protocol, chunk)
531+
else:
532+
self._app_protocol.data_received(chunk)
533+
except Exception as ex:
534+
self._fatal_error(
535+
ex, 'application protocol failed to receive SSL data')
536+
return
526537
else:
527538
self._start_shutdown()
528539
break
@@ -709,3 +720,22 @@ def _abort(self):
709720
self._transport.abort()
710721
finally:
711722
self._finalize()
723+
724+
725+
def _feed_data_to_bufferred_proto(proto, data):
726+
data_len = len(data)
727+
while data_len:
728+
buf = proto.get_buffer(data_len)
729+
buf_len = len(buf)
730+
if not buf_len:
731+
raise RuntimeError('get_buffer() returned an empty buffer')
732+
733+
if buf_len >= data_len:
734+
buf[:data_len] = data
735+
proto.buffer_updated(data_len)
736+
return
737+
else:
738+
buf[:buf_len] = data[:buf_len]
739+
proto.buffer_updated(buf_len)
740+
data = data[buf_len:]
741+
data_len = len(data)

Lib/test/test_asyncio/test_buffered_proto.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def __init__(self, cb, con_lost_fut):
99
self.cb = cb
1010
self.con_lost_fut = con_lost_fut
1111

12-
def get_buffer(self):
12+
def get_buffer(self, sizehint):
1313
self.buffer = bytearray(100)
1414
return self.buffer
1515

Lib/test/test_asyncio/test_proactor_events.py

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -465,8 +465,8 @@ def setUp(self):
465465
self.loop._proactor = self.proactor
466466

467467
self.protocol = test_utils.make_test_protocol(asyncio.BufferedProtocol)
468-
self.buf = mock.Mock()
469-
self.protocol.get_buffer.side_effect = lambda: self.buf
468+
self.buf = bytearray(1)
469+
self.protocol.get_buffer.side_effect = lambda hint: self.buf
470470

471471
self.sock = mock.Mock(socket.socket)
472472

@@ -505,6 +505,62 @@ def test_get_buffer_error(self):
505505
self.assertTrue(self.protocol.get_buffer.called)
506506
self.assertFalse(self.protocol.buffer_updated.called)
507507

508+
def test_get_buffer_zerosized(self):
509+
transport = self.socket_transport()
510+
transport._fatal_error = mock.Mock()
511+
512+
self.loop.call_exception_handler = mock.Mock()
513+
self.protocol.get_buffer.side_effect = lambda hint: bytearray(0)
514+
515+
transport._loop_reading()
516+
517+
self.assertTrue(transport._fatal_error.called)
518+
self.assertTrue(self.protocol.get_buffer.called)
519+
self.assertFalse(self.protocol.buffer_updated.called)
520+
521+
def test_proto_type_switch(self):
522+
self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
523+
tr = self.socket_transport()
524+
525+
res = asyncio.Future(loop=self.loop)
526+
res.set_result(b'data')
527+
528+
tr = self.socket_transport()
529+
tr._read_fut = res
530+
tr._loop_reading(res)
531+
self.loop._proactor.recv.assert_called_with(self.sock, 32768)
532+
self.protocol.data_received.assert_called_with(b'data')
533+
534+
# switch protocol to a BufferedProtocol
535+
536+
buf_proto = test_utils.make_test_protocol(asyncio.BufferedProtocol)
537+
buf = bytearray(4)
538+
buf_proto.get_buffer.side_effect = lambda hint: buf
539+
540+
tr.set_protocol(buf_proto)
541+
res = asyncio.Future(loop=self.loop)
542+
res.set_result(4)
543+
544+
tr._read_fut = res
545+
tr._loop_reading(res)
546+
self.loop._proactor.recv_into.assert_called_with(self.sock, buf)
547+
buf_proto.buffer_updated.assert_called_with(4)
548+
549+
def test_proto_buf_switch(self):
550+
tr = self.socket_transport()
551+
test_utils.run_briefly(self.loop)
552+
self.protocol.get_buffer.assert_called_with(-1)
553+
554+
# switch protocol to *another* BufferedProtocol
555+
556+
buf_proto = test_utils.make_test_protocol(asyncio.BufferedProtocol)
557+
buf = bytearray(4)
558+
buf_proto.get_buffer.side_effect = lambda hint: buf
559+
tr.set_protocol(buf_proto)
560+
self.assertFalse(buf_proto.get_buffer.called)
561+
test_utils.run_briefly(self.loop)
562+
buf_proto.get_buffer.assert_called_with(-1)
563+
508564
def test_buffer_updated_error(self):
509565
transport = self.socket_transport()
510566
transport._fatal_error = mock.Mock()

Lib/test/test_asyncio/test_selector_events.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1279,8 +1279,8 @@ def setUp(self):
12791279
self.loop = self.new_test_loop()
12801280

12811281
self.protocol = test_utils.make_test_protocol(asyncio.BufferedProtocol)
1282-
self.buf = mock.Mock()
1283-
self.protocol.get_buffer.side_effect = lambda: self.buf
1282+
self.buf = bytearray(1)
1283+
self.protocol.get_buffer.side_effect = lambda hint: self.buf
12841284

12851285
self.sock = mock.Mock(socket.socket)
12861286
self.sock_fd = self.sock.fileno.return_value = 7
@@ -1313,6 +1313,42 @@ def test_get_buffer_error(self):
13131313
self.assertTrue(self.protocol.get_buffer.called)
13141314
self.assertFalse(self.protocol.buffer_updated.called)
13151315

1316+
def test_get_buffer_zerosized(self):
1317+
transport = self.socket_transport()
1318+
transport._fatal_error = mock.Mock()
1319+
1320+
self.loop.call_exception_handler = mock.Mock()
1321+
self.protocol.get_buffer.side_effect = lambda hint: bytearray(0)
1322+
1323+
transport._read_ready()
1324+
1325+
self.assertTrue(transport._fatal_error.called)
1326+
self.assertTrue(self.protocol.get_buffer.called)
1327+
self.assertFalse(self.protocol.buffer_updated.called)
1328+
1329+
def test_proto_type_switch(self):
1330+
self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
1331+
transport = self.socket_transport()
1332+
1333+
self.sock.recv.return_value = b'data'
1334+
transport._read_ready()
1335+
1336+
self.protocol.data_received.assert_called_with(b'data')
1337+
1338+
# switch protocol to a BufferedProtocol
1339+
1340+
buf_proto = test_utils.make_test_protocol(asyncio.BufferedProtocol)
1341+
buf = bytearray(4)
1342+
buf_proto.get_buffer.side_effect = lambda hint: buf
1343+
1344+
transport.set_protocol(buf_proto)
1345+
1346+
self.sock.recv_into.return_value = 10
1347+
transport._read_ready()
1348+
1349+
buf_proto.get_buffer.assert_called_with(-1)
1350+
buf_proto.buffer_updated.assert_called_with(10)
1351+
13161352
def test_buffer_updated_error(self):
13171353
transport = self.socket_transport()
13181354
transport._fatal_error = mock.Mock()
@@ -1348,7 +1384,7 @@ def test_read_ready(self):
13481384
self.sock.recv_into.return_value = 10
13491385
transport._read_ready()
13501386

1351-
self.protocol.get_buffer.assert_called_with()
1387+
self.protocol.get_buffer.assert_called_with(-1)
13521388
self.protocol.buffer_updated.assert_called_with(10)
13531389

13541390
def test_read_ready_eof(self):

0 commit comments

Comments
 (0)