Skip to content

Commit 0d4b837

Browse files
committed
Split API_VERSIONS conn state to read/write
1 parent 40b6c9c commit 0d4b837

File tree

2 files changed

+21
-8
lines changed

2 files changed

+21
-8
lines changed

kafka/client_async.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,13 @@ def _conn_state_change(self, node_id, sock, conn):
315315
if self.cluster.is_bootstrap(node_id):
316316
self._last_bootstrap = time.time()
317317

318-
elif conn.state in (ConnectionStates.API_VERSIONS, ConnectionStates.AUTHENTICATING):
318+
elif conn.state is ConnectionStates.API_VERSIONS_SEND:
319+
try:
320+
self._selector.register(sock, selectors.EVENT_WRITE, conn)
321+
except KeyError:
322+
self._selector.modify(sock, selectors.EVENT_WRITE, conn)
323+
324+
elif conn.state in (ConnectionStates.API_VERSIONS_RECV, ConnectionStates.AUTHENTICATING):
319325
try:
320326
self._selector.register(sock, selectors.EVENT_READ, conn)
321327
except KeyError:

kafka/conn.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,8 @@ class ConnectionStates(object):
100100
HANDSHAKE = '<handshake>'
101101
CONNECTED = '<connected>'
102102
AUTHENTICATING = '<authenticating>'
103-
API_VERSIONS = '<checking_api_versions>'
103+
API_VERSIONS_SEND = '<checking_api_versions_send>'
104+
API_VERSIONS_RECV = '<checking_api_versions_recv>'
104105

105106

106107
class BrokerConnection(object):
@@ -419,7 +420,7 @@ def connect(self):
419420
self._wrap_ssl()
420421
else:
421422
log.debug('%s: checking broker Api Versions', self)
422-
self.state = ConnectionStates.API_VERSIONS
423+
self.state = ConnectionStates.API_VERSIONS_SEND
423424
self.config['state_change_callback'](self.node_id, self._sock, self)
424425

425426
# Connection failed
@@ -439,13 +440,13 @@ def connect(self):
439440
if self._try_handshake():
440441
log.debug('%s: completed SSL handshake.', self)
441442
log.debug('%s: checking broker Api Versions', self)
442-
self.state = ConnectionStates.API_VERSIONS
443+
self.state = ConnectionStates.API_VERSIONS_SEND
443444
self.config['state_change_callback'](self.node_id, self._sock, self)
444445

445-
if self.state is ConnectionStates.API_VERSIONS:
446+
if self.state in (ConnectionStates.API_VERSIONS_SEND, ConnectionStates.API_VERSIONS_RECV):
446447
if self._try_api_versions_check():
447448
# _try_api_versions_check has side-effects: possibly disconnected on socket errors
448-
if self.state is ConnectionStates.API_VERSIONS:
449+
if self.state in (ConnectionStates.API_VERSIONS_SEND, ConnectionStates.API_VERSIONS_RECV):
449450
if self.config['security_protocol'] in ('SASL_PLAINTEXT', 'SASL_SSL'):
450451
log.debug('%s: initiating SASL authentication', self)
451452
self.state = ConnectionStates.AUTHENTICATING
@@ -555,13 +556,17 @@ def _try_api_versions_check(self):
555556
response.add_callback(self._handle_api_versions_response, future)
556557
response.add_errback(self._handle_api_versions_failure, future)
557558
self._api_versions_future = future
559+
self.state = ConnectionStates.API_VERSIONS_RECV
560+
self.config['state_change_callback'](self.node_id, self._sock, self)
558561
elif self._check_version_idx < len(self.VERSION_CHECKS):
559562
version, request = self.VERSION_CHECKS[self._check_version_idx]
560563
future = Future()
561564
response = self._send(request, blocking=True, request_timeout_ms=(self.config['api_version_auto_timeout_ms'] * 0.8))
562565
response.add_callback(self._handle_check_version_response, future, version)
563566
response.add_errback(self._handle_check_version_failure, future)
564567
self._api_versions_future = future
568+
self.state = ConnectionStates.API_VERSIONS_RECV
569+
self.config['state_change_callback'](self.node_id, self._sock, self)
565570
else:
566571
raise 'Unable to determine broker version.'
567572

@@ -991,14 +996,16 @@ def connecting(self):
991996
return self.state in (ConnectionStates.CONNECTING,
992997
ConnectionStates.HANDSHAKE,
993998
ConnectionStates.AUTHENTICATING,
994-
ConnectionStates.API_VERSIONS)
999+
ConnectionStates.API_VERSIONS_SEND,
1000+
ConnectionStates.API_VERSIONS_RECV)
9951001

9961002
def initializing(self):
9971003
"""Returns True if socket is connected but full connection is not complete.
9981004
During this time the connection may send api requests to the broker to
9991005
check api versions and perform SASL authentication."""
10001006
return self.state in (ConnectionStates.AUTHENTICATING,
1001-
ConnectionStates.API_VERSIONS)
1007+
ConnectionStates.API_VERSIONS_SEND,
1008+
ConnectionStates.API_VERSIONS_RECV)
10021009

10031010
def disconnected(self):
10041011
"""Return True iff socket is closed"""

0 commit comments

Comments
 (0)