@@ -587,6 +587,8 @@ def _try_authenticate_plain(self, future):
587
587
size = Int32 .encode (len (msg ))
588
588
try :
589
589
with self ._lock :
590
+ if not self ._can_send_recv ():
591
+ return future .failure (Errors .NodeNotReadyError (str (self )))
590
592
self ._send_bytes_blocking (size + msg )
591
593
592
594
# The server will send a zero sized message (that is Int32(0)) on success.
@@ -616,6 +618,8 @@ def _try_authenticate_gssapi(self, future):
616
618
log .debug ('%s: GSSAPI name: %s' , self , gssapi_name )
617
619
618
620
self ._lock .acquire ()
621
+ if not self ._can_send_recv ():
622
+ return future .failure (Errors .NodeNotReadyError (str (self )))
619
623
# Establish security context and negotiate protection level
620
624
# For reference RFC 2222, section 7.2.1
621
625
try :
@@ -677,6 +681,8 @@ def _try_authenticate_oauth(self, future):
677
681
msg = bytes (self ._build_oauth_client_request ().encode ("utf-8" ))
678
682
size = Int32 .encode (len (msg ))
679
683
self ._lock .acquire ()
684
+ if not self ._can_send_recv ():
685
+ return future .failure (Errors .NodeNotReadyError (str (self )))
680
686
try :
681
687
# Send SASL OAuthBearer request with OAuth token
682
688
self ._send_bytes_blocking (size + msg )
@@ -816,6 +822,11 @@ def close(self, error=None):
816
822
for (_correlation_id , (future , _timestamp )) in ifrs :
817
823
future .failure (error )
818
824
825
+ def _can_send_recv (self ):
826
+ """Return True iff socket is ready for requests / responses"""
827
+ return self .state in (ConnectionStates .AUTHENTICATING ,
828
+ ConnectionStates .CONNECTED )
829
+
819
830
def send (self , request , blocking = True ):
820
831
"""Queue request for async network send, return Future()"""
821
832
future = Future ()
@@ -830,8 +841,7 @@ def send(self, request, blocking=True):
830
841
def _send (self , request , blocking = True ):
831
842
future = Future ()
832
843
with self ._lock :
833
- if self .state not in (ConnectionStates .AUTHENTICATING ,
834
- ConnectionStates .CONNECTED ):
844
+ if not self ._can_send_recv ():
835
845
return future .failure (Errors .NodeNotReadyError (str (self )))
836
846
837
847
correlation_id = self ._protocol .send_request (request )
@@ -855,8 +865,7 @@ def send_pending_requests(self):
855
865
"""Can block on network if request is larger than send_buffer_bytes"""
856
866
try :
857
867
with self ._lock :
858
- if self .state not in (ConnectionStates .AUTHENTICATING ,
859
- ConnectionStates .CONNECTED ):
868
+ if not self ._can_send_recv ():
860
869
return Errors .NodeNotReadyError (str (self ))
861
870
# In the future we might manage an internal write buffer
862
871
# and send bytes asynchronously. For now, just block
@@ -882,19 +891,6 @@ def recv(self):
882
891
883
892
Return list of (response, future) tuples
884
893
"""
885
- if self .state not in (ConnectionStates .AUTHENTICATING ,
886
- ConnectionStates .CONNECTED ):
887
- log .warning ('%s cannot recv: socket not connected' , self )
888
- # If requests are pending, we should close the socket and
889
- # fail all the pending request futures
890
- if self .in_flight_requests :
891
- self .close (Errors .KafkaConnectionError ('Socket not connected during recv with in-flight-requests' ))
892
- return ()
893
-
894
- elif not self .in_flight_requests :
895
- log .warning ('%s: No in-flight-requests to recv' , self )
896
- return ()
897
-
898
894
responses = self ._recv ()
899
895
if not responses and self .requests_timed_out ():
900
896
log .warning ('%s timed out after %s ms. Closing connection.' ,
@@ -925,6 +921,11 @@ def _recv(self):
925
921
"""Take all available bytes from socket, return list of any responses from parser"""
926
922
recvd = []
927
923
self ._lock .acquire ()
924
+ if not self ._can_send_recv ():
925
+ log .warning ('%s cannot recv: socket not connected' , self )
926
+ self ._lock .release ()
927
+ return ()
928
+
928
929
while len (recvd ) < self .config ['sock_chunk_buffer_count' ]:
929
930
try :
930
931
data = self ._sock .recv (self .config ['sock_chunk_bytes' ])
0 commit comments