Skip to content

Fix KafkaConsumer.poll() with zero timeout #2613

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 7 additions & 11 deletions kafka/client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from kafka.metrics.stats.rate import TimeUnit
from kafka.protocol.broker_api_versions import BROKER_API_VERSIONS
from kafka.protocol.metadata import MetadataRequest
from kafka.util import Dict, WeakMethod, ensure_valid_topic_name, timeout_ms_fn
from kafka.util import Dict, Timer, WeakMethod, ensure_valid_topic_name
# Although this looks unused, it actually monkey-patches socket.socketpair()
# and should be left in as long as we're using socket.socketpair() in this file
from kafka.vendor import socketpair # noqa: F401
Expand Down Expand Up @@ -645,12 +645,8 @@ def poll(self, timeout_ms=None, future=None):
"""
if not isinstance(timeout_ms, (int, float, type(None))):
raise TypeError('Invalid type for timeout: %s' % type(timeout_ms))
timer = Timer(timeout_ms)

begin = time.time()
if timeout_ms is not None:
timeout_at = begin + (timeout_ms / 1000)
else:
timeout_at = begin + (self.config['request_timeout_ms'] / 1000)
# Loop for futures, break after first loop if None
responses = []
while True:
Expand All @@ -675,7 +671,7 @@ def poll(self, timeout_ms=None, future=None):
if future is not None and future.is_done:
timeout = 0
else:
user_timeout_ms = 1000 * max(0, timeout_at - time.time())
user_timeout_ms = timer.timeout_ms if timeout_ms is not None else self.config['request_timeout_ms']
idle_connection_timeout_ms = self._idle_expiry_manager.next_check_ms()
request_timeout_ms = self._next_ifr_request_timeout_ms()
log.debug("Timeouts: user %f, metadata %f, idle connection %f, request %f", user_timeout_ms, metadata_timeout_ms, idle_connection_timeout_ms, request_timeout_ms)
Expand All @@ -698,7 +694,7 @@ def poll(self, timeout_ms=None, future=None):
break
elif future.is_done:
break
elif timeout_ms is not None and time.time() >= timeout_at:
elif timeout_ms is not None and timer.expired:
break

return responses
Expand Down Expand Up @@ -1175,16 +1171,16 @@ def await_ready(self, node_id, timeout_ms=30000):
This method is useful for implementing blocking behaviour on top of the non-blocking `NetworkClient`, use it with
care.
"""
inner_timeout_ms = timeout_ms_fn(timeout_ms, None)
timer = Timer(timeout_ms)
self.poll(timeout_ms=0)
if self.is_ready(node_id):
return True

while not self.is_ready(node_id) and inner_timeout_ms() > 0:
while not self.is_ready(node_id) and not timer.expired:
if self.connection_failed(node_id):
raise Errors.KafkaConnectionError("Connection to %s failed." % (node_id,))
self.maybe_connect(node_id)
self.poll(timeout_ms=inner_timeout_ms())
self.poll(timeout_ms=timer.timeout_ms)
return self.is_ready(node_id)

def send_and_receive(self, node_id, request):
Expand Down
15 changes: 10 additions & 5 deletions kafka/consumer/fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from kafka.record import MemoryRecords
from kafka.serializer import Deserializer
from kafka.structs import TopicPartition, OffsetAndMetadata, OffsetAndTimestamp
from kafka.util import timeout_ms_fn
from kafka.util import Timer

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -230,15 +230,15 @@ def _fetch_offsets_by_times(self, timestamps, timeout_ms=None):
if not timestamps:
return {}

inner_timeout_ms = timeout_ms_fn(timeout_ms, 'Timeout fetching offsets')
timer = Timer(timeout_ms, "Failed to get offsets by timestamps in %s ms" % (timeout_ms,))
timestamps = copy.copy(timestamps)
fetched_offsets = dict()
while True:
if not timestamps:
return {}

future = self._send_list_offsets_requests(timestamps)
self._client.poll(future=future, timeout_ms=inner_timeout_ms())
self._client.poll(future=future, timeout_ms=timer.timeout_ms)

# Timeout w/o future completion
if not future.is_done:
Expand All @@ -256,12 +256,17 @@ def _fetch_offsets_by_times(self, timestamps, timeout_ms=None):

if future.exception.invalid_metadata or self._client.cluster.need_update:
refresh_future = self._client.cluster.request_update()
self._client.poll(future=refresh_future, timeout_ms=inner_timeout_ms())
self._client.poll(future=refresh_future, timeout_ms=timer.timeout_ms)

if not future.is_done:
break
else:
time.sleep(inner_timeout_ms(self.config['retry_backoff_ms']) / 1000)
if timer.timeout_ms is None or timer.timeout_ms > self.config['retry_backoff_ms']:
time.sleep(self.config['retry_backoff_ms'] / 1000)
else:
time.sleep(timer.timeout_ms / 1000)

timer.maybe_raise()

raise Errors.KafkaTimeoutError(
"Failed to get offsets by timestamps in %s ms" % (timeout_ms,))
Expand Down
36 changes: 17 additions & 19 deletions kafka/consumer/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from kafka.metrics import MetricConfig, Metrics
from kafka.protocol.list_offsets import OffsetResetStrategy
from kafka.structs import OffsetAndMetadata, TopicPartition
from kafka.util import timeout_ms_fn
from kafka.util import Timer
from kafka.version import __version__

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -679,41 +679,40 @@ def poll(self, timeout_ms=0, max_records=None, update_offsets=True):
assert not self._closed, 'KafkaConsumer is closed'

# Poll for new data until the timeout expires
inner_timeout_ms = timeout_ms_fn(timeout_ms, None)
timer = Timer(timeout_ms)
while not self._closed:
records = self._poll_once(inner_timeout_ms(), max_records, update_offsets=update_offsets)
records = self._poll_once(timer, max_records, update_offsets=update_offsets)
if records:
return records

if inner_timeout_ms() <= 0:
elif timer.expired:
break

return {}

def _poll_once(self, timeout_ms, max_records, update_offsets=True):
def _poll_once(self, timer, max_records, update_offsets=True):
"""Do one round of polling. In addition to checking for new data, this does
any needed heart-beating, auto-commits, and offset updates.

Arguments:
timeout_ms (int): The maximum time in milliseconds to block.
timer (Timer): The maximum time in milliseconds to block.

Returns:
dict: Map of topic to list of records (may be empty).
"""
inner_timeout_ms = timeout_ms_fn(timeout_ms, None)
if not self._coordinator.poll(timeout_ms=inner_timeout_ms()):
if not self._coordinator.poll(timeout_ms=timer.timeout_ms):
return {}

has_all_fetch_positions = self._update_fetch_positions(timeout_ms=inner_timeout_ms())
has_all_fetch_positions = self._update_fetch_positions(timeout_ms=timer.timeout_ms)

# If data is available already, e.g. from a previous network client
# poll() call to commit, then just return it immediately
records, partial = self._fetcher.fetched_records(max_records, update_offsets=update_offsets)
log.debug('Fetched records: %s, %s', records, partial)
# Before returning the fetched records, we can send off the
# next round of fetches and avoid block waiting for their
# responses to enable pipelining while the user is handling the
# fetched records.
if not partial:
log.debug("Sending fetches")
futures = self._fetcher.send_fetches()
if len(futures):
self._client.poll(timeout_ms=0)
Expand All @@ -723,7 +722,7 @@ def _poll_once(self, timeout_ms, max_records, update_offsets=True):

# We do not want to be stuck blocking in poll if we are missing some positions
# since the offset lookup may be backing off after a failure
poll_timeout_ms = inner_timeout_ms(self._coordinator.time_to_next_poll() * 1000)
poll_timeout_ms = min(timer.timeout_ms, self._coordinator.time_to_next_poll() * 1000)
if not has_all_fetch_positions:
poll_timeout_ms = min(poll_timeout_ms, self.config['retry_backoff_ms'])

Expand All @@ -749,15 +748,14 @@ def position(self, partition, timeout_ms=None):
raise TypeError('partition must be a TopicPartition namedtuple')
assert self._subscription.is_assigned(partition), 'Partition is not assigned'

inner_timeout_ms = timeout_ms_fn(timeout_ms, 'Timeout retrieving partition position')
timer = Timer(timeout_ms)
position = self._subscription.assignment[partition].position
try:
while position is None:
# batch update fetch positions for any partitions without a valid position
self._update_fetch_positions(timeout_ms=inner_timeout_ms())
while position is None:
# batch update fetch positions for any partitions without a valid position
if self._update_fetch_positions(timeout_ms=timer.timeout_ms):
position = self._subscription.assignment[partition].position
except KafkaTimeoutError:
return None
elif timer.expired:
return None
else:
return position.offset

Expand Down
62 changes: 42 additions & 20 deletions kafka/coordinator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from kafka.metrics.stats import Avg, Count, Max, Rate
from kafka.protocol.find_coordinator import FindCoordinatorRequest
from kafka.protocol.group import HeartbeatRequest, JoinGroupRequest, LeaveGroupRequest, SyncGroupRequest, DEFAULT_GENERATION_ID, UNKNOWN_MEMBER_ID
from kafka.util import timeout_ms_fn
from kafka.util import Timer

log = logging.getLogger('kafka.coordinator')

Expand Down Expand Up @@ -256,9 +256,9 @@ def ensure_coordinator_ready(self, timeout_ms=None):
timeout_ms (numeric, optional): Maximum number of milliseconds to
block waiting to find coordinator. Default: None.

Raises: KafkaTimeoutError if timeout_ms is not None
Returns: True is coordinator found before timeout_ms, else False
"""
inner_timeout_ms = timeout_ms_fn(timeout_ms, 'Timeout attempting to find group coordinator')
timer = Timer(timeout_ms)
with self._client._lock, self._lock:
while self.coordinator_unknown():

Expand All @@ -272,27 +272,37 @@ def ensure_coordinator_ready(self, timeout_ms=None):
else:
self.coordinator_id = maybe_coordinator_id
self._client.maybe_connect(self.coordinator_id)
continue
if timer.expired:
return False
else:
continue
else:
future = self.lookup_coordinator()

self._client.poll(future=future, timeout_ms=inner_timeout_ms())
self._client.poll(future=future, timeout_ms=timer.timeout_ms)

if not future.is_done:
raise Errors.KafkaTimeoutError()
return False

if future.failed():
if future.retriable():
if getattr(future.exception, 'invalid_metadata', False):
log.debug('Requesting metadata for group coordinator request: %s', future.exception)
metadata_update = self._client.cluster.request_update()
self._client.poll(future=metadata_update, timeout_ms=inner_timeout_ms())
self._client.poll(future=metadata_update, timeout_ms=timer.timeout_ms)
if not metadata_update.is_done:
raise Errors.KafkaTimeoutError()
return False
else:
time.sleep(inner_timeout_ms(self.config['retry_backoff_ms']) / 1000)
if timeout_ms is None or timer.timeout_ms > self.config['retry_backoff_ms']:
time.sleep(self.config['retry_backoff_ms'] / 1000)
else:
time.sleep(timer.timeout_ms / 1000)
else:
raise future.exception # pylint: disable-msg=raising-bad-type
if timer.expired:
return False
else:
return True

def _reset_find_coordinator_future(self, result):
self._find_coordinator_future = None
Expand Down Expand Up @@ -407,21 +417,23 @@ def ensure_active_group(self, timeout_ms=None):
timeout_ms (numeric, optional): Maximum number of milliseconds to
block waiting to join group. Default: None.

Raises: KafkaTimeoutError if timeout_ms is not None
Returns: True if group initialized before timeout_ms, else False
"""
if self.config['api_version'] < (0, 9):
raise Errors.UnsupportedVersionError('Group Coordinator APIs require 0.9+ broker')
inner_timeout_ms = timeout_ms_fn(timeout_ms, 'Timeout attempting to join consumer group')
self.ensure_coordinator_ready(timeout_ms=inner_timeout_ms())
timer = Timer(timeout_ms)
if not self.ensure_coordinator_ready(timeout_ms=timer.timeout_ms):
return False
self._start_heartbeat_thread()
self.join_group(timeout_ms=inner_timeout_ms())
return self.join_group(timeout_ms=timer.timeout_ms)

def join_group(self, timeout_ms=None):
if self.config['api_version'] < (0, 9):
raise Errors.UnsupportedVersionError('Group Coordinator APIs require 0.9+ broker')
inner_timeout_ms = timeout_ms_fn(timeout_ms, 'Timeout attempting to join consumer group')
timer = Timer(timeout_ms)
while self.need_rejoin():
self.ensure_coordinator_ready(timeout_ms=inner_timeout_ms())
if not self.ensure_coordinator_ready(timeout_ms=timer.timeout_ms):
return False

# call on_join_prepare if needed. We set a flag
# to make sure that we do not call it a second
Expand All @@ -434,7 +446,7 @@ def join_group(self, timeout_ms=None):
if not self.rejoining:
self._on_join_prepare(self._generation.generation_id,
self._generation.member_id,
timeout_ms=inner_timeout_ms())
timeout_ms=timer.timeout_ms)
self.rejoining = True

# fence off the heartbeat thread explicitly so that it cannot
Expand All @@ -449,16 +461,19 @@ def join_group(self, timeout_ms=None):
while not self.coordinator_unknown():
if not self._client.in_flight_request_count(self.coordinator_id):
break
self._client.poll(timeout_ms=inner_timeout_ms(200))
poll_timeout_ms = 200 if timer.timeout_ms is None or timer.timeout_ms > 200 else timer.timeout_ms
self._client.poll(timeout_ms=poll_timeout_ms)
if timer.expired:
return False
else:
continue

future = self._initiate_join_group()
self._client.poll(future=future, timeout_ms=inner_timeout_ms())
self._client.poll(future=future, timeout_ms=timer.timeout_ms)
if future.is_done:
self._reset_join_group_future()
else:
raise Errors.KafkaTimeoutError()
return False

if future.succeeded():
self.rejoining = False
Expand All @@ -467,6 +482,7 @@ def join_group(self, timeout_ms=None):
self._generation.member_id,
self._generation.protocol,
future.value)
return True
else:
exception = future.exception
if isinstance(exception, (Errors.UnknownMemberIdError,
Expand All @@ -476,7 +492,13 @@ def join_group(self, timeout_ms=None):
continue
elif not future.retriable():
raise exception # pylint: disable-msg=raising-bad-type
time.sleep(inner_timeout_ms(self.config['retry_backoff_ms']) / 1000)
elif timer.expired:
return False
else:
if timer.timeout_ms is None or timer.timeout_ms > self.config['retry_backoff_ms']:
time.sleep(self.config['retry_backoff_ms'] / 1000)
else:
time.sleep(timer.timeout_ms / 1000)

def _send_join_group_request(self):
"""Join the group and return the assignment for the next generation.
Expand Down
Loading
Loading