Skip to content

Commit c1c71f7

Browse files
committed
Check _can_send_recv with lock to verify state
1 parent 47510f5 commit c1c71f7

File tree

1 file changed

+18
-17
lines changed

1 file changed

+18
-17
lines changed

kafka/conn.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,8 @@ def _try_authenticate_plain(self, future):
587587
size = Int32.encode(len(msg))
588588
try:
589589
with self._lock:
590+
if not self._can_send_recv():
591+
return future.failure(Errors.NodeNotReadyError(str(self)))
590592
self._send_bytes_blocking(size + msg)
591593

592594
# The server will send a zero sized message (that is Int32(0)) on success.
@@ -616,6 +618,8 @@ def _try_authenticate_gssapi(self, future):
616618
log.debug('%s: GSSAPI name: %s', self, gssapi_name)
617619

618620
self._lock.acquire()
621+
if not self._can_send_recv():
622+
return future.failure(Errors.NodeNotReadyError(str(self)))
619623
# Establish security context and negotiate protection level
620624
# For reference RFC 2222, section 7.2.1
621625
try:
@@ -677,6 +681,8 @@ def _try_authenticate_oauth(self, future):
677681
msg = bytes(self._build_oauth_client_request().encode("utf-8"))
678682
size = Int32.encode(len(msg))
679683
self._lock.acquire()
684+
if not self._can_send_recv():
685+
return future.failure(Errors.NodeNotReadyError(str(self)))
680686
try:
681687
# Send SASL OAuthBearer request with OAuth token
682688
self._send_bytes_blocking(size + msg)
@@ -816,6 +822,11 @@ def close(self, error=None):
816822
for (_correlation_id, (future, _timestamp)) in ifrs:
817823
future.failure(error)
818824

825+
def _can_send_recv(self):
826+
"""Return True iff socket is ready for requests / responses"""
827+
return self.state in (ConnectionStates.AUTHENTICATING,
828+
ConnectionStates.CONNECTED)
829+
819830
def send(self, request, blocking=True):
820831
"""Queue request for async network send, return Future()"""
821832
future = Future()
@@ -830,8 +841,7 @@ def send(self, request, blocking=True):
830841
def _send(self, request, blocking=True):
831842
future = Future()
832843
with self._lock:
833-
if self.state not in (ConnectionStates.AUTHENTICATING,
834-
ConnectionStates.CONNECTED):
844+
if not self._can_send_recv():
835845
return future.failure(Errors.NodeNotReadyError(str(self)))
836846

837847
correlation_id = self._protocol.send_request(request)
@@ -855,8 +865,7 @@ def send_pending_requests(self):
855865
"""Can block on network if request is larger than send_buffer_bytes"""
856866
try:
857867
with self._lock:
858-
if self.state not in (ConnectionStates.AUTHENTICATING,
859-
ConnectionStates.CONNECTED):
868+
if not self._can_send_recv():
860869
return Errors.NodeNotReadyError(str(self))
861870
# In the future we might manage an internal write buffer
862871
# and send bytes asynchronously. For now, just block
@@ -882,19 +891,6 @@ def recv(self):
882891
883892
Return list of (response, future) tuples
884893
"""
885-
if self.state not in (ConnectionStates.AUTHENTICATING,
886-
ConnectionStates.CONNECTED):
887-
log.warning('%s cannot recv: socket not connected', self)
888-
# If requests are pending, we should close the socket and
889-
# fail all the pending request futures
890-
if self.in_flight_requests:
891-
self.close(Errors.KafkaConnectionError('Socket not connected during recv with in-flight-requests'))
892-
return ()
893-
894-
elif not self.in_flight_requests:
895-
log.warning('%s: No in-flight-requests to recv', self)
896-
return ()
897-
898894
responses = self._recv()
899895
if not responses and self.requests_timed_out():
900896
log.warning('%s timed out after %s ms. Closing connection.',
@@ -925,6 +921,11 @@ def _recv(self):
925921
"""Take all available bytes from socket, return list of any responses from parser"""
926922
recvd = []
927923
self._lock.acquire()
924+
if not self._can_send_recv():
925+
log.warning('%s cannot recv: socket not connected', self)
926+
self._lock.release()
927+
return ()
928+
928929
while len(recvd) < self.config['sock_chunk_buffer_count']:
929930
try:
930931
data = self._sock.recv(self.config['sock_chunk_bytes'])

0 commit comments

Comments
 (0)