Skip to content

Commit f22ba21

Browse files
committed
GH-2302: Enable consumer seek only on matching group Id
1 parent 4de29ee commit f22ba21

File tree

4 files changed

+87
-34
lines changed

4 files changed

+87
-34
lines changed

spring-kafka/src/main/java/org/springframework/kafka/listener/AbstractConsumerSeekAware.java

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2019-2023 the original author or authors.
2+
* Copyright 2019-2024 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -33,6 +33,7 @@
3333
* having to keep track of the callbacks itself.
3434
*
3535
* @author Gary Russell
36+
* @author Borahm Lee
3637
* @since 2.3
3738
*
3839
*/
@@ -46,43 +47,59 @@ public abstract class AbstractConsumerSeekAware implements ConsumerSeekAware {
4647

4748
@Override
4849
public void registerSeekCallback(ConsumerSeekCallback callback) {
49-
this.callbackForThread.put(Thread.currentThread(), callback);
50+
if (matchGroupId()) {
51+
this.callbackForThread.put(Thread.currentThread(), callback);
52+
}
5053
}
5154

5255
@Override
5356
public void onPartitionsAssigned(Map<TopicPartition, Long> assignments, ConsumerSeekCallback callback) {
54-
ConsumerSeekCallback threadCallback = this.callbackForThread.get(Thread.currentThread());
55-
if (threadCallback != null) {
56-
assignments.keySet().forEach(tp -> {
57-
this.callbacks.put(tp, threadCallback);
58-
this.callbacksToTopic.computeIfAbsent(threadCallback, key -> new LinkedList<>()).add(tp);
59-
});
57+
if (matchGroupId()) {
58+
ConsumerSeekCallback threadCallback = this.callbackForThread.get(Thread.currentThread());
59+
if (threadCallback != null) {
60+
assignments.keySet()
61+
.forEach(tp -> {
62+
this.callbacks.put(tp, threadCallback);
63+
this.callbacksToTopic.computeIfAbsent(threadCallback, key -> new LinkedList<>())
64+
.add(tp);
65+
});
66+
}
6067
}
6168
}
6269

6370
@Override
6471
public void onPartitionsRevoked(Collection<TopicPartition> partitions) {
65-
partitions.forEach(tp -> {
66-
ConsumerSeekCallback removed = this.callbacks.remove(tp);
67-
if (removed != null) {
68-
List<TopicPartition> topics = this.callbacksToTopic.get(removed);
69-
if (topics != null) {
70-
topics.remove(tp);
71-
if (topics.size() == 0) {
72-
this.callbacksToTopic.remove(removed);
72+
if (matchGroupId()) {
73+
partitions.forEach(tp -> {
74+
ConsumerSeekCallback removed = this.callbacks.remove(tp);
75+
if (removed != null) {
76+
List<TopicPartition> topics = this.callbacksToTopic.get(removed);
77+
if (topics != null) {
78+
topics.remove(tp);
79+
if (topics.size() == 0) {
80+
this.callbacksToTopic.remove(removed);
81+
}
7382
}
7483
}
75-
}
76-
});
84+
});
85+
}
7786
}
7887

7988
@Override
8089
public void unregisterSeekCallback() {
81-
this.callbackForThread.remove(Thread.currentThread());
90+
if (matchGroupId()) {
91+
this.callbackForThread.remove(Thread.currentThread());
92+
}
93+
}
94+
95+
@Override
96+
public boolean matchGroupId() {
97+
return true;
8298
}
8399

84100
/**
85101
* Return the callback for the specified topic/partition.
102+
*
86103
* @param topicPartition the topic/partition.
87104
* @return the callback (or null if there is no assignment).
88105
*/
@@ -93,6 +110,7 @@ protected ConsumerSeekCallback getSeekCallbackFor(TopicPartition topicPartition)
93110

94111
/**
95112
* The map of callbacks for all currently assigned partitions.
113+
*
96114
* @return the map.
97115
*/
98116
protected Map<TopicPartition, ConsumerSeekCallback> getSeekCallbacks() {
@@ -101,6 +119,7 @@ protected Map<TopicPartition, ConsumerSeekCallback> getSeekCallbacks() {
101119

102120
/**
103121
* Return the currently registered callbacks and their associated {@link TopicPartition}(s).
122+
*
104123
* @return the map of callbacks and partitions.
105124
* @since 2.6
106125
*/
@@ -110,6 +129,7 @@ protected Map<ConsumerSeekCallback, List<TopicPartition>> getCallbacksAndTopics(
110129

111130
/**
112131
* Seek all assigned partitions to the beginning.
132+
*
113133
* @since 2.6
114134
*/
115135
public void seekToBeginning() {
@@ -118,6 +138,7 @@ public void seekToBeginning() {
118138

119139
/**
120140
* Seek all assigned partitions to the end.
141+
*
121142
* @since 2.6
122143
*/
123144
public void seekToEnd() {
@@ -126,6 +147,7 @@ public void seekToEnd() {
126147

127148
/**
128149
* Seek all assigned partitions to the offset represented by the timestamp.
150+
*
129151
* @param time the time to seek to.
130152
* @since 2.6
131153
*/

spring-kafka/src/main/java/org/springframework/kafka/listener/ConsumerSeekAware.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
*
3030
* @author Gary Russell
3131
* @author Soby Chacko
32+
* @author Borahm Lee
3233
* @since 1.1
3334
*
3435
*/
@@ -88,6 +89,16 @@ default void onFirstPoll() {
8889
default void unregisterSeekCallback() {
8990
}
9091

92+
/**
93+
* Determine if the consumer group ID for seeking matches the expected value.
94+
*
95+
* @return true if the group ID matches, false otherwise.
96+
* @since 3.3
97+
*/
98+
default boolean matchGroupId() {
99+
return false;
100+
}
101+
91102
/**
92103
* A callback that a listener can invoke to seek to a specific offset.
93104
*/

spring-kafka/src/main/java/org/springframework/kafka/listener/KafkaMessageListenerContainer.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@
165165
* @author Raphael Rösch
166166
* @author Christian Mergenthaler
167167
* @author Mikael Carlstedt
168+
* @author Borahm Lee
168169
*/
169170
public class KafkaMessageListenerContainer<K, V> // NOSONAR line count
170171
extends AbstractMessageListenerContainer<K, V> implements ConsumerPauseResumeEventPublisher {
@@ -1362,8 +1363,8 @@ protected void initialize() {
13621363
}
13631364
publishConsumerStartingEvent();
13641365
this.consumerThread = Thread.currentThread();
1365-
setupSeeks();
13661366
KafkaUtils.setConsumerGroupId(this.consumerGroupId);
1367+
setupSeeks();
13671368
this.count = 0;
13681369
this.last = System.currentTimeMillis();
13691370
initAssignedPartitions();

spring-kafka/src/test/java/org/springframework/kafka/listener/ConsumerSeekAwareTests.java

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,28 +16,27 @@
1616

1717
package org.springframework.kafka.listener;
1818

19-
import static org.assertj.core.api.Assertions.assertThat;
20-
import static org.mockito.Mockito.mock;
21-
import static org.mockito.Mockito.verify;
22-
2319
import java.util.Collections;
2420
import java.util.LinkedHashMap;
2521
import java.util.LinkedList;
2622
import java.util.Map;
2723
import java.util.concurrent.Callable;
24+
import java.util.concurrent.ExecutionException;
2825
import java.util.concurrent.Executors;
2926
import java.util.concurrent.atomic.AtomicBoolean;
30-
3127
import org.apache.kafka.common.TopicPartition;
3228
import org.junit.jupiter.api.Test;
33-
3429
import org.springframework.kafka.listener.ConsumerSeekAware.ConsumerSeekCallback;
3530
import org.springframework.kafka.test.utils.KafkaTestUtils;
3631

32+
import static org.assertj.core.api.Assertions.assertThat;
33+
import static org.mockito.Mockito.mock;
34+
import static org.mockito.Mockito.verify;
35+
3736
/**
3837
* @author Gary Russell
38+
* @author Borahm Lee
3939
* @since 2.6
40-
*
4140
*/
4241
public class ConsumerSeekAwareTests {
4342

@@ -51,16 +50,15 @@ class CSA extends AbstractConsumerSeekAware {
5150
var exec1 = Executors.newSingleThreadExecutor();
5251
var exec2 = Executors.newSingleThreadExecutor();
5352
var cb1 = mock(ConsumerSeekCallback.class);
54-
var cb2 = mock(ConsumerSeekCallback.class);
53+
var cb2 = mock(ConsumerSeekCallback.class);
5554
var first = new AtomicBoolean(true);
5655
var map1 = new LinkedHashMap<>(Map.of(new TopicPartition("foo", 0), 0L, new TopicPartition("foo", 1), 0L));
5756
var map2 = new LinkedHashMap<>(Map.of(new TopicPartition("foo", 2), 0L, new TopicPartition("foo", 3), 0L));
5857
var register = (Callable<Void>) () -> {
5958
if (first.getAndSet(false)) {
6059
csa.registerSeekCallback(cb1);
6160
csa.onPartitionsAssigned(map1, null);
62-
}
63-
else {
61+
} else {
6462
csa.registerSeekCallback(cb2);
6563
csa.onPartitionsAssigned(map2, null);
6664
}
@@ -80,8 +78,7 @@ class CSA extends AbstractConsumerSeekAware {
8078
var revoke1 = (Callable<Void>) () -> {
8179
if (!first.getAndSet(true)) {
8280
csa.onPartitionsRevoked(Collections.singletonList(map1.keySet().iterator().next()));
83-
}
84-
else {
81+
} else {
8582
csa.onPartitionsRevoked(Collections.singletonList(map2.keySet().iterator().next()));
8683
}
8784
return null;
@@ -96,8 +93,7 @@ class CSA extends AbstractConsumerSeekAware {
9693
var revoke2 = (Callable<Void>) () -> {
9794
if (first.getAndSet(false)) {
9895
csa.onPartitionsRevoked(Collections.singletonList(map1.keySet().iterator().next()));
99-
}
100-
else {
96+
} else {
10197
csa.onPartitionsRevoked(Collections.singletonList(map2.keySet().iterator().next()));
10298
}
10399
return null;
@@ -118,4 +114,27 @@ class CSA extends AbstractConsumerSeekAware {
118114
exec2.shutdown();
119115
}
120116

117+
@SuppressWarnings("unchecked")
118+
@Test
119+
void notMatchGroupId() throws ExecutionException, InterruptedException {
120+
class CSA extends AbstractConsumerSeekAware {
121+
@Override
122+
public boolean matchGroupId() {
123+
return false;
124+
}
125+
}
126+
127+
AbstractConsumerSeekAware csa = new CSA();
128+
var exec = Executors.newSingleThreadExecutor();
129+
var register = (Callable<Void>) () -> {
130+
csa.registerSeekCallback(mock(ConsumerSeekCallback.class));
131+
csa.onPartitionsAssigned(Map.of(new TopicPartition("baz", 0), 0L), null);
132+
return null;
133+
};
134+
exec.submit(register).get();
135+
assertThat(KafkaTestUtils.getPropertyValue(csa, "callbackForThread", Map.class)).isEmpty();
136+
assertThat(KafkaTestUtils.getPropertyValue(csa, "callbacks", Map.class)).isEmpty();
137+
assertThat(KafkaTestUtils.getPropertyValue(csa, "callbacksToTopic", Map.class)).isEmpty();
138+
exec.shutdown();
139+
}
121140
}

0 commit comments

Comments
 (0)