Skip to content

Add throwTranslatedWriteException, refactoring, async helper #1379

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,42 @@ default AsyncRunnable thenRun(final AsyncRunnable runnable) {
};
}

/**
* The error check checks if the exception is an instance of the provided class.
* @see #thenRunTryCatchAsyncBlocks(AsyncRunnable, java.util.function.Predicate, AsyncFunction)
*/
default <T extends Throwable> AsyncRunnable thenRunTryCatchAsyncBlocks(
final AsyncRunnable runnable,
final Class<T> exceptionClass,
final AsyncFunction<Throwable, Void> errorFunction) {
return thenRunTryCatchAsyncBlocks(runnable, e -> exceptionClass.isInstance(e), errorFunction);
}

/**
* Convenience method corresponding to a try-catch block in sync code.
* This MUST be used to properly handle cases where there is code above
* the block, whose errors must not be caught by an ensuing
* {@link #onErrorIf(java.util.function.Predicate, AsyncFunction)}.
*
* @param runnable corresponds to the contents of the try block
* @param errorCheck for matching on an error (or, a more complex condition)
* @param errorFunction corresponds to the contents of the catch block
* @return the composition of this runnable, a runnable that runs the
* provided runnable, followed by (composed with) the error function, which
* is conditional on there being an exception meeting the error check.
*/
default AsyncRunnable thenRunTryCatchAsyncBlocks(
final AsyncRunnable runnable,
final Predicate<Throwable> errorCheck,
final AsyncFunction<Throwable, Void> errorFunction) {
return this.thenRun(c -> {
beginAsync()
.thenRun(runnable)
.onErrorIf(errorCheck, errorFunction)
.finish(c);
});
}

/**
* @param condition the condition to check
* @param runnable The async runnable to run after this runnable,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ static BsonDocument executeCommandWithoutCheckingForFailure(final String databas
static void executeCommandAsync(final String database, final BsonDocument command, final ClusterConnectionMode clusterConnectionMode,
@Nullable final ServerApi serverApi, final InternalConnection internalConnection,
final SingleResultCallback<BsonDocument> callback) {
internalConnection.sendAndReceiveAsync(getCommandMessage(database, command, internalConnection, clusterConnectionMode, serverApi),
internalConnection.sendAndReceiveAsync(
getCommandMessage(database, command, internalConnection, clusterConnectionMode, serverApi),
new BsonDocumentCodec(),
NoOpSessionContext.INSTANCE, IgnorableRequestContext.INSTANCE, new OperationContext(), (result, t) -> {
if (t != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -633,19 +633,34 @@ private <T> T getCommandResult(final Decoder<T> decoder, final ResponseBuffers r
@Override
public void sendMessage(final List<ByteBuf> byteBuffers, final int lastRequestId) {
notNull("stream is open", stream);

if (isClosed()) {
throw new MongoSocketClosedException("Cannot write to a closed stream", getServerAddress());
}

try {
stream.write(byteBuffers);
} catch (Exception e) {
close();
throw translateWriteException(e);
throwTranslatedWriteException(e);
}
}

@Override
public void sendMessageAsync(final List<ByteBuf> byteBuffers, final int lastRequestId,
final SingleResultCallback<Void> callback) {
beginAsync().thenRun((c) -> {
notNull("stream is open", stream);
if (isClosed()) {
throw new MongoSocketClosedException("Cannot write to a closed stream", getServerAddress());
}
c.complete(c);
Copy link
Member

@vbabanin vbabanin May 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current implementation of the complete method requires the callback object to be passed to itself c.complete(c);. As i see, this pattern is intended to prevent accidental misuse of complete(T) when a result is not meant to be returned, as the method asserts that the callback is the same as the instance (Assertions.assertTrue(callback == this);). However, this pattern can be confusing and counterintuitive.

Wouldn't it be more straightforward to simplify this method to a parameterless complete()? This would make the method's usage intuitive c.complete();, eliminating the awkwardness of having to pass the callback to itself.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The purpose is actually the opposite, to prevent misuse of the void complete() when a result should be returned. This is a workaround to force a typecheck on the return type. It is very easy to make mistakes in the boilerplate (calling the wrong callback, passing the wrong value, and so on), whether in the old or new async code. I think it is important to prevent this (which is a correctness issue, and constant problem), even if it makes the API unusual (which is only ugly but not a correctness issue, and we only have to pay the price of getting used to this a few times).

}).thenRunTryCatchAsyncBlocks(c -> {
stream.writeAsync(byteBuffers, c.asHandler());
}, Exception.class, (e, c) -> {
close();
throwTranslatedWriteException(e);
}).finish(errorHandlingCallback(callback, LOGGER));
}

@Override
public ResponseBuffers receiveMessage(final int responseTo) {
assertNotNull(stream);
Expand All @@ -665,39 +680,6 @@ private ResponseBuffers receiveMessageWithAdditionalTimeout(final int additional
}
}

@Override
public void sendMessageAsync(final List<ByteBuf> byteBuffers, final int lastRequestId,
final SingleResultCallback<Void> callback) {
assertNotNull(stream);

if (isClosed()) {
callback.onResult(null, new MongoSocketClosedException("Can not read from a closed socket", getServerAddress()));
return;
}

writeAsync(byteBuffers, errorHandlingCallback(callback, LOGGER));
}

private void writeAsync(final List<ByteBuf> byteBuffers, final SingleResultCallback<Void> callback) {
try {
stream.writeAsync(byteBuffers, new AsyncCompletionHandler<Void>() {
@Override
public void completed(@Nullable final Void v) {
callback.onResult(null, null);
}

@Override
public void failed(final Throwable t) {
close();
callback.onResult(null, translateWriteException(t));
}
});
} catch (Throwable t) {
close();
callback.onResult(null, t);
}
}

@Override
public void receiveMessageAsync(final int responseTo, final SingleResultCallback<ResponseBuffers> callback) {
assertNotNull(stream);
Expand Down Expand Up @@ -762,6 +744,10 @@ private void updateSessionContext(final SessionContext sessionContext, final Res
}
}

private void throwTranslatedWriteException(final Throwable e) {
throw translateWriteException(e);
}

private MongoException translateWriteException(final Throwable e) {
if (e instanceof MongoException) {
return (MongoException) e;
Expand Down
Loading