Skip to content

Send stream completion when client errors stream #34147

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,9 @@ private static class Log
private static readonly Action<ILogger, Exception> _errorHandshakeCanceled =
LoggerMessage.Define(LogLevel.Error, new EventId(83, "ErrorHandshakeCanceled"), "The handshake was canceled by the client.");

private static readonly Action<ILogger, string, Exception?> _erroredStream =
LoggerMessage.Define<string>(LogLevel.Trace, new EventId(84, "ErroredStream"), "Client threw an error for stream '{StreamId}'.");

public static void PreparingNonBlockingInvocation(ILogger logger, string target, int count)
{
_preparingNonBlockingInvocation(logger, target, count, null);
Expand Down Expand Up @@ -664,6 +667,11 @@ public static void ErrorHandshakeCanceled(ILogger logger, Exception exception)
{
_errorHandshakeCanceled(logger, exception);
}

public static void ErroredStream(ILogger logger, string streamId, Exception exception)
{
_erroredStream(logger, streamId, exception);
}
}
}
}
Expand Down
35 changes: 20 additions & 15 deletions src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,10 @@ private void LaunchStreams(ConnectionState connectionState, Dictionary<string, o
return;
}

_state.AssertInConnectionLock();
// It's safe to access connectionState.UploadStreamToken as we still have the connection lock
var cts = CancellationTokenSource.CreateLinkedTokenSource(connectionState.UploadStreamToken, cancellationToken);

foreach (var kvp in readers)
{
var reader = kvp.Value;
Expand All @@ -708,19 +712,19 @@ private void LaunchStreams(ConnectionState connectionState, Dictionary<string, o
{
_ = _sendIAsyncStreamItemsMethod
.MakeGenericMethod(reader.GetType().GetInterface("IAsyncEnumerable`1")!.GetGenericArguments())
.Invoke(this, new object[] { connectionState, kvp.Key.ToString(), reader, cancellationToken });
.Invoke(this, new object[] { connectionState, kvp.Key.ToString(), reader, cts });
continue;
}
_ = _sendStreamItemsMethod
.MakeGenericMethod(reader.GetType().GetGenericArguments())
.Invoke(this, new object[] { connectionState, kvp.Key.ToString(), reader, cancellationToken });
.Invoke(this, new object[] { connectionState, kvp.Key.ToString(), reader, cts });
}
}

// this is called via reflection using the `_sendStreamItems` field
private Task SendStreamItems<T>(ConnectionState connectionState, string streamId, ChannelReader<T> reader, CancellationToken token)
private Task SendStreamItems<T>(ConnectionState connectionState, string streamId, ChannelReader<T> reader, CancellationTokenSource tokenSource)
{
async Task ReadChannelStream(CancellationTokenSource tokenSource)
async Task ReadChannelStream()
{
while (await reader.WaitToReadAsync(tokenSource.Token))
{
Expand All @@ -732,13 +736,13 @@ async Task ReadChannelStream(CancellationTokenSource tokenSource)
}
}

return CommonStreaming(connectionState, streamId, token, ReadChannelStream);
return CommonStreaming(connectionState, streamId, ReadChannelStream);
}

// this is called via reflection using the `_sendIAsyncStreamItemsMethod` field
private Task SendIAsyncEnumerableStreamItems<T>(ConnectionState connectionState, string streamId, IAsyncEnumerable<T> stream, CancellationToken token)
private Task SendIAsyncEnumerableStreamItems<T>(ConnectionState connectionState, string streamId, IAsyncEnumerable<T> stream, CancellationTokenSource tokenSource)
{
async Task ReadAsyncEnumerableStream(CancellationTokenSource tokenSource)
async Task ReadAsyncEnumerableStream()
{
var streamValues = AsyncEnumerableAdapters.MakeCancelableTypedAsyncEnumerable(stream, tokenSource);

Expand All @@ -749,25 +753,26 @@ async Task ReadAsyncEnumerableStream(CancellationTokenSource tokenSource)
}
}

return CommonStreaming(connectionState, streamId, token, ReadAsyncEnumerableStream);
return CommonStreaming(connectionState, streamId, ReadAsyncEnumerableStream);
}

private async Task CommonStreaming(ConnectionState connectionState, string streamId, CancellationToken token, Func<CancellationTokenSource, Task> createAndConsumeStream)
private async Task CommonStreaming(ConnectionState connectionState, string streamId, Func<Task> createAndConsumeStream)
{
// It's safe to access connectionState.UploadStreamToken as we still have the connection lock
_state.AssertInConnectionLock();
var cts = CancellationTokenSource.CreateLinkedTokenSource(connectionState.UploadStreamToken, token);

Log.StartingStream(_logger, streamId);
string? responseError = null;
try
{
await createAndConsumeStream(cts);
await createAndConsumeStream();
}
catch (OperationCanceledException)
{
Log.CancelingStream(_logger, streamId);
responseError = $"Stream canceled by client.";
responseError = "Stream canceled by client.";
}
catch (Exception ex)
{
Log.ErroredStream(_logger, streamId, ex);
responseError = $"Stream errored by client: '{ex}'";
}

Log.CompletingStream(_logger, streamId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,32 @@ public async Task UploadStreamCancellationSendsStreamComplete()
}
}

[Fact]
[LogLevel(LogLevel.Trace)]
public async Task UploadStreamErrorSendsStreamComplete()
{
using (StartVerifiableLog())
{
var connection = new TestConnection();
var hubConnection = CreateHubConnection(connection, loggerFactory: LoggerFactory);
await hubConnection.StartAsync().DefaultTimeout();

var cts = new CancellationTokenSource();
var channel = Channel.CreateUnbounded<int>();
var invokeTask = hubConnection.InvokeAsync<object>("UploadMethod", channel.Reader, cts.Token);

var invokeMessage = await connection.ReadSentJsonAsync().DefaultTimeout();
Assert.Equal(HubProtocolConstants.InvocationMessageType, invokeMessage["type"]);

channel.Writer.Complete(new Exception("error from client"));

// the next sent message should be a completion message
var complete = await connection.ReadSentJsonAsync().DefaultTimeout();
Assert.Equal(HubProtocolConstants.CompletionMessageType, complete["type"]);
Assert.StartsWith("Stream errored by client: 'System.Exception: error from client", ((string)complete["error"]));
}
}

[Fact]
[LogLevel(LogLevel.Trace)]
public async Task InvocationCanCompleteBeforeStreamCompletes()
Expand Down