Skip to content

Client to Server Streaming with IAsyncEnumerable #9310

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Apr 18, 2019
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 49 additions & 12 deletions src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ public partial class HubConnection
private readonly SemaphoreSlim _connectionLock = new SemaphoreSlim(1, 1);

private static readonly MethodInfo _sendStreamItemsMethod = typeof(HubConnection).GetMethods(BindingFlags.NonPublic | BindingFlags.Instance).Single(m => m.Name.Equals("SendStreamItems"));

#if NETCOREAPP3_0
private static readonly MethodInfo _sendIAsyncStreamItemsMethod = typeof(HubConnection).GetMethods(BindingFlags.NonPublic | BindingFlags.Instance).Single(m => m.Name.Equals("SendIAsyncEnumerableStreamItems"));
#endif
// Persistent across all connections
private readonly ILoggerFactory _loggerFactory;
private readonly ILogger _logger;
Expand Down Expand Up @@ -533,13 +535,11 @@ async Task OnStreamCanceled(InvocationRequest irq)
}

LaunchStreams(readers, cancellationToken);

return channel;
}

private Dictionary<string, object> PackageStreamingParams(ref object[] args, out List<string> streamIds)
{
// lazy initialized, to avoid allocating unecessary dictionaries
Dictionary<string, object> readers = null;
streamIds = null;
var newArgs = new List<object>(args.Length);
Expand Down Expand Up @@ -572,7 +572,6 @@ private Dictionary<string, object> PackageStreamingParams(ref object[] args, out
}

args = newArgs.ToArray();

return readers;
}

Expand All @@ -590,31 +589,69 @@ private void LaunchStreams(Dictionary<string, object> readers, CancellationToken
// For each stream that needs to be sent, run a "send items" task in the background.
// This reads from the channel, attaches streamId, and sends to server.
// A single background thread here quickly gets messy.
#if NETCOREAPP3_0
if (ReflectionHelper.IsIAsyncEnumerable(reader.GetType()))
{
_ = _sendIAsyncStreamItemsMethod
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Long term, we should store the Takss returned by _sendIAsyncStreamItemsMethod and _sendStreamItemsMethod and track them during connection teardown and log warnings if they don't complete within a timeout.

.MakeGenericMethod(reader.GetType().GetInterface("IAsyncEnumerable`1").GetGenericArguments())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mikaelm12 is this right?

.Invoke(this, new object[] { kvp.Key.ToString(), reader, cancellationToken });
continue;
}
#endif
_ = _sendStreamItemsMethod
.MakeGenericMethod(reader.GetType().GetGenericArguments())
.Invoke(this, new object[] { kvp.Key.ToString(), reader, cancellationToken });
}
}

// this is called via reflection using the `_sendStreamItems` field
private async Task SendStreamItems<T>(string streamId, ChannelReader<T> reader, CancellationToken token)
private Task SendStreamItems<T>(string streamId, ChannelReader<T> reader, CancellationToken token)
{
Log.StartingStream(_logger, streamId);

var combinedToken = CancellationTokenSource.CreateLinkedTokenSource(_uploadStreamToken, token).Token;

string responseError = null;
try
async Task ReadChannelStream(CancellationTokenSource tokenSource)
{
while (await reader.WaitToReadAsync(combinedToken))
while (await reader.WaitToReadAsync(token))
{
while (!combinedToken.IsCancellationRequested && reader.TryRead(out var item))
while (!token.IsCancellationRequested && reader.TryRead(out var item))
{
await SendWithLock(new StreamItemMessage(streamId, item));
Log.SendingStreamItem(_logger, streamId);
}
}
}

return CommonStreaming(streamId, token, ReadChannelStream);
}

#if NETCOREAPP3_0
// this is called via reflection using the `_sendIAsyncStreamItemsMethod` field
private Task SendIAsyncEnumerableStreamItems<T>(string streamId, IAsyncEnumerable<T> stream, CancellationToken token)
{
async Task ReadAsyncEnumerableStream(CancellationTokenSource tokenSource)
{
var streamValues = AsyncEnumerableAdapters.MakeCancelableTypedAsyncEnumerable(stream, tokenSource);

await foreach (var streamValue in streamValues)
{
await SendWithLock(new StreamItemMessage(streamId, streamValue));
Log.SendingStreamItem(_logger, streamId);
}
}

return CommonStreaming(streamId, token, ReadAsyncEnumerableStream);
}
#endif

private async Task CommonStreaming(string streamId, CancellationToken token, Func<CancellationTokenSource, Task> createAndConsumeStream)
{
var cts = CancellationTokenSource.CreateLinkedTokenSource(_uploadStreamToken, token);

Log.StartingStream(_logger, streamId);
string responseError = null;
try
{
await createAndConsumeStream(cts);
}
catch (OperationCanceledException)
{
Log.CancelingStream(_logger, streamId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,106 @@ public async Task CanStreamToAndFromClientInSameInvocation(string protocolName,
}
}

[Theory]
[MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))]
[LogLevel(LogLevel.Trace)]
public async Task CanStreamToServerWithIAsyncEnumerable(string protocolName, HttpTransportType transportType, string path)
{
var protocol = HubProtocols[protocolName];
using (StartServer<Startup>(out var server))
{
var connection = CreateHubConnection(server.Url, path, transportType, protocol, LoggerFactory);
try
{
async IAsyncEnumerable<string> clientStreamData()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dude, casing, 😄

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the case settled on local functions? I could see the argument for using camelCase like you would for anything else with a local scope.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it settled it’s always pascal case because this is c# not java or JavaScript

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😅

{
var items = new string[] { "A", "B", "C", "D" };
foreach (var item in items)
{
await Task.Delay(10);
yield return item;
}
}

await connection.StartAsync().OrTimeout();

var stream = clientStreamData();

var channel = await connection.StreamAsChannelAsync<string>("StreamEcho", stream).OrTimeout();

Assert.Equal("A", await channel.ReadAsync().AsTask().OrTimeout());
Assert.Equal("B", await channel.ReadAsync().AsTask().OrTimeout());
Assert.Equal("C", await channel.ReadAsync().AsTask().OrTimeout());
Assert.Equal("D", await channel.ReadAsync().AsTask().OrTimeout());

var results = await channel.ReadAndCollectAllAsync().OrTimeout();
Assert.Empty(results);
}
catch (Exception ex)
{
LoggerFactory.CreateLogger<HubConnectionTests>().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName);
throw;
}
finally
{
await connection.DisposeAsync().OrTimeout();
}
}
}

[Theory]
[MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))]
[LogLevel(LogLevel.Trace)]
public async Task CanCancelIAsyncEnumerableClientToServerUpload(string protocolName, HttpTransportType transportType, string path)
{
var protocol = HubProtocols[protocolName];
using (StartServer<Startup>(out var server))
{
var connection = CreateHubConnection(server.Url, path, transportType, protocol, LoggerFactory);
try
{
async IAsyncEnumerable<int> clientStreamData()
{
for (var i = 0; i < 1000; i++)
{
yield return i;
await Task.Delay(10);
}
}

await connection.StartAsync().OrTimeout();
var results = new List<int>();
var stream = clientStreamData();
var cts = new CancellationTokenSource();
var ex = await Assert.ThrowsAsync<OperationCanceledException>(async () =>
{
var channel = await connection.StreamAsChannelAsync<int>("StreamEchoInt", stream, cts.Token).OrTimeout();

while (await channel.WaitToReadAsync())
{
while (channel.TryRead(out var item))
{
results.Add(item);
cts.Cancel();
}
}
});

Assert.True(results.Count > 0 && results.Count < 1000);
Assert.True(cts.IsCancellationRequested);
}
catch (Exception ex)
{
LoggerFactory.CreateLogger<HubConnectionTests>().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName);
throw;
}
finally
{
await connection.DisposeAsync().OrTimeout();
}
}
}

[Theory]
[MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))]
[LogLevel(LogLevel.Trace)]
Expand All @@ -673,7 +773,7 @@ public async Task StreamAsyncCanBeCanceledThroughGetAsyncEnumerator(string proto
try
{
await connection.StartAsync().OrTimeout();
var stream = connection.StreamAsync<int>("Stream", 1000 );
var stream = connection.StreamAsync<int>("Stream", 1000);
var results = new List<int>();

var cts = new CancellationTokenSource();
Expand Down
30 changes: 30 additions & 0 deletions src/SignalR/clients/csharp/Client/test/FunctionalTests/Hubs.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ public string GetCallerConnectionId()

public ChannelReader<string> StreamEcho(ChannelReader<string> source) => TestHubMethodsImpl.StreamEcho(source);

public ChannelReader<int> StreamEchoInt(ChannelReader<int> source) => TestHubMethodsImpl.StreamEchoInt(source);

public string GetUserIdentifier()
{
return Context.UserIdentifier;
Expand Down Expand Up @@ -121,6 +123,8 @@ public string GetCallerConnectionId()
}

public ChannelReader<string> StreamEcho(ChannelReader<string> source) => TestHubMethodsImpl.StreamEcho(source);

public ChannelReader<int> StreamEchoInt(ChannelReader<int> source) => TestHubMethodsImpl.StreamEchoInt(source);
}

public class TestHubT : Hub<ITestHub>
Expand Down Expand Up @@ -151,6 +155,8 @@ public string GetCallerConnectionId()
}

public ChannelReader<string> StreamEcho(ChannelReader<string> source) => TestHubMethodsImpl.StreamEcho(source);

public ChannelReader<int> StreamEchoInt(ChannelReader<int> source) => TestHubMethodsImpl.StreamEchoInt(source);
}

internal static class TestHubMethodsImpl
Expand Down Expand Up @@ -212,6 +218,30 @@ public static ChannelReader<string> StreamEcho(ChannelReader<string> source)

return output.Reader;
}

public static ChannelReader<int> StreamEchoInt(ChannelReader<int> source)
{
var output = Channel.CreateUnbounded<int>();
_ = Task.Run(async () =>
{
try
{
while (await source.WaitToReadAsync())
{
while (source.TryRead(out var item))
{
await output.Writer.WriteAsync(item);
}
}
}
finally
{
output.Writer.TryComplete();
}
});

return output.Reader;
}
}

public interface ITestHub
Expand Down
26 changes: 26 additions & 0 deletions src/SignalR/common/Shared/ReflectionHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Channels;

namespace Microsoft.AspNetCore.SignalR
Expand All @@ -13,6 +15,13 @@ internal static class ReflectionHelper
public static bool IsStreamingType(Type type, bool mustBeDirectType = false)
{
// TODO #2594 - add Streams here, to make sending files easy

#if NETCOREAPP3_0
if (IsIAsyncEnumerable(type))
{
return true;
}
#endif
do
{
if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(ChannelReader<>))
Expand All @@ -25,5 +34,22 @@ public static bool IsStreamingType(Type type, bool mustBeDirectType = false)

return false;
}

#if NETCOREAPP3_0
public static bool IsIAsyncEnumerable(Type type)
{
return type.GetInterfaces().Any(t =>
{
if (t.IsGenericType)
{
return t.GetGenericTypeDefinition() == typeof(IAsyncEnumerable<>);
}
else
{
return false;
}
});
}
#endif
}
}