Skip to content

Commit 46fe595

Browse files
authored
Support IAsyncEnumerable returns in SignalR hubs (#6791)
1 parent 03460d8 commit 46fe595

File tree

12 files changed

+358
-184
lines changed

12 files changed

+358
-184
lines changed

src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -822,7 +822,7 @@ bool ExpectedErrors(WriteContext writeContext)
822822
await connection.StartAsync().OrTimeout();
823823
var channel = await connection.StreamAsChannelAsync<int>("StreamBroken").OrTimeout();
824824
var ex = await Assert.ThrowsAsync<HubException>(() => channel.ReadAndCollectAllAsync()).OrTimeout();
825-
Assert.Equal("The value returned by the streaming method 'StreamBroken' is not a ChannelReader<>.", ex.Message);
825+
Assert.Equal("The value returned by the streaming method 'StreamBroken' is not a ChannelReader<> or IAsyncEnumerable<>.", ex.Message);
826826
}
827827
catch (Exception ex)
828828
{

src/SignalR/samples/SignalRSamples/Hubs/Streaming.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
33

44
using System;
5+
using System.Collections.Generic;
56
using System.Reactive.Linq;
67
using System.Threading.Channels;
78
using System.Threading.Tasks;
@@ -11,6 +12,15 @@ namespace SignalRSamples.Hubs
1112
{
1213
public class Streaming : Hub
1314
{
15+
public async IAsyncEnumerable<int> AsyncEnumerableCounter(int count, int delay)
16+
{
17+
for (var i = 0; i < count; i++)
18+
{
19+
yield return i;
20+
await Task.Delay(delay);
21+
}
22+
}
23+
1424
public ChannelReader<int> ObservableCounter(int count, int delay)
1525
{
1626
var observable = Observable.Interval(TimeSpan.FromMilliseconds(delay))

src/SignalR/samples/SignalRSamples/SignalRSamples.csproj

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
<Project Sdk="Microsoft.NET.Sdk.Web">
1+
<Project Sdk="Microsoft.NET.Sdk.Web">
22

33
<PropertyGroup>
44
<TargetFramework>netcoreapp3.0</TargetFramework>
5+
<LangVersion>8.0</LangVersion>
56
</PropertyGroup>
67

78
<ItemGroup>

src/SignalR/samples/SignalRSamples/wwwroot/streaming.html

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ <h2>Controls</h2>
1717
</div>
1818

1919
<div>
20+
<button id="asyncEnumerableButton" name="asyncEnumerable" type="button" disabled>From IAsyncEnumerable</button>
2021
<button id="observableButton" name="observable" type="button" disabled>From Observable</button>
2122
<button id="channelButton" name="channel" type="button" disabled>From Channel</button>
2223
</div>
@@ -32,7 +33,7 @@ <h2>Results</h2>
3233
let resultsList = document.getElementById('resultsList');
3334
let channelButton = document.getElementById('channelButton');
3435
let observableButton = document.getElementById('observableButton');
35-
let clearButton = document.getElementById('clearButton');
36+
let asyncEnumerableButton = document.getElementById('asyncEnumerableButton');
3637

3738
let connectButton = document.getElementById('connectButton');
3839
let disconnectButton = document.getElementById('disconnectButton');
@@ -61,6 +62,7 @@ <h2>Results</h2>
6162
connection.onclose(function () {
6263
channelButton.disabled = true;
6364
observableButton.disabled = true;
65+
asyncEnumerableButton.disabled = true;
6466
connectButton.disabled = false;
6567
disconnectButton.disabled = true;
6668

@@ -71,12 +73,17 @@ <h2>Results</h2>
7173
.then(function () {
7274
channelButton.disabled = false;
7375
observableButton.disabled = false;
76+
asyncEnumerableButton.disabled = false;
7477
connectButton.disabled = true;
7578
disconnectButton.disabled = false;
7679
addLine('resultsList', 'connected', 'green');
7780
});
7881
});
7982

83+
click('asyncEnumerableButton', function () {
84+
run('AsyncEnumerableCounter');
85+
})
86+
8087
click('observableButton', function () {
8188
run('ObservableCounter');
8289
});
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
// Copyright (c) .NET Foundation. All rights reserved.
2+
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
3+
4+
using System.Collections.Generic;
5+
using System.Diagnostics;
6+
using System.Threading;
7+
using System.Threading.Channels;
8+
using System.Threading.Tasks;
9+
10+
namespace Microsoft.AspNetCore.SignalR.Internal
11+
{
12+
// True-internal because this is a weird and tricky class to use :)
13+
internal static class AsyncEnumerableAdapters
14+
{
15+
public static IAsyncEnumerable<object> MakeCancelableAsyncEnumerable<T>(IAsyncEnumerable<T> asyncEnumerable, CancellationToken cancellationToken = default)
16+
{
17+
return new CancelableAsyncEnumerable<T>(asyncEnumerable, cancellationToken);
18+
}
19+
20+
public static IAsyncEnumerable<object> MakeCancelableAsyncEnumerableFromChannel<T>(ChannelReader<T> channel, CancellationToken cancellationToken = default)
21+
{
22+
return MakeCancelableAsyncEnumerable(channel.ReadAllAsync(), cancellationToken);
23+
}
24+
25+
/// <summary>Converts an IAsyncEnumerable of T to an IAsyncEnumerable of object.</summary>
26+
private class CancelableAsyncEnumerable<T> : IAsyncEnumerable<object>
27+
{
28+
private readonly IAsyncEnumerable<T> _asyncEnumerable;
29+
private readonly CancellationToken _cancellationToken;
30+
31+
public CancelableAsyncEnumerable(IAsyncEnumerable<T> asyncEnumerable, CancellationToken cancellationToken)
32+
{
33+
_asyncEnumerable = asyncEnumerable;
34+
_cancellationToken = cancellationToken;
35+
}
36+
37+
public IAsyncEnumerator<object> GetAsyncEnumerator(CancellationToken cancellationToken = default)
38+
{
39+
// Assume that this will be iterated through with await foreach which always passes a default token.
40+
// Instead use the token from the ctor.
41+
Debug.Assert(cancellationToken == default);
42+
43+
var enumeratorOfT = _asyncEnumerable.GetAsyncEnumerator(_cancellationToken);
44+
return enumeratorOfT as IAsyncEnumerator<object> ?? new BoxedAsyncEnumerator(enumeratorOfT);
45+
}
46+
47+
private class BoxedAsyncEnumerator : IAsyncEnumerator<object>
48+
{
49+
private IAsyncEnumerator<T> _asyncEnumerator;
50+
51+
public BoxedAsyncEnumerator(IAsyncEnumerator<T> asyncEnumerator)
52+
{
53+
_asyncEnumerator = asyncEnumerator;
54+
}
55+
56+
public object Current => _asyncEnumerator.Current;
57+
58+
public ValueTask<bool> MoveNextAsync()
59+
{
60+
return _asyncEnumerator.MoveNextAsync();
61+
}
62+
63+
public ValueTask DisposeAsync()
64+
{
65+
return _asyncEnumerator.DisposeAsync();
66+
}
67+
}
68+
}
69+
}
70+
}

src/SignalR/server/Core/src/Internal/AsyncEnumeratorAdapters.cs

Lines changed: 0 additions & 84 deletions
This file was deleted.

src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs

Lines changed: 13 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -293,16 +293,20 @@ await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection,
293293
{
294294
var result = await ExecuteHubMethod(methodExecutor, hub, arguments);
295295

296-
if (!TryGetStreamingEnumerator(connection, hubMethodInvocationMessage.InvocationId, descriptor, result, out var enumerator, ref cts))
296+
if (result == null)
297297
{
298298
Log.InvalidReturnValueFromStreamingMethod(_logger, methodExecutor.MethodInfo.Name);
299299
await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection,
300-
$"The value returned by the streaming method '{methodExecutor.MethodInfo.Name}' is not a ChannelReader<>.");
300+
$"The value returned by the streaming method '{methodExecutor.MethodInfo.Name}' is not a ChannelReader<> or IAsyncEnumerable<>.");
301301
return;
302302
}
303303

304+
cts = cts ?? CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted);
305+
connection.ActiveRequestCancellationSources.TryAdd(hubMethodInvocationMessage.InvocationId, cts);
306+
var enumerable = descriptor.FromReturnedStream(result, cts.Token);
307+
304308
Log.StreamingResult(_logger, hubMethodInvocationMessage.InvocationId, methodExecutor);
305-
_ = StreamResultsAsync(hubMethodInvocationMessage.InvocationId, connection, enumerator, scope, hubActivator, hub, cts, hubMethodInvocationMessage);
309+
_ = StreamResultsAsync(hubMethodInvocationMessage.InvocationId, connection, enumerable, scope, hubActivator, hub, cts, hubMethodInvocationMessage);
306310
}
307311

308312
else if (string.IsNullOrEmpty(hubMethodInvocationMessage.InvocationId))
@@ -393,17 +397,17 @@ private ValueTask CleanupInvocation(HubConnectionContext connection, HubMethodIn
393397
return scope.DisposeAsync();
394398
}
395399

396-
private async Task StreamResultsAsync(string invocationId, HubConnectionContext connection, IAsyncEnumerator<object> enumerator, IServiceScope scope,
400+
private async Task StreamResultsAsync(string invocationId, HubConnectionContext connection, IAsyncEnumerable<object> enumerable, IServiceScope scope,
397401
IHubActivator<THub> hubActivator, THub hub, CancellationTokenSource streamCts, HubMethodInvocationMessage hubMethodInvocationMessage)
398402
{
399403
string error = null;
400404

401405
try
402406
{
403-
while (await enumerator.MoveNextAsync())
407+
await foreach (var streamItem in enumerable)
404408
{
405409
// Send the stream item
406-
await connection.WriteAsync(new StreamItemMessage(invocationId, enumerator.Current));
410+
await connection.WriteAsync(new StreamItemMessage(invocationId, streamItem));
407411
}
408412
}
409413
catch (ChannelClosedException ex)
@@ -422,8 +426,6 @@ private async Task StreamResultsAsync(string invocationId, HubConnectionContext
422426
}
423427
finally
424428
{
425-
(enumerator as IDisposable)?.Dispose();
426-
427429
await CleanupInvocation(connection, hubMethodInvocationMessage, hubActivator, hub, scope);
428430

429431
// Dispose the linked CTS for the stream.
@@ -502,10 +504,10 @@ private static async Task<bool> IsHubMethodAuthorizedSlow(IServiceProvider provi
502504
return authorizationResult.Succeeded;
503505
}
504506

505-
private async Task<bool> ValidateInvocationMode(HubMethodDescriptor hubMethodDescriptor, bool isStreamedInvocation,
507+
private async Task<bool> ValidateInvocationMode(HubMethodDescriptor hubMethodDescriptor, bool isStreamResponse,
506508
HubMethodInvocationMessage hubMethodInvocationMessage, HubConnectionContext connection)
507509
{
508-
if (hubMethodDescriptor.IsStreamable && !isStreamedInvocation)
510+
if (hubMethodDescriptor.IsStreamResponse && !isStreamResponse)
509511
{
510512
// Non-null/empty InvocationId? Blocking
511513
if (!string.IsNullOrEmpty(hubMethodInvocationMessage.InvocationId))
@@ -518,7 +520,7 @@ await connection.WriteAsync(CompletionMessage.WithError(hubMethodInvocationMessa
518520
return false;
519521
}
520522

521-
if (!hubMethodDescriptor.IsStreamable && isStreamedInvocation)
523+
if (!hubMethodDescriptor.IsStreamResponse && isStreamResponse)
522524
{
523525
Log.NonStreamingMethodCalledWithStream(_logger, hubMethodInvocationMessage);
524526
await connection.WriteAsync(CompletionMessage.WithError(hubMethodInvocationMessage.InvocationId,
@@ -530,26 +532,6 @@ await connection.WriteAsync(CompletionMessage.WithError(hubMethodInvocationMessa
530532
return true;
531533
}
532534

533-
private bool TryGetStreamingEnumerator(HubConnectionContext connection, string invocationId, HubMethodDescriptor hubMethodDescriptor, object result, out IAsyncEnumerator<object> enumerator, ref CancellationTokenSource streamCts)
534-
{
535-
if (result != null)
536-
{
537-
if (hubMethodDescriptor.IsChannel)
538-
{
539-
if (streamCts == null)
540-
{
541-
streamCts = CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted);
542-
}
543-
connection.ActiveRequestCancellationSources.TryAdd(invocationId, streamCts);
544-
enumerator = hubMethodDescriptor.FromChannel(result, streamCts.Token);
545-
return true;
546-
}
547-
}
548-
549-
enumerator = null;
550-
return false;
551-
}
552-
553535
private void DiscoverHubMethods()
554536
{
555537
var hubType = typeof(THub);

0 commit comments

Comments
 (0)