Skip to content

Commit a959f27

Browse files
committed
Add support for BufferedProtocol
1 parent 77ee4f9 commit a959f27

File tree

9 files changed

+339
-46
lines changed

9 files changed

+339
-46
lines changed

tests/test_tcp.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,154 @@ async def runner():
629629

630630
class Test_UV_TCP(_TestTCP, tb.UVTestCase):
631631

632+
def test_create_server_buffered_1(self):
633+
SIZE = 123123
634+
635+
class Proto(asyncio.BaseProtocol):
636+
def connection_made(self, tr):
637+
self.tr = tr
638+
self.recvd = b''
639+
self.data = bytearray(50)
640+
self.buf = memoryview(self.data)
641+
642+
def get_buffer(self):
643+
return self.buf
644+
645+
def buffer_updated(self, nbytes):
646+
self.recvd += self.buf[:nbytes]
647+
if self.recvd == b'a' * SIZE:
648+
self.tr.write(b'hello')
649+
650+
def eof_received(self):
651+
pass
652+
653+
async def test():
654+
port = tb.find_free_port()
655+
srv = await self.loop.create_server(Proto, '127.0.0.1', port)
656+
657+
s = socket.socket(socket.AF_INET)
658+
with s:
659+
s.setblocking(False)
660+
await self.loop.sock_connect(s, ('127.0.0.1', port))
661+
await self.loop.sock_sendall(s, b'a' * SIZE)
662+
d = await self.loop.sock_recv(s, 100)
663+
self.assertEqual(d, b'hello')
664+
665+
srv.close()
666+
await srv.wait_closed()
667+
668+
self.loop.run_until_complete(test())
669+
670+
def test_create_server_buffered_2(self):
671+
class ProtoExc(asyncio.BaseProtocol):
672+
def __init__(self):
673+
self._lost_exc = None
674+
675+
def get_buffer(self):
676+
1 / 0
677+
678+
def buffer_updated(self, nbytes):
679+
pass
680+
681+
def connection_lost(self, exc):
682+
self._lost_exc = exc
683+
684+
def eof_received(self):
685+
pass
686+
687+
class ProtoZeroBuf1(asyncio.BaseProtocol):
688+
def __init__(self):
689+
self._lost_exc = None
690+
691+
def get_buffer(self):
692+
return bytearray(0)
693+
694+
def buffer_updated(self, nbytes):
695+
pass
696+
697+
def connection_lost(self, exc):
698+
self._lost_exc = exc
699+
700+
def eof_received(self):
701+
pass
702+
703+
class ProtoZeroBuf2(asyncio.BaseProtocol):
704+
def __init__(self):
705+
self._lost_exc = None
706+
707+
def get_buffer(self):
708+
return memoryview(bytearray(0))
709+
710+
def buffer_updated(self, nbytes):
711+
pass
712+
713+
def connection_lost(self, exc):
714+
self._lost_exc = exc
715+
716+
def eof_received(self):
717+
pass
718+
719+
class ProtoUpdatedError(asyncio.BaseProtocol):
720+
def __init__(self):
721+
self._lost_exc = None
722+
723+
def get_buffer(self):
724+
return memoryview(bytearray(100))
725+
726+
def buffer_updated(self, nbytes):
727+
raise RuntimeError('oups')
728+
729+
def connection_lost(self, exc):
730+
self._lost_exc = exc
731+
732+
def eof_received(self):
733+
pass
734+
735+
async def test(proto_factory, exc_type, exc_re):
736+
port = tb.find_free_port()
737+
proto = proto_factory()
738+
srv = await self.loop.create_server(
739+
lambda: proto, '127.0.0.1', port)
740+
741+
try:
742+
s = socket.socket(socket.AF_INET)
743+
with s:
744+
s.setblocking(False)
745+
await self.loop.sock_connect(s, ('127.0.0.1', port))
746+
await self.loop.sock_sendall(s, b'a')
747+
d = await self.loop.sock_recv(s, 100)
748+
if not d:
749+
raise ConnectionResetError
750+
except ConnectionResetError:
751+
pass
752+
else:
753+
self.fail("server didn't abort the connection")
754+
return
755+
finally:
756+
srv.close()
757+
await srv.wait_closed()
758+
759+
if proto._lost_exc is None:
760+
self.fail("connection_lost() was not called")
761+
return
762+
763+
with self.assertRaisesRegex(exc_type, exc_re):
764+
raise proto._lost_exc
765+
766+
self.loop.set_exception_handler(lambda loop, ctx: None)
767+
768+
self.loop.run_until_complete(
769+
test(ProtoExc, RuntimeError, 'unhandled error .* get_buffer'))
770+
771+
self.loop.run_until_complete(
772+
test(ProtoZeroBuf1, RuntimeError, 'unhandled error .* get_buffer'))
773+
774+
self.loop.run_until_complete(
775+
test(ProtoZeroBuf2, RuntimeError, 'unhandled error .* get_buffer'))
776+
777+
self.loop.run_until_complete(
778+
test(ProtoUpdatedError, RuntimeError, r'^oups$'))
779+
632780
def test_transport_get_extra_info(self):
633781
# This tests is only for uvloop. asyncio should pass it
634782
# too in Python 3.6.

uvloop/handles/basetransport.pxd

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@ cdef class UVBaseTransport(UVSocketHandle):
3838

3939
cdef inline _set_server(self, Server server)
4040
cdef inline _set_waiter(self, object waiter)
41-
cdef inline _set_protocol(self, object protocol)
41+
42+
cdef _set_protocol(self, object protocol)
43+
cdef _clear_protocol(self)
4244

4345
cdef inline _init_protocol(self)
4446
cdef inline _add_extra_info(self, str name, object obj)

uvloop/handles/basetransport.pyx

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,7 @@ cdef class UVBaseTransport(UVSocketHandle):
201201
if self._protocol_connected:
202202
self._protocol.connection_lost(exc)
203203
finally:
204-
self._protocol = None
205-
self._protocol_data_received = None
204+
self._clear_protocol()
206205

207206
self._close()
208207

@@ -223,14 +222,18 @@ cdef class UVBaseTransport(UVSocketHandle):
223222

224223
self._waiter = waiter
225224

226-
cdef inline _set_protocol(self, object protocol):
225+
cdef _set_protocol(self, object protocol):
227226
self._protocol = protocol
228227
# Store a reference to the bound method directly
229228
try:
230229
self._protocol_data_received = protocol.data_received
231230
except AttributeError:
232231
pass
233232

233+
cdef _clear_protocol(self):
234+
self._protocol = None
235+
self._protocol_data_received = None
236+
234237
cdef inline _init_protocol(self):
235238
self._loop._track_transport(self)
236239
if self._protocol is None:
@@ -263,6 +266,9 @@ cdef class UVBaseTransport(UVSocketHandle):
263266

264267
def set_protocol(self, protocol):
265268
self._set_protocol(protocol)
269+
if self._is_reading():
270+
self._stop_reading()
271+
self._start_reading()
266272

267273
def _force_close(self, exc):
268274
# Used by SSLProto. Might be removed in the future.

uvloop/handles/pipe.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ cdef void __pipe_connect_callback(uv.uv_connect_t* req, int status) with gil:
207207
try:
208208
transport._on_connect(exc)
209209
except BaseException as ex:
210-
wrapper.transport._error(ex, False)
210+
wrapper.transport._fatal_error(ex, False)
211211
finally:
212212
wrapper.on_done()
213213

uvloop/handles/stream.pxd

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,18 @@ cdef class UVStream(UVBaseTransport):
44
bint __shutting_down
55
bint __reading
66
bint __read_error_close
7+
8+
bint __buffered
9+
object _protocol_get_buffer
10+
object _protocol_buffer_updated
11+
712
bint _eof
813
list _buffer
914
size_t _buffer_size
1015

16+
Py_buffer _read_pybuf
17+
bint _read_pybuf_acquired
18+
1119
# All "inline" methods are final
1220

1321
cdef inline _init(self, Loop loop, object protocol, Server server,
@@ -29,7 +37,6 @@ cdef class UVStream(UVBaseTransport):
2937
cdef _close(self)
3038

3139
cdef inline _on_accept(self)
32-
cdef inline _on_read(self, bytes buf)
3340
cdef inline _on_eof(self)
3441
cdef inline _on_write(self)
3542
cdef inline _on_connect(self, object exc)

0 commit comments

Comments
 (0)