Skip to content

Commit 7ceaad5

Browse files
committed
BrokerConnection.receive_bytes(data) -> response events
1 parent 2ca7e77 commit 7ceaad5

File tree

3 files changed

+78
-91
lines changed

3 files changed

+78
-91
lines changed

kafka/client_async.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -585,25 +585,14 @@ def _poll(self, timeout, sleep=True):
585585
continue
586586

587587
self._idle_expiry_manager.update(conn.node_id)
588-
589-
# Accumulate as many responses as the connection has pending
590-
while conn.in_flight_requests:
591-
response = conn.recv() # Note: conn.recv runs callbacks / errbacks
592-
593-
# Incomplete responses are buffered internally
594-
# while conn.in_flight_requests retains the request
595-
if not response:
596-
break
597-
responses.append(response)
588+
responses.extend(conn.recv()) # Note: conn.recv runs callbacks / errbacks
598589

599590
# Check for additional pending SSL bytes
600591
if self.config['security_protocol'] in ('SSL', 'SASL_SSL'):
601592
# TODO: optimize
602593
for conn in self._conns.values():
603594
if conn not in processed and conn.connected() and conn._sock.pending():
604-
response = conn.recv()
605-
if response:
606-
responses.append(response)
595+
responses.extend(conn.recv())
607596

608597
for conn in six.itervalues(self._conns):
609598
if conn.requests_timed_out():
@@ -615,6 +604,7 @@ def _poll(self, timeout, sleep=True):
615604

616605
if self._sensors:
617606
self._sensors.io_time.record((time.time() - end_select) * 1000000000)
607+
618608
self._maybe_close_oldest_connection()
619609
return responses
620610

kafka/conn.py

Lines changed: 71 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import copy
55
import errno
66
import logging
7-
import io
87
from random import shuffle
98
import socket
109
import time
@@ -18,6 +17,7 @@
1817
from kafka.protocol.api import RequestHeader
1918
from kafka.protocol.admin import SaslHandShakeRequest
2019
from kafka.protocol.commit import GroupCoordinatorResponse
20+
from kafka.protocol.frame import KafkaBytes
2121
from kafka.protocol.metadata import MetadataRequest
2222
from kafka.protocol.types import Int32
2323
from kafka.version import __version__
@@ -204,9 +204,9 @@ def __init__(self, host, port, afi, **configs):
204204
if self.config['ssl_context'] is not None:
205205
self._ssl_context = self.config['ssl_context']
206206
self._sasl_auth_future = None
207-
self._rbuffer = io.BytesIO()
207+
self._header = KafkaBytes(4)
208+
self._rbuffer = None
208209
self._receiving = False
209-
self._next_payload_bytes = 0
210210
self.last_attempt = 0
211211
self._processing = False
212212
self._correlation_id = 0
@@ -518,17 +518,19 @@ def close(self, error=None):
518518
self.state = ConnectionStates.DISCONNECTED
519519
self.last_attempt = time.time()
520520
self._sasl_auth_future = None
521-
self._receiving = False
522-
self._next_payload_bytes = 0
523-
self._rbuffer.seek(0)
524-
self._rbuffer.truncate()
521+
self._reset_buffer()
525522
if error is None:
526523
error = Errors.Cancelled(str(self))
527524
while self.in_flight_requests:
528525
ifr = self.in_flight_requests.popleft()
529526
ifr.future.failure(error)
530527
self.config['state_change_callback'](self)
531528

529+
def _reset_buffer(self):
530+
self._receiving = False
531+
self._header.seek(0)
532+
self._rbuffer = None
533+
532534
def send(self, request):
533535
"""send request, return Future()
534536
@@ -602,11 +604,11 @@ def recv(self):
602604
# fail all the pending request futures
603605
if self.in_flight_requests:
604606
self.close(Errors.ConnectionError('Socket not connected during recv with in-flight-requests'))
605-
return None
607+
return ()
606608

607609
elif not self.in_flight_requests:
608610
log.warning('%s: No in-flight-requests to recv', self)
609-
return None
611+
return ()
610612

611613
response = self._recv()
612614
if not response and self.requests_timed_out():
@@ -615,93 +617,87 @@ def recv(self):
615617
self.close(error=Errors.RequestTimedOutError(
616618
'Request timed out after %s ms' %
617619
self.config['request_timeout_ms']))
618-
return None
620+
return ()
619621
return response
620622

621623
def _recv(self):
622-
# Not receiving is the state of reading the payload header
623-
if not self._receiving:
624+
responses = []
625+
SOCK_CHUNK_BYTES = 4096
626+
while True:
624627
try:
625-
bytes_to_read = 4 - self._rbuffer.tell()
626-
data = self._sock.recv(bytes_to_read)
628+
data = self._sock.recv(SOCK_CHUNK_BYTES)
627629
# We expect socket.recv to raise an exception if there is not
628630
# enough data to read the full bytes_to_read
629631
# but if the socket is disconnected, we will get empty data
630632
# without an exception raised
631633
if not data:
632634
log.error('%s: socket disconnected', self)
633635
self.close(error=Errors.ConnectionError('socket disconnected'))
634-
return None
635-
self._rbuffer.write(data)
636+
break
637+
else:
638+
responses.extend(self.receive_bytes(data))
639+
if len(data) < SOCK_CHUNK_BYTES:
640+
break
636641
except SSLWantReadError:
637-
return None
642+
break
638643
except ConnectionError as e:
639644
if six.PY2 and e.errno == errno.EWOULDBLOCK:
640-
return None
641-
log.exception('%s: Error receiving 4-byte payload header -'
645+
break
646+
log.exception('%s: Error receiving network data'
642647
' closing socket', self)
643648
self.close(error=Errors.ConnectionError(e))
644-
return None
645-
except BlockingIOError:
646-
if six.PY3:
647-
return None
648-
raise
649-
650-
if self._rbuffer.tell() == 4:
651-
self._rbuffer.seek(0)
652-
self._next_payload_bytes = Int32.decode(self._rbuffer)
653-
# reset buffer and switch state to receiving payload bytes
654-
self._rbuffer.seek(0)
655-
self._rbuffer.truncate()
656-
self._receiving = True
657-
elif self._rbuffer.tell() > 4:
658-
raise Errors.KafkaError('this should not happen - are you threading?')
659-
660-
if self._receiving:
661-
staged_bytes = self._rbuffer.tell()
662-
try:
663-
bytes_to_read = self._next_payload_bytes - staged_bytes
664-
data = self._sock.recv(bytes_to_read)
665-
# We expect socket.recv to raise an exception if there is not
666-
# enough data to read the full bytes_to_read
667-
# but if the socket is disconnected, we will get empty data
668-
# without an exception raised
669-
if bytes_to_read and not data:
670-
log.error('%s: socket disconnected', self)
671-
self.close(error=Errors.ConnectionError('socket disconnected'))
672-
return None
673-
self._rbuffer.write(data)
674-
except SSLWantReadError:
675-
return None
676-
except ConnectionError as e:
677-
# Extremely small chance that we have exactly 4 bytes for a
678-
# header, but nothing to read in the body yet
679-
if six.PY2 and e.errno == errno.EWOULDBLOCK:
680-
return None
681-
log.exception('%s: Error in recv', self)
682-
self.close(error=Errors.ConnectionError(e))
683-
return None
649+
break
684650
except BlockingIOError:
685651
if six.PY3:
686-
return None
652+
break
687653
raise
654+
return responses
688655

689-
staged_bytes = self._rbuffer.tell()
690-
if staged_bytes > self._next_payload_bytes:
691-
self.close(error=Errors.KafkaError('Receive buffer has more bytes than expected?'))
692-
693-
if staged_bytes != self._next_payload_bytes:
694-
return None
656+
def receive_bytes(self, data):
657+
i = 0
658+
n = len(data)
659+
responses = []
660+
if self._sensors:
661+
self._sensors.bytes_received.record(n)
662+
while i < n:
663+
664+
# Not receiving is the state of reading the payload header
665+
if not self._receiving:
666+
bytes_to_read = min(4 - self._header.tell(), n - i)
667+
self._header.write(data[i:i+bytes_to_read])
668+
i += bytes_to_read
669+
670+
if self._header.tell() == 4:
671+
self._header.seek(0)
672+
nbytes = Int32.decode(self._header)
673+
# reset buffer and switch state to receiving payload bytes
674+
self._rbuffer = KafkaBytes(nbytes)
675+
self._receiving = True
676+
elif self._header.tell() > 4:
677+
raise Errors.KafkaError('this should not happen - are you threading?')
678+
679+
680+
if self._receiving:
681+
total_bytes = len(self._rbuffer)
682+
staged_bytes = self._rbuffer.tell()
683+
bytes_to_read = min(total_bytes - staged_bytes, n - i)
684+
self._rbuffer.write(data[i:i+bytes_to_read])
685+
i += bytes_to_read
686+
687+
staged_bytes = self._rbuffer.tell()
688+
if staged_bytes > total_bytes:
689+
self.close(error=Errors.KafkaError('Receive buffer has more bytes than expected?'))
690+
691+
if staged_bytes != total_bytes:
692+
break
695693

696-
self._receiving = False
697-
self._next_payload_bytes = 0
698-
if self._sensors:
699-
self._sensors.bytes_received.record(4 + self._rbuffer.tell())
700-
self._rbuffer.seek(0)
701-
response = self._process_response(self._rbuffer)
702-
self._rbuffer.seek(0)
703-
self._rbuffer.truncate()
704-
return response
694+
self._receiving = False
695+
self._rbuffer.seek(0)
696+
resp = self._process_response(self._rbuffer)
697+
if resp is not None:
698+
responses.append(resp)
699+
self._reset_buffer()
700+
return responses
705701

706702
def _process_response(self, read_buffer):
707703
assert not self._processing, 'Recursion not supported'

kafka/protocol/message.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from ..codec import (has_gzip, has_snappy, has_lz4,
77
gzip_decode, snappy_decode,
88
lz4_decode, lz4_decode_old_kafka)
9+
from .frame import KafkaBytes
910
from .struct import Struct
1011
from .types import (
1112
Int8, Int32, Int64, Bytes, Schema, AbstractType
@@ -155,10 +156,10 @@ class MessageSet(AbstractType):
155156
@classmethod
156157
def encode(cls, items):
157158
# RecordAccumulator encodes messagesets internally
158-
if isinstance(items, io.BytesIO):
159+
if isinstance(items, (io.BytesIO, KafkaBytes)):
159160
size = Int32.decode(items)
160161
# rewind and return all the bytes
161-
items.seek(-4, 1)
162+
items.seek(items.tell() - 4)
162163
return items.read(size + 4)
163164

164165
encoded_values = []
@@ -198,7 +199,7 @@ def decode(cls, data, bytes_to_read=None):
198199

199200
@classmethod
200201
def repr(cls, messages):
201-
if isinstance(messages, io.BytesIO):
202+
if isinstance(messages, (KafkaBytes, io.BytesIO)):
202203
offset = messages.tell()
203204
decoded = cls.decode(messages)
204205
messages.seek(offset)

0 commit comments

Comments
 (0)