Skip to content

Commit e44e5c7

Browse files
committed
Add throwTranslatedWriteException, refactoring, async helper
1 parent 330d3b1 commit e44e5c7

File tree

3 files changed

+110
-39
lines changed

3 files changed

+110
-39
lines changed

driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,42 @@ default AsyncRunnable thenRun(final AsyncRunnable runnable) {
178178
};
179179
}
180180

181+
/**
182+
* The error check checks if the exception is an instance of the provided class.
183+
* @see #thenRunTryCatchAsyncBlocks(AsyncRunnable, java.util.function.Predicate, AsyncFunction)
184+
*/
185+
default <T extends Throwable> AsyncRunnable thenRunTryCatchAsyncBlocks(
186+
final AsyncRunnable runnable,
187+
final Class<T> exceptionClass,
188+
final AsyncFunction<Throwable, Void> errorFunction) {
189+
return thenRunTryCatchAsyncBlocks(runnable, e -> exceptionClass.isInstance(e), errorFunction);
190+
}
191+
192+
/**
193+
* Convenience method corresponding to a try-catch block in sync code.
194+
* This MUST be used to properly handle cases where there is code above
195+
* the block, whose errors must not be caught by an ensuing
196+
* {@link #onErrorIf(java.util.function.Predicate, AsyncFunction)}.
197+
*
198+
* @param runnable corresponds to the contents of the try block
199+
* @param errorCheck for matching on an error (or, a more complex condition)
200+
* @param errorFunction corresponds to the contents of the catch block
201+
* @return the composition of this runnable, a runnable that runs the
202+
* provided runnable, followed by (composed with) the error function, which
203+
* is conditional on there being an exception meeting the error check.
204+
*/
205+
default AsyncRunnable thenRunTryCatchAsyncBlocks(
206+
final AsyncRunnable runnable,
207+
final Predicate<Throwable> errorCheck,
208+
final AsyncFunction<Throwable, Void> errorFunction) {
209+
return this.thenRun(c -> {
210+
beginAsync()
211+
.thenRun(runnable)
212+
.onErrorIf(errorCheck, errorFunction)
213+
.finish(c);
214+
});
215+
}
216+
181217
/**
182218
* @param condition the condition to check
183219
* @param runnable The async runnable to run after this runnable,

driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java

Lines changed: 22 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@
7373
import static com.mongodb.assertions.Assertions.isTrue;
7474
import static com.mongodb.assertions.Assertions.notNull;
7575
import static com.mongodb.internal.async.AsyncRunnable.beginAsync;
76-
import static com.mongodb.internal.async.ErrorHandlingResultCallback.errorHandlingCallback;
7776
import static com.mongodb.internal.connection.Authenticator.shouldAuthenticate;
7877
import static com.mongodb.internal.connection.CommandHelper.HELLO;
7978
import static com.mongodb.internal.connection.CommandHelper.LEGACY_HELLO;
@@ -633,19 +632,34 @@ private <T> T getCommandResult(final Decoder<T> decoder, final ResponseBuffers r
633632
@Override
634633
public void sendMessage(final List<ByteBuf> byteBuffers, final int lastRequestId) {
635634
notNull("stream is open", stream);
636-
637635
if (isClosed()) {
638636
throw new MongoSocketClosedException("Cannot write to a closed stream", getServerAddress());
639637
}
640-
641638
try {
642639
stream.write(byteBuffers);
643640
} catch (Exception e) {
644641
close();
645-
throw translateWriteException(e);
642+
throwTranslatedWriteException(e);
646643
}
647644
}
648645

646+
@Override
647+
public void sendMessageAsync(final List<ByteBuf> byteBuffers, final int lastRequestId,
648+
final SingleResultCallback<Void> callback) {
649+
beginAsync().thenRun((c) -> {
650+
notNull("stream is open", stream);
651+
if (isClosed()) {
652+
throw new MongoSocketClosedException("Cannot write to a closed stream", getServerAddress());
653+
}
654+
c.complete(c);
655+
}).thenRunTryCatchAsyncBlocks(c -> {
656+
stream.writeAsync(byteBuffers, c.asHandler());
657+
}, Exception.class, (e, c) -> {
658+
close();
659+
throwTranslatedWriteException(e);
660+
}).finish(callback);
661+
}
662+
649663
@Override
650664
public ResponseBuffers receiveMessage(final int responseTo) {
651665
assertNotNull(stream);
@@ -665,39 +679,6 @@ private ResponseBuffers receiveMessageWithAdditionalTimeout(final int additional
665679
}
666680
}
667681

668-
@Override
669-
public void sendMessageAsync(final List<ByteBuf> byteBuffers, final int lastRequestId,
670-
final SingleResultCallback<Void> callback) {
671-
assertNotNull(stream);
672-
673-
if (isClosed()) {
674-
callback.onResult(null, new MongoSocketClosedException("Can not read from a closed socket", getServerAddress()));
675-
return;
676-
}
677-
678-
writeAsync(byteBuffers, errorHandlingCallback(callback, LOGGER));
679-
}
680-
681-
private void writeAsync(final List<ByteBuf> byteBuffers, final SingleResultCallback<Void> callback) {
682-
try {
683-
stream.writeAsync(byteBuffers, new AsyncCompletionHandler<Void>() {
684-
@Override
685-
public void completed(@Nullable final Void v) {
686-
callback.onResult(null, null);
687-
}
688-
689-
@Override
690-
public void failed(final Throwable t) {
691-
close();
692-
callback.onResult(null, translateWriteException(t));
693-
}
694-
});
695-
} catch (Throwable t) {
696-
close();
697-
callback.onResult(null, t);
698-
}
699-
}
700-
701682
@Override
702683
public void receiveMessageAsync(final int responseTo, final SingleResultCallback<ResponseBuffers> callback) {
703684
assertNotNull(stream);
@@ -762,6 +743,10 @@ private void updateSessionContext(final SessionContext sessionContext, final Res
762743
}
763744
}
764745

746+
private void throwTranslatedWriteException(final Throwable e) {
747+
throw translateWriteException(e);
748+
}
749+
765750
private MongoException translateWriteException(final Throwable e) {
766751
if (e instanceof MongoException) {
767752
return (MongoException) e;

driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTest.java

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -393,8 +393,8 @@ void testTryCatch() {
393393
// chain of 2 in try.
394394
// WARNING: "onErrorIf" will consider everything in
395395
// the preceding chain to be part of the try.
396-
// Use nested async chains to define the beginning
397-
// of the "try".
396+
// Use nested async chains, or convenience methods,
397+
// to define the beginning of the try.
398398
assertBehavesSameVariations(5,
399399
() -> {
400400
try {
@@ -491,6 +491,56 @@ void testTryCatch() {
491491
});
492492
}
493493

494+
@Test
495+
void testTryCatchHelper() {
496+
assertBehavesSameVariations(4,
497+
() -> {
498+
plain(0);
499+
try {
500+
sync(1);
501+
} catch (Throwable t) {
502+
plain(2);
503+
throw t;
504+
}
505+
},
506+
(callback) -> {
507+
beginAsync().thenRun(c -> {
508+
plain(0);
509+
c.complete(c);
510+
}).thenRunTryCatchAsyncBlocks(c -> {
511+
async(1, c);
512+
}, Throwable.class, (t, c) -> {
513+
plain(2);
514+
c.completeExceptionally(t);
515+
}).finish(callback);
516+
});
517+
518+
assertBehavesSameVariations(5,
519+
() -> {
520+
plain(0);
521+
try {
522+
sync(1);
523+
} catch (Throwable t) {
524+
plain(2);
525+
throw t;
526+
}
527+
sync(4);
528+
},
529+
(callback) -> {
530+
beginAsync().thenRun(c -> {
531+
plain(0);
532+
c.complete(c);
533+
}).thenRunTryCatchAsyncBlocks(c -> {
534+
async(1, c);
535+
}, Throwable.class, (t, c) -> {
536+
plain(2);
537+
c.completeExceptionally(t);
538+
}).thenRun(c -> {
539+
async(4, c);
540+
}).finish(callback);
541+
});
542+
}
543+
494544
@Test
495545
void testTryCatchWithVariables() {
496546
// using supply etc.

0 commit comments

Comments
 (0)