Skip to content

Commit ebb9ad2

Browse files
authored
Client to Server Streaming with IAsyncEnumerable (#9310)
1 parent 6074daa commit ebb9ad2

File tree

6 files changed

+222
-19
lines changed

6 files changed

+222
-19
lines changed

src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs

Lines changed: 49 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ public partial class HubConnection
4242
private readonly SemaphoreSlim _connectionLock = new SemaphoreSlim(1, 1);
4343

4444
private static readonly MethodInfo _sendStreamItemsMethod = typeof(HubConnection).GetMethods(BindingFlags.NonPublic | BindingFlags.Instance).Single(m => m.Name.Equals("SendStreamItems"));
45-
45+
#if NETCOREAPP3_0
46+
private static readonly MethodInfo _sendIAsyncStreamItemsMethod = typeof(HubConnection).GetMethods(BindingFlags.NonPublic | BindingFlags.Instance).Single(m => m.Name.Equals("SendIAsyncEnumerableStreamItems"));
47+
#endif
4648
// Persistent across all connections
4749
private readonly ILoggerFactory _loggerFactory;
4850
private readonly ILogger _logger;
@@ -533,13 +535,11 @@ async Task OnStreamCanceled(InvocationRequest irq)
533535
}
534536

535537
LaunchStreams(readers, cancellationToken);
536-
537538
return channel;
538539
}
539540

540541
private Dictionary<string, object> PackageStreamingParams(ref object[] args, out List<string> streamIds)
541542
{
542-
// lazy initialized, to avoid allocating unecessary dictionaries
543543
Dictionary<string, object> readers = null;
544544
streamIds = null;
545545
var newArgs = new List<object>(args.Length);
@@ -572,7 +572,6 @@ private Dictionary<string, object> PackageStreamingParams(ref object[] args, out
572572
}
573573

574574
args = newArgs.ToArray();
575-
576575
return readers;
577576
}
578577

@@ -590,31 +589,68 @@ private void LaunchStreams(Dictionary<string, object> readers, CancellationToken
590589
// For each stream that needs to be sent, run a "send items" task in the background.
591590
// This reads from the channel, attaches streamId, and sends to server.
592591
// A single background thread here quickly gets messy.
592+
#if NETCOREAPP3_0
593+
if (ReflectionHelper.IsIAsyncEnumerable(reader.GetType()))
594+
{
595+
_ = _sendIAsyncStreamItemsMethod
596+
.MakeGenericMethod(reader.GetType().GetInterface("IAsyncEnumerable`1").GetGenericArguments())
597+
.Invoke(this, new object[] { kvp.Key.ToString(), reader, cancellationToken });
598+
continue;
599+
}
600+
#endif
593601
_ = _sendStreamItemsMethod
594602
.MakeGenericMethod(reader.GetType().GetGenericArguments())
595603
.Invoke(this, new object[] { kvp.Key.ToString(), reader, cancellationToken });
596604
}
597605
}
598606

599607
// this is called via reflection using the `_sendStreamItems` field
600-
private async Task SendStreamItems<T>(string streamId, ChannelReader<T> reader, CancellationToken token)
608+
private Task SendStreamItems<T>(string streamId, ChannelReader<T> reader, CancellationToken token)
601609
{
602-
Log.StartingStream(_logger, streamId);
603-
604-
var combinedToken = CancellationTokenSource.CreateLinkedTokenSource(_uploadStreamToken, token).Token;
605-
606-
string responseError = null;
607-
try
610+
async Task ReadChannelStream(CancellationTokenSource tokenSource)
608611
{
609-
while (await reader.WaitToReadAsync(combinedToken))
612+
while (await reader.WaitToReadAsync(tokenSource.Token))
610613
{
611-
while (!combinedToken.IsCancellationRequested && reader.TryRead(out var item))
614+
while (!tokenSource.Token.IsCancellationRequested && reader.TryRead(out var item))
612615
{
613616
await SendWithLock(new StreamItemMessage(streamId, item));
614617
Log.SendingStreamItem(_logger, streamId);
615618
}
616619
}
617620
}
621+
622+
return CommonStreaming(streamId, token, ReadChannelStream);
623+
}
624+
625+
#if NETCOREAPP3_0
626+
// this is called via reflection using the `_sendIAsyncStreamItemsMethod` field
627+
private Task SendIAsyncEnumerableStreamItems<T>(string streamId, IAsyncEnumerable<T> stream, CancellationToken token)
628+
{
629+
async Task ReadAsyncEnumerableStream(CancellationTokenSource tokenSource)
630+
{
631+
var streamValues = AsyncEnumerableAdapters.MakeCancelableTypedAsyncEnumerable(stream, tokenSource);
632+
633+
await foreach (var streamValue in streamValues)
634+
{
635+
await SendWithLock(new StreamItemMessage(streamId, streamValue));
636+
Log.SendingStreamItem(_logger, streamId);
637+
}
638+
}
639+
640+
return CommonStreaming(streamId, token, ReadAsyncEnumerableStream);
641+
}
642+
#endif
643+
644+
private async Task CommonStreaming(string streamId, CancellationToken token, Func<CancellationTokenSource, Task> createAndConsumeStream)
645+
{
646+
var cts = CancellationTokenSource.CreateLinkedTokenSource(_uploadStreamToken, token);
647+
648+
Log.StartingStream(_logger, streamId);
649+
string responseError = null;
650+
try
651+
{
652+
await createAndConsumeStream(cts);
653+
}
618654
catch (OperationCanceledException)
619655
{
620656
Log.CancelingStream(_logger, streamId);

src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -661,6 +661,106 @@ public async Task CanStreamToAndFromClientInSameInvocation(string protocolName,
661661
}
662662
}
663663

664+
[Theory]
665+
[MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))]
666+
[LogLevel(LogLevel.Trace)]
667+
public async Task CanStreamToServerWithIAsyncEnumerable(string protocolName, HttpTransportType transportType, string path)
668+
{
669+
var protocol = HubProtocols[protocolName];
670+
using (StartServer<Startup>(out var server))
671+
{
672+
var connection = CreateHubConnection(server.Url, path, transportType, protocol, LoggerFactory);
673+
try
674+
{
675+
async IAsyncEnumerable<string> clientStreamData()
676+
{
677+
var items = new string[] { "A", "B", "C", "D" };
678+
foreach (var item in items)
679+
{
680+
await Task.Delay(10);
681+
yield return item;
682+
}
683+
}
684+
685+
await connection.StartAsync().OrTimeout();
686+
687+
var stream = clientStreamData();
688+
689+
var channel = await connection.StreamAsChannelAsync<string>("StreamEcho", stream).OrTimeout();
690+
691+
Assert.Equal("A", await channel.ReadAsync().AsTask().OrTimeout());
692+
Assert.Equal("B", await channel.ReadAsync().AsTask().OrTimeout());
693+
Assert.Equal("C", await channel.ReadAsync().AsTask().OrTimeout());
694+
Assert.Equal("D", await channel.ReadAsync().AsTask().OrTimeout());
695+
696+
var results = await channel.ReadAndCollectAllAsync().OrTimeout();
697+
Assert.Empty(results);
698+
}
699+
catch (Exception ex)
700+
{
701+
LoggerFactory.CreateLogger<HubConnectionTests>().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName);
702+
throw;
703+
}
704+
finally
705+
{
706+
await connection.DisposeAsync().OrTimeout();
707+
}
708+
}
709+
}
710+
711+
[Theory]
712+
[MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))]
713+
[LogLevel(LogLevel.Trace)]
714+
public async Task CanCancelIAsyncEnumerableClientToServerUpload(string protocolName, HttpTransportType transportType, string path)
715+
{
716+
var protocol = HubProtocols[protocolName];
717+
using (StartServer<Startup>(out var server))
718+
{
719+
var connection = CreateHubConnection(server.Url, path, transportType, protocol, LoggerFactory);
720+
try
721+
{
722+
async IAsyncEnumerable<int> clientStreamData()
723+
{
724+
for (var i = 0; i < 1000; i++)
725+
{
726+
yield return i;
727+
await Task.Delay(10);
728+
}
729+
}
730+
731+
await connection.StartAsync().OrTimeout();
732+
var results = new List<int>();
733+
var stream = clientStreamData();
734+
var cts = new CancellationTokenSource();
735+
var ex = await Assert.ThrowsAsync<OperationCanceledException>(async () =>
736+
{
737+
var channel = await connection.StreamAsChannelAsync<int>("StreamEchoInt", stream, cts.Token).OrTimeout();
738+
739+
while (await channel.WaitToReadAsync())
740+
{
741+
while (channel.TryRead(out var item))
742+
{
743+
results.Add(item);
744+
cts.Cancel();
745+
}
746+
}
747+
});
748+
749+
Assert.True(results.Count > 0 && results.Count < 1000);
750+
Assert.True(cts.IsCancellationRequested);
751+
}
752+
catch (Exception ex)
753+
{
754+
LoggerFactory.CreateLogger<HubConnectionTests>().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName);
755+
throw;
756+
}
757+
finally
758+
{
759+
await connection.DisposeAsync().OrTimeout();
760+
}
761+
}
762+
}
763+
664764
[Theory]
665765
[MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))]
666766
[LogLevel(LogLevel.Trace)]
@@ -673,7 +773,7 @@ public async Task StreamAsyncCanBeCanceledThroughGetAsyncEnumerator(string proto
673773
try
674774
{
675775
await connection.StartAsync().OrTimeout();
676-
var stream = connection.StreamAsync<int>("Stream", 1000 );
776+
var stream = connection.StreamAsync<int>("Stream", 1000);
677777
var results = new List<int>();
678778

679779
var cts = new CancellationTokenSource();

src/SignalR/clients/csharp/Client/test/FunctionalTests/Hubs.cs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ public string GetCallerConnectionId()
4343

4444
public ChannelReader<string> StreamEcho(ChannelReader<string> source) => TestHubMethodsImpl.StreamEcho(source);
4545

46+
public ChannelReader<int> StreamEchoInt(ChannelReader<int> source) => TestHubMethodsImpl.StreamEchoInt(source);
47+
4648
public string GetUserIdentifier()
4749
{
4850
return Context.UserIdentifier;
@@ -121,6 +123,8 @@ public string GetCallerConnectionId()
121123
}
122124

123125
public ChannelReader<string> StreamEcho(ChannelReader<string> source) => TestHubMethodsImpl.StreamEcho(source);
126+
127+
public ChannelReader<int> StreamEchoInt(ChannelReader<int> source) => TestHubMethodsImpl.StreamEchoInt(source);
124128
}
125129

126130
public class TestHubT : Hub<ITestHub>
@@ -151,6 +155,8 @@ public string GetCallerConnectionId()
151155
}
152156

153157
public ChannelReader<string> StreamEcho(ChannelReader<string> source) => TestHubMethodsImpl.StreamEcho(source);
158+
159+
public ChannelReader<int> StreamEchoInt(ChannelReader<int> source) => TestHubMethodsImpl.StreamEchoInt(source);
154160
}
155161

156162
internal static class TestHubMethodsImpl
@@ -212,6 +218,30 @@ public static ChannelReader<string> StreamEcho(ChannelReader<string> source)
212218

213219
return output.Reader;
214220
}
221+
222+
public static ChannelReader<int> StreamEchoInt(ChannelReader<int> source)
223+
{
224+
var output = Channel.CreateUnbounded<int>();
225+
_ = Task.Run(async () =>
226+
{
227+
try
228+
{
229+
while (await source.WaitToReadAsync())
230+
{
231+
while (source.TryRead(out var item))
232+
{
233+
await output.Writer.WriteAsync(item);
234+
}
235+
}
236+
}
237+
finally
238+
{
239+
output.Writer.TryComplete();
240+
}
241+
});
242+
243+
return output.Reader;
244+
}
215245
}
216246

217247
public interface ITestHub

src/SignalR/common/Protocols.MessagePack/src/Protocol/MessagePackHubProtocol.cs

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -210,12 +210,21 @@ private static HubMessage CreateStreamInvocationMessage(byte[] input, ref int of
210210
return ApplyHeaders(headers, new StreamInvocationMessage(invocationId, target, arguments, streams));
211211
}
212212

213-
private static StreamItemMessage CreateStreamItemMessage(byte[] input, ref int offset, IInvocationBinder binder, IFormatterResolver resolver)
213+
private static HubMessage CreateStreamItemMessage(byte[] input, ref int offset, IInvocationBinder binder, IFormatterResolver resolver)
214214
{
215215
var headers = ReadHeaders(input, ref offset);
216216
var invocationId = ReadInvocationId(input, ref offset);
217-
var itemType = binder.GetStreamItemType(invocationId);
218-
var value = DeserializeObject(input, ref offset, itemType, "item", resolver);
217+
object value;
218+
try
219+
{
220+
var itemType = binder.GetStreamItemType(invocationId);
221+
value = DeserializeObject(input, ref offset, itemType, "item", resolver);
222+
}
223+
catch (Exception ex)
224+
{
225+
return new StreamBindingFailureMessage(invocationId, ExceptionDispatchInfo.Capture(ex));
226+
}
227+
219228
return ApplyHeaders(headers, new StreamItemMessage(invocationId, value));
220229
}
221230

src/SignalR/common/Shared/ReflectionHelper.cs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
33

44
using System;
5+
using System.Collections.Generic;
6+
using System.Linq;
57
using System.Threading.Channels;
68

79
namespace Microsoft.AspNetCore.SignalR
@@ -13,6 +15,13 @@ internal static class ReflectionHelper
1315
public static bool IsStreamingType(Type type, bool mustBeDirectType = false)
1416
{
1517
// TODO #2594 - add Streams here, to make sending files easy
18+
19+
#if NETCOREAPP3_0
20+
if (IsIAsyncEnumerable(type))
21+
{
22+
return true;
23+
}
24+
#endif
1625
do
1726
{
1827
if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(ChannelReader<>))
@@ -25,5 +34,22 @@ public static bool IsStreamingType(Type type, bool mustBeDirectType = false)
2534

2635
return false;
2736
}
37+
38+
#if NETCOREAPP3_0
39+
public static bool IsIAsyncEnumerable(Type type)
40+
{
41+
return type.GetInterfaces().Any(t =>
42+
{
43+
if (t.IsGenericType)
44+
{
45+
return t.GetGenericTypeDefinition() == typeof(IAsyncEnumerable<>);
46+
}
47+
else
48+
{
49+
return false;
50+
}
51+
});
52+
}
53+
#endif
2854
}
2955
}

src/SignalR/common/SignalR.Common/test/Internal/Protocol/MessagePackHubProtocolTestBase.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,8 +277,10 @@ protected void TestWriteMessages(ProtocolTestData testData)
277277
// StreamItemMessage
278278
new InvalidMessageData("StreamItemMissingId", new byte[] { 0x92, 2, 0x80 }, "Reading 'invocationId' as String failed."),
279279
new InvalidMessageData("StreamItemInvocationIdBoolean", new byte[] { 0x93, 2, 0x80, 0xc2 }, "Reading 'invocationId' as String failed."),
280-
new InvalidMessageData("StreamItemMissing", new byte[] { 0x93, 2, 0x80, 0xa3, (byte)'x', (byte)'y', (byte)'z' }, "Deserializing object of the `String` type for 'item' failed."),
281-
new InvalidMessageData("StreamItemTypeMismatch", new byte[] { 0x94, 2, 0x80, 0xa3, (byte)'x', (byte)'y', (byte)'z', 42 }, "Deserializing object of the `String` type for 'item' failed."),
280+
281+
// These now trigger StreamBindingInvocationFailureMessages
282+
//new InvalidMessageData("StreamItemMissing", new byte[] { 0x93, 2, 0x80, 0xa3, (byte)'x', (byte)'y', (byte)'z' }, "Deserializing object of the `String` type for 'item' failed."),
283+
//new InvalidMessageData("StreamItemTypeMismatch", new byte[] { 0x94, 2, 0x80, 0xa3, (byte)'x', (byte)'y', (byte)'z', 42 }, "Deserializing object of the `String` type for 'item' failed."),
282284

283285
// CompletionMessage
284286
new InvalidMessageData("CompletionMissingId", new byte[] { 0x92, 3, 0x80 }, "Reading 'invocationId' as String failed."),

0 commit comments

Comments
 (0)