Skip to content

Add synchronized decorator; add lock to subscription state #2636

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion kafka/consumer/subscription_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
import logging
import random
import re
import threading
import time

from kafka.vendor import six

import kafka.errors as Errors
from kafka.protocol.list_offsets import OffsetResetStrategy
from kafka.structs import OffsetAndMetadata
from kafka.util import ensure_valid_topic_name
from kafka.util import ensure_valid_topic_name, synchronized

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -84,6 +85,7 @@ def __init__(self, offset_reset_strategy='earliest'):
self.assignment = OrderedDict()
self.rebalance_listener = None
self.listeners = []
self._lock = threading.RLock()

def _set_subscription_type(self, subscription_type):
if not isinstance(subscription_type, SubscriptionType):
Expand All @@ -93,6 +95,7 @@ def _set_subscription_type(self, subscription_type):
elif self.subscription_type != subscription_type:
raise Errors.IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE)

@synchronized
def subscribe(self, topics=(), pattern=None, listener=None):
"""Subscribe to a list of topics, or a topic regex pattern.

Expand Down Expand Up @@ -147,6 +150,7 @@ def subscribe(self, topics=(), pattern=None, listener=None):
raise TypeError('listener must be a ConsumerRebalanceListener')
self.rebalance_listener = listener

@synchronized
def change_subscription(self, topics):
"""Change the topic subscription.

Expand Down Expand Up @@ -178,6 +182,7 @@ def change_subscription(self, topics):
self.subscription = set(topics)
self._group_subscription.update(topics)

@synchronized
def group_subscribe(self, topics):
"""Add topics to the current group subscription.

Expand All @@ -191,13 +196,15 @@ def group_subscribe(self, topics):
raise Errors.IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE)
self._group_subscription.update(topics)

@synchronized
def reset_group_subscription(self):
"""Reset the group's subscription to only contain topics subscribed by this consumer."""
if not self.partitions_auto_assigned():
raise Errors.IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE)
assert self.subscription is not None, 'Subscription required'
self._group_subscription.intersection_update(self.subscription)

@synchronized
def assign_from_user(self, partitions):
"""Manually assign a list of TopicPartitions to this consumer.

Expand All @@ -222,6 +229,7 @@ def assign_from_user(self, partitions):
self._set_assignment({partition: self.assignment.get(partition, TopicPartitionState())
for partition in partitions})

@synchronized
def assign_from_subscribed(self, assignments):
"""Update the assignment to the specified partitions

Expand Down Expand Up @@ -258,6 +266,7 @@ def _set_assignment(self, partition_states, randomize=False):
for tp in topic_partitions[topic]:
self.assignment[tp] = partition_states[tp]

@synchronized
def unsubscribe(self):
"""Clear all topic subscriptions and partition assignments"""
self.subscription = None
Expand All @@ -266,6 +275,7 @@ def unsubscribe(self):
self.subscribed_pattern = None
self.subscription_type = SubscriptionType.NONE

@synchronized
def group_subscription(self):
"""Get the topic subscription for the group.

Expand All @@ -281,6 +291,7 @@ def group_subscription(self):
"""
return self._group_subscription

@synchronized
def seek(self, partition, offset):
"""Manually specify the fetch offset for a TopicPartition.

Expand All @@ -298,15 +309,18 @@ def seek(self, partition, offset):
raise TypeError("offset must be type in or OffsetAndMetadata")
self.assignment[partition].seek(offset)

@synchronized
def assigned_partitions(self):
"""Return set of TopicPartitions in current assignment."""
return set(self.assignment.keys())

@synchronized
def paused_partitions(self):
"""Return current set of paused TopicPartitions."""
return set(partition for partition in self.assignment
if self.is_paused(partition))

@synchronized
def fetchable_partitions(self):
"""Return ordered list of TopicPartitions that should be Fetched."""
fetchable = list()
Expand All @@ -315,10 +329,12 @@ def fetchable_partitions(self):
fetchable.append(partition)
return fetchable

@synchronized
def partitions_auto_assigned(self):
"""Return True unless user supplied partitions manually."""
return self.subscription_type in (SubscriptionType.AUTO_TOPICS, SubscriptionType.AUTO_PATTERN)

@synchronized
def all_consumed_offsets(self):
"""Returns consumed offsets as {TopicPartition: OffsetAndMetadata}"""
all_consumed = {}
Expand All @@ -327,6 +343,7 @@ def all_consumed_offsets(self):
all_consumed[partition] = state.position
return all_consumed

@synchronized
def request_offset_reset(self, partition, offset_reset_strategy=None):
"""Mark partition for offset reset using specified or default strategy.

Expand All @@ -338,33 +355,40 @@ def request_offset_reset(self, partition, offset_reset_strategy=None):
offset_reset_strategy = self._default_offset_reset_strategy
self.assignment[partition].reset(offset_reset_strategy)

@synchronized
def set_reset_pending(self, partitions, next_allowed_reset_time):
for partition in partitions:
self.assignment[partition].set_reset_pending(next_allowed_reset_time)

@synchronized
def has_default_offset_reset_policy(self):
"""Return True if default offset reset policy is Earliest or Latest"""
return self._default_offset_reset_strategy != OffsetResetStrategy.NONE

@synchronized
def is_offset_reset_needed(self, partition):
return self.assignment[partition].awaiting_reset

@synchronized
def has_all_fetch_positions(self):
for state in six.itervalues(self.assignment):
if not state.has_valid_position:
return False
return True

@synchronized
def missing_fetch_positions(self):
missing = set()
for partition, state in six.iteritems(self.assignment):
if state.is_missing_position():
missing.add(partition)
return missing

@synchronized
def has_valid_position(self, partition):
return partition in self.assignment and self.assignment[partition].has_valid_position

@synchronized
def reset_missing_positions(self):
partitions_with_no_offsets = set()
for tp, state in six.iteritems(self.assignment):
Expand All @@ -377,32 +401,40 @@ def reset_missing_positions(self):
if partitions_with_no_offsets:
raise Errors.NoOffsetForPartitionError(partitions_with_no_offsets)

@synchronized
def partitions_needing_reset(self):
partitions = set()
for tp, state in six.iteritems(self.assignment):
if state.awaiting_reset and state.is_reset_allowed():
partitions.add(tp)
return partitions

@synchronized
def is_assigned(self, partition):
return partition in self.assignment

@synchronized
def is_paused(self, partition):
return partition in self.assignment and self.assignment[partition].paused

@synchronized
def is_fetchable(self, partition):
return partition in self.assignment and self.assignment[partition].is_fetchable()

@synchronized
def pause(self, partition):
self.assignment[partition].pause()

@synchronized
def resume(self, partition):
self.assignment[partition].resume()

@synchronized
def reset_failed(self, partitions, next_retry_time):
for partition in partitions:
self.assignment[partition].reset_failed(next_retry_time)

@synchronized
def move_partition_to_end(self, partition):
if partition in self.assignment:
try:
Expand All @@ -411,6 +443,7 @@ def move_partition_to_end(self, partition):
state = self.assignment.pop(partition)
self.assignment[partition] = state

@synchronized
def position(self, partition):
return self.assignment[partition].position

Expand Down
9 changes: 9 additions & 0 deletions kafka/util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import absolute_import, division

import binascii
import functools
import re
import time
import weakref
Expand Down Expand Up @@ -129,3 +130,11 @@ class Dict(dict):
See: https://docs.python.org/2/library/weakref.html
"""
pass


def synchronized(func):
def wrapper(self, *args, **kwargs):
with self._lock:
return func(self, *args, **kwargs)
functools.update_wrapper(wrapper, func)
return wrapper
Loading