Skip to content

Commit 08400df

Browse files
committed
Fix some race conditions in tests
* `SftpSessionFactoryTests.concurrentGetSessionDoesntCauseFailure()` may not report properly into an `ArrayList` from another thread. Use `asyncTaskExecutor.submitCompletable()` instead and deal with their result afterward. * Use `CompletableFuture` for the `TcpListener` logic in the `FailoverClientConnectionFactoryTests.failoverAllDeadAfterSuccess()`
1 parent d73ded5 commit 08400df

File tree

3 files changed

+74
-71
lines changed

3 files changed

+74
-71
lines changed

spring-integration-ip/src/main/java/org/springframework/integration/ip/tcp/connection/FailoverClientConnectionFactory.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2023 the original author or authors.
2+
* Copyright 2002-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -35,11 +35,12 @@
3535

3636
/**
3737
* Given a list of connection factories, serves up {@link TcpConnection}s
38-
* that can iterate over a connection from each factory until the write
38+
* that can iterate over a connection from each factory until the {@code write}
3939
* succeeds or the list is exhausted.
4040
*
4141
* @author Gary Russell
4242
* @author Christian Tzolov
43+
* @author Artem Bilan
4344
*
4445
* @since 2.2
4546
*
@@ -163,8 +164,9 @@ protected TcpConnectionSupport obtainConnection() throws InterruptedException {
163164
return sharedConnection;
164165
}
165166
FailoverTcpConnection failoverTcpConnection = new FailoverTcpConnection(this.factories);
166-
if (getListener() != null) {
167-
failoverTcpConnection.registerListener(getListener());
167+
TcpListener listener = getListener();
168+
if (listener != null) {
169+
failoverTcpConnection.registerListener(listener);
168170
}
169171
failoverTcpConnection.incrementEpoch();
170172
if (shared) {
@@ -286,9 +288,7 @@ private void findAConnection() throws InterruptedException {
286288
}
287289
catch (RuntimeException e) {
288290
if (logger.isDebugEnabled()) {
289-
logger.debug(nextFactory + " failed with "
290-
+ e.toString()
291-
+ ", trying another");
291+
logger.debug(nextFactory + " failed with " + e + ", trying another");
292292
}
293293
if (restartedList && (lastFactoryToTry == null || lastFactoryToTry.equals(nextFactory))) {
294294
logger.debug("Failover failed to find a connection");

spring-integration-ip/src/test/java/org/springframework/integration/ip/tcp/connection/FailoverClientConnectionFactoryTests.java

Lines changed: 50 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2024 the original author or authors.
2+
* Copyright 2002-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -25,6 +25,7 @@
2525
import java.nio.channels.SocketChannel;
2626
import java.util.ArrayList;
2727
import java.util.List;
28+
import java.util.concurrent.CompletableFuture;
2829
import java.util.concurrent.CountDownLatch;
2930
import java.util.concurrent.Executor;
3031
import java.util.concurrent.TimeUnit;
@@ -39,8 +40,6 @@
3940
import org.mockito.Mockito;
4041

4142
import org.springframework.beans.factory.BeanFactory;
42-
import org.springframework.context.ApplicationEvent;
43-
import org.springframework.context.ApplicationEventPublisher;
4443
import org.springframework.core.task.SimpleAsyncTaskExecutor;
4544
import org.springframework.integration.channel.DirectChannel;
4645
import org.springframework.integration.channel.QueueChannel;
@@ -79,34 +78,21 @@
7978
*/
8079
public class FailoverClientConnectionFactoryTests {
8180

82-
private static final ApplicationEventPublisher NULL_PUBLISHER = new ApplicationEventPublisher() {
83-
84-
@Override
85-
public void publishEvent(ApplicationEvent event) {
86-
}
87-
88-
@Override
89-
public void publishEvent(Object event) {
90-
91-
}
92-
93-
};
94-
9581
@Test
9682
public void testFailoverGood() throws Exception {
9783
TcpConnectionSupport conn1 = makeMockConnection();
9884
TcpConnectionSupport conn2 = makeMockConnection();
9985
AbstractClientConnectionFactory factory1 = createFactoryWithMockConnection(conn1);
10086
AbstractClientConnectionFactory factory2 = createFactoryWithMockConnection(conn2);
101-
List<AbstractClientConnectionFactory> factories = new ArrayList<AbstractClientConnectionFactory>();
87+
List<AbstractClientConnectionFactory> factories = new ArrayList<>();
10288
factories.add(factory1);
10389
factories.add(factory2);
10490
doThrow(new UncheckedIOException(new IOException("fail")))
10591
.when(conn1).send(Mockito.any(Message.class));
10692
doAnswer(invocation -> null).when(conn2).send(Mockito.any(Message.class));
10793
FailoverClientConnectionFactory failoverFactory = new FailoverClientConnectionFactory(factories);
10894
failoverFactory.start();
109-
GenericMessage<String> message = new GenericMessage<String>("foo");
95+
GenericMessage<String> message = new GenericMessage<>("foo");
11096
failoverFactory.getConnection().send(message);
11197
Mockito.verify(conn2).send(message);
11298
}
@@ -129,7 +115,7 @@ public void testRefreshSharedInfinite() throws Exception {
129115
private void testRefreshShared(boolean closeOnRefresh, long interval) throws Exception {
130116
AbstractClientConnectionFactory factory1 = mock(AbstractClientConnectionFactory.class);
131117
AbstractClientConnectionFactory factory2 = mock(AbstractClientConnectionFactory.class);
132-
List<AbstractClientConnectionFactory> factories = new ArrayList<AbstractClientConnectionFactory>();
118+
List<AbstractClientConnectionFactory> factories = new ArrayList<>();
133119
factories.add(factory1);
134120
factories.add(factory2);
135121
TcpConnectionSupport conn1 = makeMockConnection();
@@ -182,7 +168,7 @@ public void testFailoverAllDead() throws Exception {
182168
TcpConnectionSupport conn2 = makeMockConnection();
183169
AbstractClientConnectionFactory factory1 = createFactoryWithMockConnection(conn1);
184170
AbstractClientConnectionFactory factory2 = createFactoryWithMockConnection(conn2);
185-
List<AbstractClientConnectionFactory> factories = new ArrayList<AbstractClientConnectionFactory>();
171+
List<AbstractClientConnectionFactory> factories = new ArrayList<>();
186172
factories.add(factory1);
187173
factories.add(factory2);
188174
doThrow(new UncheckedIOException(new IOException("fail")))
@@ -191,7 +177,7 @@ public void testFailoverAllDead() throws Exception {
191177
.when(conn2).send(Mockito.any(Message.class));
192178
FailoverClientConnectionFactory failoverFactory = new FailoverClientConnectionFactory(factories);
193179
failoverFactory.start();
194-
GenericMessage<String> message = new GenericMessage<String>("foo");
180+
GenericMessage<String> message = new GenericMessage<>("foo");
195181
assertThatExceptionOfType(UncheckedIOException.class).isThrownBy(() ->
196182
failoverFactory.getConnection().send(message));
197183
Mockito.verify(conn2).send(message);
@@ -214,7 +200,7 @@ void failoverAllDeadAfterSuccess() throws Exception {
214200
TcpNetClientConnectionFactory cf1 = new TcpNetClientConnectionFactory("localhost", ss1.getLocalPort());
215201
AbstractClientConnectionFactory cf2 = mock(AbstractClientConnectionFactory.class);
216202
doThrow(new UncheckedIOException(new IOException("fail"))).when(cf2).getConnection();
217-
CountDownLatch latch = new CountDownLatch(2);
203+
CountDownLatch latch = new CountDownLatch(1);
218204
cf1.setApplicationEventPublisher(event -> {
219205
if (event instanceof TcpConnectionCloseEvent) {
220206
latch.countDown();
@@ -223,12 +209,16 @@ void failoverAllDeadAfterSuccess() throws Exception {
223209
cf2.setApplicationEventPublisher(event -> {
224210
});
225211
FailoverClientConnectionFactory fccf = new FailoverClientConnectionFactory(List.of(cf1, cf2));
226-
fccf.registerListener(msf -> {
227-
latch.countDown();
228-
return false;
229-
});
212+
213+
CompletableFuture<Message<?>> messageCompletableFuture = new CompletableFuture<>();
214+
fccf.registerListener(messageCompletableFuture::complete);
215+
230216
fccf.start();
231217
fccf.getConnection().send(new GenericMessage<>("test"));
218+
assertThat(messageCompletableFuture)
219+
.succeedsWithin(10, TimeUnit.SECONDS)
220+
.extracting(Message::getPayload)
221+
.isEqualTo("ok".getBytes());
232222
assertThat(latch.await(10, TimeUnit.SECONDS)).isTrue();
233223
assertThatExceptionOfType(UncheckedIOException.class).isThrownBy(() ->
234224
fccf.getConnection().send(new GenericMessage<>("test")));
@@ -240,7 +230,7 @@ public void testFailoverAllDeadButOriginalOkAgain() throws Exception {
240230
TcpConnectionSupport conn2 = makeMockConnection();
241231
AbstractClientConnectionFactory factory1 = createFactoryWithMockConnection(conn1);
242232
AbstractClientConnectionFactory factory2 = createFactoryWithMockConnection(conn2);
243-
List<AbstractClientConnectionFactory> factories = new ArrayList<AbstractClientConnectionFactory>();
233+
List<AbstractClientConnectionFactory> factories = new ArrayList<>();
244234
factories.add(factory1);
245235
factories.add(factory2);
246236
final AtomicBoolean failedOnce = new AtomicBoolean();
@@ -255,7 +245,7 @@ public void testFailoverAllDeadButOriginalOkAgain() throws Exception {
255245
.when(conn2).send(Mockito.any(Message.class));
256246
FailoverClientConnectionFactory failoverFactory = new FailoverClientConnectionFactory(factories);
257247
failoverFactory.start();
258-
GenericMessage<String> message = new GenericMessage<String>("foo");
248+
GenericMessage<String> message = new GenericMessage<>("foo");
259249
failoverFactory.getConnection().send(message);
260250
Mockito.verify(conn2).send(message);
261251
Mockito.verify(conn1, times(2)).send(message);
@@ -265,7 +255,7 @@ public void testFailoverAllDeadButOriginalOkAgain() throws Exception {
265255
public void testFailoverConnectNone() throws Exception {
266256
AbstractClientConnectionFactory factory1 = mock(AbstractClientConnectionFactory.class);
267257
AbstractClientConnectionFactory factory2 = mock(AbstractClientConnectionFactory.class);
268-
List<AbstractClientConnectionFactory> factories = new ArrayList<AbstractClientConnectionFactory>();
258+
List<AbstractClientConnectionFactory> factories = new ArrayList<>();
269259
factories.add(factory1);
270260
factories.add(factory2);
271261
when(factory1.getConnection()).thenThrow(new UncheckedIOException(new IOException("fail")));
@@ -274,7 +264,7 @@ public void testFailoverConnectNone() throws Exception {
274264
when(factory2.isActive()).thenReturn(true);
275265
FailoverClientConnectionFactory failoverFactory = new FailoverClientConnectionFactory(factories);
276266
failoverFactory.start();
277-
GenericMessage<String> message = new GenericMessage<String>("foo");
267+
GenericMessage<String> message = new GenericMessage<>("foo");
278268
assertThatExceptionOfType(UncheckedIOException.class).isThrownBy(() ->
279269
failoverFactory.getConnection().send(message));
280270
}
@@ -283,7 +273,7 @@ public void testFailoverConnectNone() throws Exception {
283273
public void testFailoverConnectToFirstAfterTriedAll() throws Exception {
284274
AbstractClientConnectionFactory factory1 = mock(AbstractClientConnectionFactory.class);
285275
AbstractClientConnectionFactory factory2 = mock(AbstractClientConnectionFactory.class);
286-
List<AbstractClientConnectionFactory> factories = new ArrayList<AbstractClientConnectionFactory>();
276+
List<AbstractClientConnectionFactory> factories = new ArrayList<>();
287277
factories.add(factory1);
288278
factories.add(factory2);
289279
TcpConnectionSupport conn1 = makeMockConnection();
@@ -308,7 +298,7 @@ public void testOkAgainAfterCompleteFailure() throws Exception {
308298
TcpConnectionSupport conn2 = makeMockConnection();
309299
AbstractClientConnectionFactory factory1 = createFactoryWithMockConnection(conn1);
310300
AbstractClientConnectionFactory factory2 = createFactoryWithMockConnection(conn2);
311-
List<AbstractClientConnectionFactory> factories = new ArrayList<AbstractClientConnectionFactory>();
301+
List<AbstractClientConnectionFactory> factories = new ArrayList<>();
312302
factories.add(factory1);
313303
factories.add(factory2);
314304
final AtomicInteger failCount = new AtomicInteger();
@@ -322,7 +312,7 @@ public void testOkAgainAfterCompleteFailure() throws Exception {
322312
.when(conn2).send(Mockito.any(Message.class));
323313
FailoverClientConnectionFactory failoverFactory = new FailoverClientConnectionFactory(factories);
324314
failoverFactory.start();
325-
GenericMessage<String> message = new GenericMessage<String>("foo");
315+
GenericMessage<String> message = new GenericMessage<>("foo");
326316
assertThatExceptionOfType(UncheckedIOException.class)
327317
.isThrownBy(() -> failoverFactory.getConnection().send(message));
328318
failoverFactory.getConnection().send(message);
@@ -426,27 +416,27 @@ public void testFailoverCachedRealClose() throws Exception {
426416
cachingFactory2.setBeanName("cache2");
427417

428418
// Failover
429-
List<AbstractClientConnectionFactory> factories = new ArrayList<AbstractClientConnectionFactory>();
419+
List<AbstractClientConnectionFactory> factories = new ArrayList<>();
430420
factories.add(cachingFactory1);
431421
factories.add(cachingFactory2);
432422
FailoverClientConnectionFactory failoverFactory = new FailoverClientConnectionFactory(factories);
433423

434424
failoverFactory.start();
435425
TcpConnection conn1 = failoverFactory.getConnection();
436-
conn1.send(new GenericMessage<String>("foo1"));
426+
conn1.send(new GenericMessage<>("foo1"));
437427
conn1.close();
438428
TcpConnection conn2 = failoverFactory.getConnection();
439429
assertThat((TestUtils.getPropertyValue(conn2, "delegate", TcpConnectionInterceptorSupport.class))
440430
.getTheConnection())
441431
.isSameAs((TestUtils.getPropertyValue(conn1, "delegate", TcpConnectionInterceptorSupport.class))
442432
.getTheConnection());
443-
conn2.send(new GenericMessage<String>("foo2"));
433+
conn2.send(new GenericMessage<>("foo2"));
444434
conn1 = failoverFactory.getConnection();
445435
assertThat((TestUtils.getPropertyValue(conn2, "delegate", TcpConnectionInterceptorSupport.class))
446436
.getTheConnection())
447437
.isNotSameAs((TestUtils.getPropertyValue(conn1, "delegate", TcpConnectionInterceptorSupport.class))
448438
.getTheConnection());
449-
conn1.send(new GenericMessage<String>("foo3"));
439+
conn1.send(new GenericMessage<>("foo3"));
450440
conn1.close();
451441
conn2.close();
452442
assertThat(latch1.await(10, TimeUnit.SECONDS)).isTrue();
@@ -455,8 +445,8 @@ public void testFailoverCachedRealClose() throws Exception {
455445
TestingUtilities.waitUntilFactoryHasThisNumberOfConnections(factory1, 0);
456446
conn1 = failoverFactory.getConnection();
457447
conn2 = failoverFactory.getConnection();
458-
conn1.send(new GenericMessage<String>("foo4"));
459-
conn2.send(new GenericMessage<String>("foo5"));
448+
conn1.send(new GenericMessage<>("foo4"));
449+
conn2.send(new GenericMessage<>("foo5"));
460450
conn1.close();
461451
conn2.close();
462452
assertThat(latch2.await(10, TimeUnit.SECONDS)).isTrue();
@@ -467,7 +457,7 @@ public void testFailoverCachedRealClose() throws Exception {
467457

468458
@SuppressWarnings("unchecked")
469459
@Test
470-
public void testFailoverCachedWithGateway() throws Exception {
460+
public void testFailoverCachedWithGateway() {
471461
final TcpNetServerConnectionFactory server = new TcpNetServerConnectionFactory(0);
472462
server.setBeanName("server");
473463
server.afterPropertiesSet();
@@ -490,7 +480,7 @@ public void testFailoverCachedWithGateway() throws Exception {
490480
cachingClient.afterPropertiesSet();
491481

492482
// Failover
493-
List<AbstractClientConnectionFactory> clientFactories = new ArrayList<AbstractClientConnectionFactory>();
483+
List<AbstractClientConnectionFactory> clientFactories = new ArrayList<>();
494484
clientFactories.add(cachingClient);
495485
FailoverClientConnectionFactory failoverClient = new FailoverClientConnectionFactory(clientFactories);
496486
failoverClient.setSingleUse(true);
@@ -505,13 +495,13 @@ public void testFailoverCachedWithGateway() throws Exception {
505495
outbound.afterPropertiesSet();
506496
outbound.start();
507497

508-
outbound.handleMessage(new GenericMessage<String>("foo"));
498+
outbound.handleMessage(new GenericMessage<>("foo"));
509499
Message<byte[]> result = (Message<byte[]>) replyChannel.receive(10000);
510500
assertThat(result).isNotNull();
511501
assertThat(new String(result.getPayload())).isEqualTo("foo");
512502

513503
// INT-4024 - second reply had bad connection id
514-
outbound.handleMessage(new GenericMessage<String>("foo"));
504+
outbound.handleMessage(new GenericMessage<>("foo"));
515505
result = (Message<byte[]>) replyChannel.receive(10000);
516506
assertThat(result).isNotNull();
517507
assertThat(new String(result.getPayload())).isEqualTo("foo");
@@ -557,13 +547,13 @@ public void testFailoverCachedRealBadHost() throws Exception {
557547
cachingFactory2.setBeanName("cache2");
558548

559549
// Failover
560-
List<AbstractClientConnectionFactory> factories = new ArrayList<AbstractClientConnectionFactory>();
550+
List<AbstractClientConnectionFactory> factories = new ArrayList<>();
561551
factories.add(cachingFactory1);
562552
factories.add(cachingFactory2);
563553
FailoverClientConnectionFactory failoverFactory = new FailoverClientConnectionFactory(factories);
564554
failoverFactory.start();
565555
TcpConnection conn1 = failoverFactory.getConnection();
566-
GenericMessage<String> message = new GenericMessage<String>("foo");
556+
GenericMessage<String> message = new GenericMessage<>("foo");
567557
conn1.send(message);
568558
conn1.close();
569559
TcpConnection conn2 = failoverFactory.getConnection();
@@ -595,9 +585,11 @@ private void testRealGuts(AbstractClientConnectionFactory client1, AbstractClien
595585
client2.setTaskExecutor(holder.exec);
596586
client1.setBeanName("client1");
597587
client2.setBeanName("client2");
598-
client1.setApplicationEventPublisher(NULL_PUBLISHER);
599-
client2.setApplicationEventPublisher(NULL_PUBLISHER);
600-
List<AbstractClientConnectionFactory> factories = new ArrayList<AbstractClientConnectionFactory>();
588+
client1.setApplicationEventPublisher(event -> {
589+
});
590+
client2.setApplicationEventPublisher(event -> {
591+
});
592+
List<AbstractClientConnectionFactory> factories = new ArrayList<>();
601593
factories.add(client1);
602594
factories.add(client2);
603595
FailoverClientConnectionFactory failFactory = new FailoverClientConnectionFactory(factories);
@@ -610,10 +602,10 @@ private void testRealGuts(AbstractClientConnectionFactory client1, AbstractClien
610602
outGateway.start();
611603
QueueChannel replyChannel = new QueueChannel();
612604
outGateway.setReplyChannel(replyChannel);
613-
Message<String> message = new GenericMessage<String>("foo");
605+
Message<String> message = new GenericMessage<>("foo");
614606
outGateway.setRemoteTimeout(120000);
615607
outGateway.handleMessage(message);
616-
Socket socket = null;
608+
Socket socket;
617609
if (!singleUse) {
618610
socket = getSocket(client1);
619611
port1 = socket.getLocalPort();
@@ -644,12 +636,14 @@ private Holder setupAndStartServers(AbstractServerConnectionFactory server1,
644636
server2.setTaskExecutor(exec);
645637
server1.setBeanName("server1");
646638
server2.setBeanName("server2");
647-
server1.setApplicationEventPublisher(NULL_PUBLISHER);
648-
server2.setApplicationEventPublisher(NULL_PUBLISHER);
639+
server1.setApplicationEventPublisher(event -> {
640+
});
641+
server2.setApplicationEventPublisher(event -> {
642+
});
649643
TcpInboundGateway gateway1 = new TcpInboundGateway();
650644
gateway1.setConnectionFactory(server1);
651645
SubscribableChannel channel = new DirectChannel();
652-
final AtomicReference<String> connectionId = new AtomicReference<String>();
646+
final AtomicReference<String> connectionId = new AtomicReference<>();
653647
channel.subscribe(message -> {
654648
connectionId.set((String) message.getHeaders().get(IpHeaders.CONNECTION_ID));
655649
((MessageChannel) message.getHeaders().getReplyChannel()).send(message);
@@ -695,7 +689,9 @@ private static class Holder {
695689

696690
}
697691

698-
private static AbstractClientConnectionFactory createFactoryWithMockConnection(TcpConnectionSupport mockConn) throws Exception {
692+
private static AbstractClientConnectionFactory createFactoryWithMockConnection(TcpConnectionSupport mockConn)
693+
throws Exception {
694+
699695
AbstractClientConnectionFactory factory = mock(AbstractClientConnectionFactory.class);
700696
when(factory.getConnection()).thenReturn(mockConn);
701697
when(factory.isActive()).thenReturn(true);

0 commit comments

Comments
 (0)