Skip to content

Commit 2387c21

Browse files
Send stream completion when client errors stream (#34147)
1 parent 15949fa commit 2387c21

File tree

3 files changed

+54
-15
lines changed

3 files changed

+54
-15
lines changed

src/SignalR/clients/csharp/Client.Core/src/HubConnection.Log.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,9 @@ private static class Log
249249
private static readonly Action<ILogger, Exception> _errorHandshakeCanceled =
250250
LoggerMessage.Define(LogLevel.Error, new EventId(83, "ErrorHandshakeCanceled"), "The handshake was canceled by the client.");
251251

252+
private static readonly Action<ILogger, string, Exception?> _erroredStream =
253+
LoggerMessage.Define<string>(LogLevel.Trace, new EventId(84, "ErroredStream"), "Client threw an error for stream '{StreamId}'.");
254+
252255
public static void PreparingNonBlockingInvocation(ILogger logger, string target, int count)
253256
{
254257
_preparingNonBlockingInvocation(logger, target, count, null);
@@ -664,6 +667,11 @@ public static void ErrorHandshakeCanceled(ILogger logger, Exception exception)
664667
{
665668
_errorHandshakeCanceled(logger, exception);
666669
}
670+
671+
public static void ErroredStream(ILogger logger, string streamId, Exception exception)
672+
{
673+
_erroredStream(logger, streamId, exception);
674+
}
667675
}
668676
}
669677
}

src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -697,6 +697,10 @@ private void LaunchStreams(ConnectionState connectionState, Dictionary<string, o
697697
return;
698698
}
699699

700+
_state.AssertInConnectionLock();
701+
// It's safe to access connectionState.UploadStreamToken as we still have the connection lock
702+
var cts = CancellationTokenSource.CreateLinkedTokenSource(connectionState.UploadStreamToken, cancellationToken);
703+
700704
foreach (var kvp in readers)
701705
{
702706
var reader = kvp.Value;
@@ -708,19 +712,19 @@ private void LaunchStreams(ConnectionState connectionState, Dictionary<string, o
708712
{
709713
_ = _sendIAsyncStreamItemsMethod
710714
.MakeGenericMethod(reader.GetType().GetInterface("IAsyncEnumerable`1")!.GetGenericArguments())
711-
.Invoke(this, new object[] { connectionState, kvp.Key.ToString(), reader, cancellationToken });
715+
.Invoke(this, new object[] { connectionState, kvp.Key.ToString(), reader, cts });
712716
continue;
713717
}
714718
_ = _sendStreamItemsMethod
715719
.MakeGenericMethod(reader.GetType().GetGenericArguments())
716-
.Invoke(this, new object[] { connectionState, kvp.Key.ToString(), reader, cancellationToken });
720+
.Invoke(this, new object[] { connectionState, kvp.Key.ToString(), reader, cts });
717721
}
718722
}
719723

720724
// this is called via reflection using the `_sendStreamItems` field
721-
private Task SendStreamItems<T>(ConnectionState connectionState, string streamId, ChannelReader<T> reader, CancellationToken token)
725+
private Task SendStreamItems<T>(ConnectionState connectionState, string streamId, ChannelReader<T> reader, CancellationTokenSource tokenSource)
722726
{
723-
async Task ReadChannelStream(CancellationTokenSource tokenSource)
727+
async Task ReadChannelStream()
724728
{
725729
while (await reader.WaitToReadAsync(tokenSource.Token))
726730
{
@@ -732,13 +736,13 @@ async Task ReadChannelStream(CancellationTokenSource tokenSource)
732736
}
733737
}
734738

735-
return CommonStreaming(connectionState, streamId, token, ReadChannelStream);
739+
return CommonStreaming(connectionState, streamId, ReadChannelStream);
736740
}
737741

738742
// this is called via reflection using the `_sendIAsyncStreamItemsMethod` field
739-
private Task SendIAsyncEnumerableStreamItems<T>(ConnectionState connectionState, string streamId, IAsyncEnumerable<T> stream, CancellationToken token)
743+
private Task SendIAsyncEnumerableStreamItems<T>(ConnectionState connectionState, string streamId, IAsyncEnumerable<T> stream, CancellationTokenSource tokenSource)
740744
{
741-
async Task ReadAsyncEnumerableStream(CancellationTokenSource tokenSource)
745+
async Task ReadAsyncEnumerableStream()
742746
{
743747
var streamValues = AsyncEnumerableAdapters.MakeCancelableTypedAsyncEnumerable(stream, tokenSource);
744748

@@ -749,25 +753,26 @@ async Task ReadAsyncEnumerableStream(CancellationTokenSource tokenSource)
749753
}
750754
}
751755

752-
return CommonStreaming(connectionState, streamId, token, ReadAsyncEnumerableStream);
756+
return CommonStreaming(connectionState, streamId, ReadAsyncEnumerableStream);
753757
}
754758

755-
private async Task CommonStreaming(ConnectionState connectionState, string streamId, CancellationToken token, Func<CancellationTokenSource, Task> createAndConsumeStream)
759+
private async Task CommonStreaming(ConnectionState connectionState, string streamId, Func<Task> createAndConsumeStream)
756760
{
757-
// It's safe to access connectionState.UploadStreamToken as we still have the connection lock
758-
_state.AssertInConnectionLock();
759-
var cts = CancellationTokenSource.CreateLinkedTokenSource(connectionState.UploadStreamToken, token);
760-
761761
Log.StartingStream(_logger, streamId);
762762
string? responseError = null;
763763
try
764764
{
765-
await createAndConsumeStream(cts);
765+
await createAndConsumeStream();
766766
}
767767
catch (OperationCanceledException)
768768
{
769769
Log.CancelingStream(_logger, streamId);
770-
responseError = $"Stream canceled by client.";
770+
responseError = "Stream canceled by client.";
771+
}
772+
catch (Exception ex)
773+
{
774+
Log.ErroredStream(_logger, streamId, ex);
775+
responseError = $"Stream errored by client: '{ex}'";
771776
}
772777

773778
Log.CompletingStream(_logger, streamId);

src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.cs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,32 @@ public async Task UploadStreamCancellationSendsStreamComplete()
517517
}
518518
}
519519

520+
[Fact]
521+
[LogLevel(LogLevel.Trace)]
522+
public async Task UploadStreamErrorSendsStreamComplete()
523+
{
524+
using (StartVerifiableLog())
525+
{
526+
var connection = new TestConnection();
527+
var hubConnection = CreateHubConnection(connection, loggerFactory: LoggerFactory);
528+
await hubConnection.StartAsync().DefaultTimeout();
529+
530+
var cts = new CancellationTokenSource();
531+
var channel = Channel.CreateUnbounded<int>();
532+
var invokeTask = hubConnection.InvokeAsync<object>("UploadMethod", channel.Reader, cts.Token);
533+
534+
var invokeMessage = await connection.ReadSentJsonAsync().DefaultTimeout();
535+
Assert.Equal(HubProtocolConstants.InvocationMessageType, invokeMessage["type"]);
536+
537+
channel.Writer.Complete(new Exception("error from client"));
538+
539+
// the next sent message should be a completion message
540+
var complete = await connection.ReadSentJsonAsync().DefaultTimeout();
541+
Assert.Equal(HubProtocolConstants.CompletionMessageType, complete["type"]);
542+
Assert.StartsWith("Stream errored by client: 'System.Exception: error from client", ((string)complete["error"]));
543+
}
544+
}
545+
520546
[Fact]
521547
[LogLevel(LogLevel.Trace)]
522548
public async Task InvocationCanCompleteBeforeStreamCompletes()

0 commit comments

Comments
 (0)