Skip to content

Commit b02a5f9

Browse files
authored
Fix regression in timeout handling. (#1373)
- Resolve regression where CSOT exception is exposed despite CSOT being disabled. - Correct premature decrease in connect timeout before connection initiation. - Encapsulate logic within TimeoutContext. JAVA-5439
1 parent 997e92f commit b02a5f9

File tree

7 files changed

+100
-58
lines changed

7 files changed

+100
-58
lines changed

driver-core/src/main/com/mongodb/internal/TimeoutContext.java

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -185,25 +185,27 @@ public long getMaxAwaitTimeMS() {
185185
return timeoutSettings.getMaxAwaitTimeMS();
186186
}
187187

188-
public void runMaxTimeMSTimeout(final Runnable onInfinite, final LongConsumer onRemaining,
189-
final Runnable onExpired) {
188+
public void runMaxTimeMS(final LongConsumer onRemaining) {
190189
if (maxTimeSupplier != null) {
191-
runWithFixedTimout(maxTimeSupplier.get(), onInfinite, onRemaining);
190+
runWithFixedTimeout(maxTimeSupplier.get(), onRemaining);
192191
return;
193192
}
194-
195-
if (timeout != null) {
196-
timeout.shortenBy(minRoundTripTimeMS, MILLISECONDS)
197-
.run(MILLISECONDS, onInfinite, onRemaining, onExpired);
198-
} else {
199-
runWithFixedTimout(timeoutSettings.getMaxTimeMS(), onInfinite, onRemaining);
193+
if (timeout == null) {
194+
runWithFixedTimeout(timeoutSettings.getMaxTimeMS(), onRemaining);
195+
return;
200196
}
197+
timeout.shortenBy(minRoundTripTimeMS, MILLISECONDS)
198+
.run(MILLISECONDS,
199+
() -> {},
200+
onRemaining,
201+
() -> {
202+
throw createMongoRoundTripTimeoutException();
203+
});
204+
201205
}
202206

203-
private static void runWithFixedTimout(final long ms, final Runnable onInfinite, final LongConsumer onRemaining) {
204-
if (ms == 0) {
205-
onInfinite.run();
206-
} else {
207+
private static void runWithFixedTimeout(final long ms, final LongConsumer onRemaining) {
208+
if (ms != 0) {
207209
onRemaining.accept(ms);
208210
}
209211
}
@@ -214,15 +216,18 @@ public void resetToDefaultMaxTime() {
214216

215217
/**
216218
* The override will be provided as the remaining value in
217-
* {@link #runMaxTimeMSTimeout}, where 0 will invoke the onExpired path
219+
* {@link #runMaxTimeMS}, where 0 is ignored.
220+
* <p>
221+
* NOTE: Suitable for static user-defined values only (i.e MaxAwaitTimeMS),
222+
* not for running timeouts that adjust dynamically.
218223
*/
219224
public void setMaxTimeOverride(final long maxTimeMS) {
220225
this.maxTimeSupplier = () -> maxTimeMS;
221226
}
222227

223228
/**
224229
* The override will be provided as the remaining value in
225-
* {@link #runMaxTimeMSTimeout}, where 0 will invoke the onExpired path
230+
* {@link #runMaxTimeMS}, where 0 is ignored.
226231
*/
227232
public void setMaxTimeOverrideToMaxCommitTime() {
228233
this.maxTimeSupplier = () -> getMaxCommitTimeMS();
@@ -242,12 +247,12 @@ public long getWriteTimeoutMS() {
242247
return timeoutOrAlternative(0);
243248
}
244249

245-
public Timeout createConnectTimeoutMs() {
246-
// null timeout treated as infinite will be later than the other
247-
248-
return Timeout.earliest(
249-
Timeout.expiresIn(getTimeoutSettings().getConnectTimeoutMS(), MILLISECONDS, ZERO_DURATION_MEANS_INFINITE),
250-
Timeout.nullAsInfinite(timeout));
250+
public int getConnectTimeoutMs() {
251+
final long connectTimeoutMS = getTimeoutSettings().getConnectTimeoutMS();
252+
return Math.toIntExact(Timeout.nullAsInfinite(timeout).call(MILLISECONDS,
253+
() -> connectTimeoutMS,
254+
(ms) -> connectTimeoutMS == 0 ? ms : Math.min(ms, connectTimeoutMS),
255+
() -> throwMongoTimeoutException("The operation timeout has expired.")));
251256
}
252257

253258
public void resetTimeout() {

driver-core/src/main/com/mongodb/internal/connection/CommandMessage.java

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -222,12 +222,9 @@ private List<BsonElement> getExtraElements(final OperationContext operationConte
222222

223223
List<BsonElement> extraElements = new ArrayList<>();
224224
if (!getSettings().isCryptd()) {
225-
timeoutContext.runMaxTimeMSTimeout(
226-
() -> {},
227-
(ms) -> extraElements.add(new BsonElement("maxTimeMS", new BsonInt64(ms))),
228-
() -> {
229-
throw TimeoutContext.createMongoRoundTripTimeoutException();
230-
});
225+
timeoutContext.runMaxTimeMS(maxTimeMS ->
226+
extraElements.add(new BsonElement("maxTimeMS", new BsonInt64(maxTimeMS)))
227+
);
231228
}
232229
extraElements.add(new BsonElement("$db", new BsonString(new MongoNamespace(getCollectionName()).getDatabaseName())));
233230
if (sessionContext.getClusterTime() != null) {

driver-core/src/main/com/mongodb/internal/connection/SocketStream.java

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
import static com.mongodb.internal.connection.SocketStreamHelper.configureSocket;
4747
import static com.mongodb.internal.connection.SslHelper.configureSslSocket;
4848
import static com.mongodb.internal.thread.InterruptionUtil.translateInterruptedException;
49-
import static java.util.concurrent.TimeUnit.MILLISECONDS;
5049

5150
/**
5251
* <p>This class is not part of the public API and may be removed or changed at any time</p>
@@ -122,10 +121,7 @@ private SSLSocket initializeSslSocketOverSocksProxy(final OperationContext opera
122121
SocksSocket socksProxy = new SocksSocket(settings.getProxySettings());
123122
configureSocket(socksProxy, operationContext, settings);
124123
InetSocketAddress inetSocketAddress = toSocketAddress(serverHost, serverPort);
125-
operationContext.getTimeoutContext().createConnectTimeoutMs().checkedRun(MILLISECONDS,
126-
() -> socksProxy.connect(inetSocketAddress, 0),
127-
(ms) -> socksProxy.connect(inetSocketAddress, Math.toIntExact(ms)),
128-
() -> throwMongoTimeoutException("The operation timeout has expired."));
124+
socksProxy.connect(inetSocketAddress, operationContext.getTimeoutContext().getConnectTimeoutMs());
129125

130126
SSLSocket sslSocket = (SSLSocket) sslSocketFactory.createSocket(socksProxy, serverHost, serverPort, true);
131127
//Even though Socks proxy connection is already established, TLS handshake has not been performed yet.
@@ -153,11 +149,8 @@ private Socket initializeSocketOverSocksProxy(final OperationContext operationCo
153149
*/
154150
SocksSocket socksProxy = new SocksSocket(createdSocket, settings.getProxySettings());
155151

156-
InetSocketAddress inetSocketAddress = toSocketAddress(address.getHost(), address.getPort());
157-
operationContext.getTimeoutContext().createConnectTimeoutMs().checkedRun(MILLISECONDS,
158-
() -> socksProxy.connect(inetSocketAddress, 0),
159-
(ms) -> socksProxy.connect(inetSocketAddress, Math.toIntExact(ms)),
160-
() -> throwMongoTimeoutException("The operation timeout has expired."));
152+
socksProxy.connect(toSocketAddress(address.getHost(), address.getPort()),
153+
operationContext.getTimeoutContext().getConnectTimeoutMs());
161154
return socksProxy;
162155
}
163156

@@ -185,9 +178,7 @@ public ByteBuf read(final int numBytes, final OperationContext operationContext)
185178
byte[] bytes = buffer.array();
186179
while (totalBytesRead < buffer.limit()) {
187180
int readTimeoutMS = (int) operationContext.getTimeoutContext().getReadTimeoutMS();
188-
if (readTimeoutMS > 0) {
189-
socket.setSoTimeout(readTimeoutMS);
190-
}
181+
socket.setSoTimeout(readTimeoutMS);
191182
int bytesRead = inputStream.read(bytes, totalBytesRead, buffer.limit() - totalBytesRead);
192183
if (bytesRead == -1) {
193184
throw new MongoSocketReadException("Prematurely reached end of stream", getAddress());

driver-core/src/main/com/mongodb/internal/connection/SocketStreamHelper.java

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,7 @@
2727
import java.net.SocketException;
2828
import java.net.SocketOption;
2929

30-
import static com.mongodb.internal.TimeoutContext.throwMongoTimeoutException;
3130
import static com.mongodb.internal.connection.SslHelper.configureSslSocket;
32-
import static java.util.concurrent.TimeUnit.MILLISECONDS;
3331

3432
@SuppressWarnings({"unchecked", "rawtypes"})
3533
final class SocketStreamHelper {
@@ -75,10 +73,7 @@ static void initialize(final OperationContext operationContext, final Socket soc
7573
final SslSettings sslSettings) throws IOException {
7674
configureSocket(socket, operationContext, settings);
7775
configureSslSocket(socket, sslSettings, inetSocketAddress);
78-
operationContext.getTimeoutContext().createConnectTimeoutMs().checkedRun(MILLISECONDS,
79-
() -> socket.connect(inetSocketAddress, 0),
80-
(ms) -> socket.connect(inetSocketAddress, Math.toIntExact(ms)),
81-
() -> throwMongoTimeoutException("The operation timeout has expired."));
76+
socket.connect(inetSocketAddress, operationContext.getTimeoutContext().getConnectTimeoutMs());
8277
}
8378

8479
static void configureSocket(final Socket socket, final OperationContext operationContext, final SocketSettings settings) throws SocketException {

driver-core/src/main/com/mongodb/internal/connection/netty/NettyStream.java

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@
7171

7272
import static com.mongodb.assertions.Assertions.assertNotNull;
7373
import static com.mongodb.internal.Locks.withLock;
74-
import static com.mongodb.internal.TimeoutContext.throwMongoTimeoutException;
7574
import static com.mongodb.internal.connection.ServerAddressHelper.getSocketAddresses;
7675
import static com.mongodb.internal.connection.SslHelper.enableHostNameVerification;
7776
import static com.mongodb.internal.connection.SslHelper.enableSni;
@@ -192,10 +191,8 @@ private void initializeChannel(final OperationContext operationContext, final As
192191
Bootstrap bootstrap = new Bootstrap();
193192
bootstrap.group(workerGroup);
194193
bootstrap.channel(socketChannelClass);
195-
operationContext.getTimeoutContext().createConnectTimeoutMs().checkedRun(MILLISECONDS,
196-
() -> bootstrap.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 0),
197-
(ms) -> bootstrap.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, Math.toIntExact(ms)),
198-
() -> throwMongoTimeoutException("The operation timeout has expired."));
194+
bootstrap.option(ChannelOption.CONNECT_TIMEOUT_MILLIS,
195+
operationContext.getTimeoutContext().getConnectTimeoutMs());
199196
bootstrap.option(ChannelOption.TCP_NODELAY, true);
200197
bootstrap.option(ChannelOption.SO_KEEPALIVE, true);
201198

driver-core/src/test/unit/com/mongodb/internal/TimeoutContextTest.java

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,13 @@
1919
import com.mongodb.session.ClientSession;
2020
import org.junit.jupiter.api.DisplayName;
2121
import org.junit.jupiter.api.Test;
22+
import org.junit.jupiter.params.ParameterizedTest;
23+
import org.junit.jupiter.params.provider.Arguments;
24+
import org.junit.jupiter.params.provider.MethodSource;
2225
import org.mockito.Mockito;
2326

2427
import java.util.function.Supplier;
28+
import java.util.stream.Stream;
2529

2630
import static com.mongodb.ClusterFixture.TIMEOUT_SETTINGS;
2731
import static com.mongodb.ClusterFixture.TIMEOUT_SETTINGS_WITH_INFINITE_TIMEOUT;
@@ -42,12 +46,7 @@ final class TimeoutContextTest {
4246

4347
public static long getMaxTimeMS(final TimeoutContext timeoutContext) {
4448
long[] result = {0L};
45-
timeoutContext.runMaxTimeMSTimeout(
46-
() -> {},
47-
(ms) -> result[0] = ms,
48-
() -> {
49-
throw TimeoutContext.createMongoRoundTripTimeoutException();
50-
});
49+
timeoutContext.runMaxTimeMS((ms) -> result[0] = ms);
5150
return result[0];
5251
}
5352

@@ -198,16 +197,19 @@ void testThrowsWhenExpired() {
198197

199198
assertThrows(MongoOperationTimeoutException.class, smallTimeout::getReadTimeoutMS);
200199
assertThrows(MongoOperationTimeoutException.class, smallTimeout::getWriteTimeoutMS);
200+
assertThrows(MongoOperationTimeoutException.class, smallTimeout::getConnectTimeoutMs);
201201
assertThrows(MongoOperationTimeoutException.class, () -> getMaxTimeMS(smallTimeout));
202202
assertThrows(MongoOperationTimeoutException.class, smallTimeout::getMaxCommitTimeMS);
203203
assertThrows(MongoOperationTimeoutException.class, () -> smallTimeout.timeoutOrAlternative(1));
204204
assertDoesNotThrow(longTimeout::getReadTimeoutMS);
205205
assertDoesNotThrow(longTimeout::getWriteTimeoutMS);
206+
assertDoesNotThrow(longTimeout::getConnectTimeoutMs);
206207
assertDoesNotThrow(() -> getMaxTimeMS(longTimeout));
207208
assertDoesNotThrow(longTimeout::getMaxCommitTimeMS);
208209
assertDoesNotThrow(() -> longTimeout.timeoutOrAlternative(1));
209210
assertDoesNotThrow(noTimeout::getReadTimeoutMS);
210211
assertDoesNotThrow(noTimeout::getWriteTimeoutMS);
212+
assertDoesNotThrow(noTimeout::getConnectTimeoutMs);
211213
assertDoesNotThrow(() -> getMaxTimeMS(noTimeout));
212214
assertDoesNotThrow(noTimeout::getMaxCommitTimeMS);
213215
assertDoesNotThrow(() -> noTimeout.timeoutOrAlternative(1));
@@ -284,6 +286,61 @@ void shouldResetMaximeMS() {
284286
assertTrue(getMaxTimeMS(timeoutContext) > 1);
285287
}
286288

289+
static Stream<Arguments> shouldChooseConnectTimeoutWhenItIsLessThenTimeoutMs() {
290+
return Stream.of(
291+
//connectTimeoutMS, timeoutMS, expected
292+
Arguments.of(500L, 1000L, 500L),
293+
Arguments.of(0L, null, 0L),
294+
Arguments.of(1000L, null, 1000L),
295+
Arguments.of(1000L, 0L, 1000L),
296+
Arguments.of(0L, 0L, 0L)
297+
);
298+
}
299+
300+
@ParameterizedTest
301+
@MethodSource
302+
@DisplayName("should choose connectTimeoutMS when connectTimeoutMS is less than timeoutMS")
303+
void shouldChooseConnectTimeoutWhenItIsLessThenTimeoutMs(final Long connectTimeoutMS,
304+
final Long timeoutMS,
305+
final long expected) {
306+
TimeoutContext timeoutContext = new TimeoutContext(
307+
new TimeoutSettings(0,
308+
connectTimeoutMS,
309+
0,
310+
timeoutMS,
311+
0));
312+
313+
long calculatedTimeoutMS = timeoutContext.getConnectTimeoutMs();
314+
assertEquals(expected, calculatedTimeoutMS);
315+
}
316+
317+
318+
static Stream<Arguments> shouldChooseTimeoutMsWhenItIsLessThenConnectTimeoutMS() {
319+
return Stream.of(
320+
//connectTimeoutMS, timeoutMS, expected
321+
Arguments.of(1000L, 1000L, 999),
322+
Arguments.of(1000L, 500L, 499L),
323+
Arguments.of(0L, 1000L, 999L)
324+
);
325+
}
326+
327+
@ParameterizedTest
328+
@MethodSource
329+
@DisplayName("should choose timeoutMS when timeoutMS is less than connectTimeoutMS")
330+
void shouldChooseTimeoutMsWhenItIsLessThenConnectTimeoutMS(final Long connectTimeoutMS,
331+
final Long timeoutMS,
332+
final long expected) {
333+
TimeoutContext timeoutContext = new TimeoutContext(
334+
new TimeoutSettings(0,
335+
connectTimeoutMS,
336+
0,
337+
timeoutMS,
338+
0));
339+
340+
long calculatedTimeoutMS = timeoutContext.getConnectTimeoutMs();
341+
assertTrue(expected - calculatedTimeoutMS <= 1);
342+
}
343+
287344
private TimeoutContextTest() {
288345
}
289346
}

driver-core/src/test/unit/com/mongodb/internal/connection/CommandMessageTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ void encodeShouldThrowTimeoutExceptionWhenTimeoutContextIsCalled() {
6060
BasicOutputBuffer bsonOutput = new BasicOutputBuffer();
6161
SessionContext sessionContext = mock(SessionContext.class);
6262
TimeoutContext timeoutContext = mock(TimeoutContext.class, mock -> {
63-
doThrow(new MongoOperationTimeoutException("test")).when(mock).runMaxTimeMSTimeout(any(), any(), any());
63+
doThrow(new MongoOperationTimeoutException("test")).when(mock).runMaxTimeMS(any());
6464
});
6565
OperationContext operationContext = mock(OperationContext.class, mock -> {
6666
when(mock.getSessionContext()).thenReturn(sessionContext);

0 commit comments

Comments
 (0)