Skip to content

Commit 103da60

Browse files
dpkp88manpreet
authored andcommitted
BrokerConnection receive bytes pipe (dpkp#1032)
1 parent 4a0cc37 commit 103da60

File tree

4 files changed

+123
-93
lines changed

4 files changed

+123
-93
lines changed

kafka/client_async.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -598,28 +598,18 @@ def _poll(self, timeout):
598598
continue
599599

600600
self._idle_expiry_manager.update(conn.node_id)
601-
602-
# Accumulate as many responses as the connection has pending
603-
while conn.in_flight_requests:
604-
response = conn.recv() # Note: conn.recv runs callbacks / errbacks
605-
606-
# Incomplete responses are buffered internally
607-
# while conn.in_flight_requests retains the request
608-
if not response:
609-
break
610-
responses.append(response)
601+
responses.extend(conn.recv()) # Note: conn.recv runs callbacks / errbacks
611602

612603
# Check for additional pending SSL bytes
613604
if self.config['security_protocol'] in ('SSL', 'SASL_SSL'):
614605
# TODO: optimize
615606
for conn in self._conns.values():
616607
if conn not in processed and conn.connected() and conn._sock.pending():
617-
response = conn.recv()
618-
if response:
619-
responses.append(response)
608+
responses.extend(conn.recv())
620609

621610
if self._sensors:
622611
self._sensors.io_time.record((time.time() - end_select) * 1000000000)
612+
623613
self._maybe_close_oldest_connection()
624614
return responses
625615

kafka/conn.py

Lines changed: 86 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import copy
55
import errno
66
import logging
7-
import io
87
from random import shuffle, uniform
98
import socket
109
import time
@@ -18,6 +17,7 @@
1817
from kafka.protocol.api import RequestHeader
1918
from kafka.protocol.admin import SaslHandShakeRequest
2019
from kafka.protocol.commit import GroupCoordinatorResponse, OffsetFetchRequest
20+
from kafka.protocol.frame import KafkaBytes
2121
from kafka.protocol.metadata import MetadataRequest
2222
from kafka.protocol.fetch import FetchRequest
2323
from kafka.protocol.types import Int32
@@ -231,9 +231,9 @@ def __init__(self, host, port, afi, **configs):
231231
if self.config['ssl_context'] is not None:
232232
self._ssl_context = self.config['ssl_context']
233233
self._sasl_auth_future = None
234-
self._rbuffer = io.BytesIO()
234+
self._header = KafkaBytes(4)
235+
self._rbuffer = None
235236
self._receiving = False
236-
self._next_payload_bytes = 0
237237
self.last_attempt = 0
238238
self._processing = False
239239
self._correlation_id = 0
@@ -637,17 +637,19 @@ def close(self, error=None):
637637
self.state = ConnectionStates.DISCONNECTED
638638
self.last_attempt = time.time()
639639
self._sasl_auth_future = None
640-
self._receiving = False
641-
self._next_payload_bytes = 0
642-
self._rbuffer.seek(0)
643-
self._rbuffer.truncate()
640+
self._reset_buffer()
644641
if error is None:
645642
error = Errors.Cancelled(str(self))
646643
while self.in_flight_requests:
647644
ifr = self.in_flight_requests.popleft()
648645
ifr.future.failure(error)
649646
self.config['state_change_callback'](self)
650647

648+
def _reset_buffer(self):
649+
self._receiving = False
650+
self._header.seek(0)
651+
self._rbuffer = None
652+
651653
def send(self, request):
652654
"""send request, return Future()
653655
@@ -721,116 +723,123 @@ def recv(self):
721723
# fail all the pending request futures
722724
if self.in_flight_requests:
723725
self.close(Errors.ConnectionError('Socket not connected during recv with in-flight-requests'))
724-
return None
726+
return ()
725727

726728
elif not self.in_flight_requests:
727729
log.warning('%s: No in-flight-requests to recv', self)
728-
return None
730+
return ()
729731

730732
elif self._requests_timed_out():
731733
log.warning('%s timed out after %s ms. Closing connection.',
732734
self, self.config['request_timeout_ms'])
733735
self.close(error=Errors.RequestTimedOutError(
734736
'Request timed out after %s ms' %
735737
self.config['request_timeout_ms']))
736-
return None
738+
return ()
737739

740+
# TODO: manpreet: Decide to return response/None
741+
# return response
738742
return self._recv()
739743

740744
def _recv(self):
741-
# Not receiving is the state of reading the payload header
742-
if not self._receiving:
745+
responses = []
746+
SOCK_CHUNK_BYTES = 4096
747+
while True:
743748
try:
744-
bytes_to_read = 4 - self._rbuffer.tell()
745-
data = self._sock.recv(bytes_to_read)
749+
data = self._sock.recv(SOCK_CHUNK_BYTES)
746750
# We expect socket.recv to raise an exception if there is not
747751
# enough data to read the full bytes_to_read
748752
# but if the socket is disconnected, we will get empty data
749753
# without an exception raised
750754
if not data:
751755
log.error('%s: socket disconnected', self)
752756
self.close(error=Errors.ConnectionError('socket disconnected'))
753-
return None
754-
self._rbuffer.write(data)
757+
break
758+
else:
759+
responses.extend(self.receive_bytes(data))
760+
if len(data) < SOCK_CHUNK_BYTES:
761+
break
755762
except SSLWantReadError:
756-
return None
763+
break
757764
except ConnectionError as e:
758765
if six.PY2 and e.errno == errno.EWOULDBLOCK:
759-
return None
760-
log.exception('%s: Error receiving 4-byte payload header -'
766+
break
767+
log.exception('%s: Error receiving network data'
761768
' closing socket', self)
762769
self.close(error=Errors.ConnectionError(e))
763-
return None
764-
except BlockingIOError:
765-
if six.PY3:
766-
return None
767-
raise
768-
769-
if self._rbuffer.tell() == 4:
770-
self._rbuffer.seek(0)
771-
self._next_payload_bytes = Int32.decode(self._rbuffer)
772-
# reset buffer and switch state to receiving payload bytes
773-
self._rbuffer.seek(0)
774-
self._rbuffer.truncate()
775-
self._receiving = True
776-
elif self._rbuffer.tell() > 4:
777-
raise Errors.KafkaError('this should not happen - are you threading?')
778-
779-
if self._receiving:
780-
staged_bytes = self._rbuffer.tell()
781-
try:
782-
bytes_to_read = self._next_payload_bytes - staged_bytes
783-
data = self._sock.recv(bytes_to_read)
784-
# We expect socket.recv to raise an exception if there is not
785-
# enough data to read the full bytes_to_read
786-
# but if the socket is disconnected, we will get empty data
787-
# without an exception raised
788-
if bytes_to_read and not data:
789-
log.error('%s: socket disconnected', self)
790-
self.close(error=Errors.ConnectionError('socket disconnected'))
791-
return None
792-
self._rbuffer.write(data)
793-
except SSLWantReadError:
794-
return None
795-
except ConnectionError as e:
796-
# Extremely small chance that we have exactly 4 bytes for a
797-
# header, but nothing to read in the body yet
798-
if six.PY2 and e.errno == errno.EWOULDBLOCK:
799-
return None
800-
log.exception('%s: Error in recv', self)
801-
self.close(error=Errors.ConnectionError(e))
802-
return None
770+
break
803771
except BlockingIOError:
804772
if six.PY3:
805-
return None
773+
break
806774
raise
775+
return responses
807776

808-
staged_bytes = self._rbuffer.tell()
809-
if staged_bytes > self._next_payload_bytes:
810-
self.close(error=Errors.KafkaError('Receive buffer has more bytes than expected?'))
811-
812-
if staged_bytes != self._next_payload_bytes:
813-
return None
777+
def receive_bytes(self, data):
778+
i = 0
779+
n = len(data)
780+
responses = []
781+
if self._sensors:
782+
self._sensors.bytes_received.record(n)
783+
while i < n:
784+
785+
# Not receiving is the state of reading the payload header
786+
if not self._receiving:
787+
bytes_to_read = min(4 - self._header.tell(), n - i)
788+
self._header.write(data[i:i+bytes_to_read])
789+
i += bytes_to_read
790+
791+
if self._header.tell() == 4:
792+
self._header.seek(0)
793+
nbytes = Int32.decode(self._header)
794+
# reset buffer and switch state to receiving payload bytes
795+
self._rbuffer = KafkaBytes(nbytes)
796+
self._receiving = True
797+
elif self._header.tell() > 4:
798+
raise Errors.KafkaError('this should not happen - are you threading?')
799+
800+
801+
if self._receiving:
802+
total_bytes = len(self._rbuffer)
803+
staged_bytes = self._rbuffer.tell()
804+
bytes_to_read = min(total_bytes - staged_bytes, n - i)
805+
self._rbuffer.write(data[i:i+bytes_to_read])
806+
i += bytes_to_read
807+
808+
staged_bytes = self._rbuffer.tell()
809+
if staged_bytes > total_bytes:
810+
self.close(error=Errors.KafkaError('Receive buffer has more bytes than expected?'))
811+
812+
if staged_bytes != total_bytes:
813+
break
814814

815-
self._receiving = False
816-
self._next_payload_bytes = 0
817-
if self._sensors:
818-
self._sensors.bytes_received.record(4 + self._rbuffer.tell())
819-
self._rbuffer.seek(0)
820-
response = self._process_response(self._rbuffer)
821-
self._rbuffer.seek(0)
822-
self._rbuffer.truncate()
823-
return response
815+
self._receiving = False
816+
self._rbuffer.seek(0)
817+
resp = self._process_response(self._rbuffer)
818+
if resp is not None:
819+
responses.append(resp)
820+
self._reset_buffer()
821+
return responses
824822

825823
def _process_response(self, read_buffer):
826824
assert not self._processing, 'Recursion not supported'
827825
self._processing = True
828-
ifr = self.in_flight_requests.popleft()
826+
recv_correlation_id = Int32.decode(read_buffer)
827+
828+
if not self.in_flight_requests:
829+
error = Errors.CorrelationIdError(
830+
'%s: No in-flight-request found for server response'
831+
' with correlation ID %d'
832+
% (self, recv_correlation_id))
833+
self.close(error)
834+
self._processing = False
835+
return None
836+
else:
837+
ifr = self.in_flight_requests.popleft()
838+
829839
if self._sensors:
830840
self._sensors.request_time.record((time.time() - ifr.timestamp) * 1000)
831841

832842
# verify send/recv correlation ids match
833-
recv_correlation_id = Int32.decode(read_buffer)
834843

835844
# 0.8.2 quirk
836845
if (self.config['api_version'] == (0, 8, 2) and

kafka/protocol/frame.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
class KafkaBytes(bytearray):
2+
def __init__(self, size):
3+
super(KafkaBytes, self).__init__(size)
4+
self._idx = 0
5+
6+
def read(self, nbytes=None):
7+
if nbytes is None:
8+
nbytes = len(self) - self._idx
9+
start = self._idx
10+
self._idx += nbytes
11+
if self._idx > len(self):
12+
self._idx = len(self)
13+
return bytes(self[start:self._idx])
14+
15+
def write(self, data):
16+
start = self._idx
17+
self._idx += len(data)
18+
self[start:self._idx] = data
19+
20+
def seek(self, idx):
21+
self._idx = idx
22+
23+
def tell(self):
24+
return self._idx
25+
26+
def __str__(self):
27+
return 'KafkaBytes(%d)' % len(self)
28+
29+
def __repr__(self):
30+
return str(self)

kafka/protocol/message.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from ..codec import (has_gzip, has_snappy, has_lz4,
77
gzip_decode, snappy_decode,
88
lz4_decode, lz4_decode_old_kafka)
9+
from .frame import KafkaBytes
910
from .struct import Struct
1011
from .types import (
1112
Int8, Int32, Int64, Bytes, Schema, AbstractType
@@ -155,10 +156,10 @@ class MessageSet(AbstractType):
155156
@classmethod
156157
def encode(cls, items):
157158
# RecordAccumulator encodes messagesets internally
158-
if isinstance(items, io.BytesIO):
159+
if isinstance(items, (io.BytesIO, KafkaBytes)):
159160
size = Int32.decode(items)
160161
# rewind and return all the bytes
161-
items.seek(-4, 1)
162+
items.seek(items.tell() - 4)
162163
return items.read(size + 4)
163164

164165
encoded_values = []
@@ -198,7 +199,7 @@ def decode(cls, data, bytes_to_read=None):
198199

199200
@classmethod
200201
def repr(cls, messages):
201-
if isinstance(messages, io.BytesIO):
202+
if isinstance(messages, (KafkaBytes, io.BytesIO)):
202203
offset = messages.tell()
203204
decoded = cls.decode(messages)
204205
messages.seek(offset)

0 commit comments

Comments
 (0)