Skip to content

Pass CancellationToken to WaitAsync in client #20210

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 31, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 22 additions & 20 deletions src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ public async Task StartAsync(CancellationToken cancellationToken = default)

private async Task StartAsyncInner(CancellationToken cancellationToken = default)
{
await _state.WaitConnectionLockAsync();
await _state.WaitConnectionLockAsync(token: cancellationToken);
try
{
if (!_state.TryChangeState(HubConnectionState.Disconnected, HubConnectionState.Connecting))
Expand Down Expand Up @@ -465,7 +465,7 @@ private async Task StopAsyncCore(bool disposing)

// Potentially wait for StartAsync to finish, and block a new StartAsync from
// starting until we've finished stopping.
await _state.WaitConnectionLockAsync();
await _state.WaitConnectionLockAsync(token: default);

// Ensure that ReconnectingState.ReconnectTask is not accessed outside of the lock.
var reconnectTask = _state.ReconnectTask;
Expand All @@ -478,7 +478,7 @@ private async Task StopAsyncCore(bool disposing)
// The StopCts should prevent the HubConnection from restarting until it is reset.
_state.ReleaseConnectionLock();
await reconnectTask;
await _state.WaitConnectionLockAsync();
await _state.WaitConnectionLockAsync(token: default);
}

ConnectionState connectionState;
Expand Down Expand Up @@ -574,7 +574,7 @@ private async Task<ChannelReader<object>> StreamAsChannelCoreAsyncCore(string me
async Task OnStreamCanceled(InvocationRequest irq)
{
// We need to take the connection lock in order to ensure we a) have a connection and b) are the only one accessing the write end of the pipe.
await _state.WaitConnectionLockAsync();
await _state.WaitConnectionLockAsync(token: default);
try
{
if (_state.CurrentConnectionStateUnsynchronized != null)
Expand All @@ -601,7 +601,7 @@ async Task OnStreamCanceled(InvocationRequest irq)
var readers = default(Dictionary<string, object>);

CheckDisposed();
var connectionState = await _state.WaitForActiveConnectionAsync(nameof(StreamAsChannelCoreAsync));
var connectionState = await _state.WaitForActiveConnectionAsync(nameof(StreamAsChannelCoreAsync), token: cancellationToken);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove token = default from all of our internal methods that wait for the semaphore like WaitForActiveConnectionAsync, SendWithLock, WaitConnectionLockAsync, etc...?

I know WaitConnectionLockAsync in particular has a lot of usages where we don't want to use a cancellation token for close/reconnect logic, but I like being explicit.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed


ChannelReader<object> channel;
try
Expand Down Expand Up @@ -704,7 +704,7 @@ async Task ReadChannelStream(CancellationTokenSource tokenSource)
{
while (!tokenSource.Token.IsCancellationRequested && reader.TryRead(out var item))
{
await SendWithLock(connectionState, new StreamItemMessage(streamId, item));
await SendWithLock(connectionState, new StreamItemMessage(streamId, item), tokenSource.Token);
Log.SendingStreamItem(_logger, streamId);
}
}
Expand All @@ -722,7 +722,7 @@ async Task ReadAsyncEnumerableStream(CancellationTokenSource tokenSource)

await foreach (var streamValue in streamValues)
{
await SendWithLock(connectionState, new StreamItemMessage(streamId, streamValue));
await SendWithLock(connectionState, new StreamItemMessage(streamId, streamValue), tokenSource.Token);
Log.SendingStreamItem(_logger, streamId);
}
}
Expand Down Expand Up @@ -750,15 +750,17 @@ private async Task CommonStreaming(ConnectionState connectionState, string strea

Log.CompletingStream(_logger, streamId);

await SendWithLock(connectionState, CompletionMessage.WithError(streamId, responseError), cts.Token);
// Don't use cancellation token here
// this is triggered by a cancellation token to tell the server that the client is done streaming
await SendWithLock(connectionState, CompletionMessage.WithError(streamId, responseError), cancellationToken: default);
}

private async Task<object> InvokeCoreAsyncCore(string methodName, Type returnType, object[] args, CancellationToken cancellationToken)
{
var readers = default(Dictionary<string, object>);

CheckDisposed();
var connectionState = await _state.WaitForActiveConnectionAsync(nameof(InvokeCoreAsync));
var connectionState = await _state.WaitForActiveConnectionAsync(nameof(InvokeCoreAsync), token: cancellationToken);

Task<object> invocationTask;
try
Expand Down Expand Up @@ -853,7 +855,7 @@ private async Task SendCoreAsyncCore(string methodName, object[] args, Cancellat
var readers = default(Dictionary<string, object>);

CheckDisposed();
var connectionState = await _state.WaitForActiveConnectionAsync(nameof(SendCoreAsync));
var connectionState = await _state.WaitForActiveConnectionAsync(nameof(SendCoreAsync), token: cancellationToken);
try
{
CheckDisposed();
Expand All @@ -872,10 +874,10 @@ private async Task SendCoreAsyncCore(string methodName, object[] args, Cancellat
}
}

private async Task SendWithLock(ConnectionState expectedConnectionState, HubMessage message, CancellationToken cancellationToken = default, [CallerMemberName] string callerName = "")
private async Task SendWithLock(ConnectionState expectedConnectionState, HubMessage message, CancellationToken cancellationToken, [CallerMemberName] string callerName = "")
{
CheckDisposed();
var connectionState = await _state.WaitForActiveConnectionAsync(callerName);
var connectionState = await _state.WaitForActiveConnectionAsync(callerName, token: cancellationToken);
try
{
CheckDisposed();
Expand Down Expand Up @@ -1246,7 +1248,7 @@ internal void OnServerTimeout()
private async Task HandleConnectionClose(ConnectionState connectionState)
{
// Clear the connectionState field
await _state.WaitConnectionLockAsync();
await _state.WaitConnectionLockAsync(token: default);
try
{
SafeAssert(ReferenceEquals(_state.CurrentConnectionStateUnsynchronized, connectionState),
Expand Down Expand Up @@ -1364,7 +1366,7 @@ private async Task ReconnectAsync(Exception closeException)
{
Log.ReconnectingStoppedDuringRetryDelay(_logger);

await _state.WaitConnectionLockAsync();
await _state.WaitConnectionLockAsync(token: default);
try
{
_state.ChangeState(HubConnectionState.Reconnecting, HubConnectionState.Disconnected);
Expand All @@ -1379,7 +1381,7 @@ private async Task ReconnectAsync(Exception closeException)
return;
}

await _state.WaitConnectionLockAsync();
await _state.WaitConnectionLockAsync(token: default);
try
{
SafeAssert(ReferenceEquals(_state.CurrentConnectionStateUnsynchronized, null),
Expand Down Expand Up @@ -1418,7 +1420,7 @@ private async Task ReconnectAsync(Exception closeException)
nextRetryDelay = GetNextRetryDelay(previousReconnectAttempts++, DateTime.UtcNow - reconnectStartTime, retryReason);
}

await _state.WaitConnectionLockAsync();
await _state.WaitConnectionLockAsync(token: default);
try
{
SafeAssert(ReferenceEquals(_state.CurrentConnectionStateUnsynchronized, null),
Expand Down Expand Up @@ -1954,10 +1956,10 @@ public void AssertConnectionValid([CallerMemberName] string memberName = null, [
SafeAssert(CurrentConnectionStateUnsynchronized != null, "We don't have a connection!", memberName, fileName, lineNumber);
}

public Task WaitConnectionLockAsync([CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int lineNumber = 0)
public Task WaitConnectionLockAsync(CancellationToken token, [CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int lineNumber = 0)
{
Log.WaitingOnConnectionLock(_logger, memberName, filePath, lineNumber);
return _connectionLock.WaitAsync();
return _connectionLock.WaitAsync(token);
}

public bool TryAcquireConnectionLock()
Expand All @@ -1966,9 +1968,9 @@ public bool TryAcquireConnectionLock()
}

// Don't call this method in a try/finally that releases the lock since we're also potentially releasing the connection lock here.
public async Task<ConnectionState> WaitForActiveConnectionAsync(string methodName, [CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int lineNumber = 0)
public async Task<ConnectionState> WaitForActiveConnectionAsync(string methodName, CancellationToken token, [CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int lineNumber = 0)
{
await WaitConnectionLockAsync(methodName);
await WaitConnectionLockAsync(token, methodName);

if (CurrentConnectionStateUnsynchronized == null || CurrentConnectionStateUnsynchronized.Stopping)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ await WithConnectionAsync(
})),
async (connection) =>
{
// We aggregate failures that happen when we start the transport. The operation cancelled exception will
// We aggregate failures that happen when we start the transport. The operation canceled exception will
// be an inner exception.
var ex = await Assert.ThrowsAsync<AggregateException>(async () => await connection.StartAsync(cts.Token)).OrTimeout();
Assert.Equal(3, ex.InnerExceptions.Count);
Expand All @@ -454,6 +454,29 @@ await WithConnectionAsync(
}
}

[Fact]
public async Task CanceledCancellationTokenPassedToStartThrows()
{
using (StartVerifiableLog())
{
bool transportStartCalled = false;
var httpHandler = new TestHttpMessageHandler();

await WithConnectionAsync(
CreateConnection(httpHandler,
transport: new TestTransport(onTransportStart: () => {
transportStartCalled = true;
return Task.CompletedTask;
})),
async (connection) =>
{
await Assert.ThrowsAsync<TaskCanceledException>(async () => await connection.StartAsync(new CancellationToken(canceled: true))).OrTimeout();
});

Assert.False(transportStartCalled);
}
}

[Fact]
public async Task SSECanBeCanceled()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,7 @@ public async Task StartAsyncWithTriggeredCancellationTokenIsCanceled()
var hubConnection = CreateHubConnection(connection, loggerFactory: LoggerFactory);
try
{
await Assert.ThrowsAsync<OperationCanceledException>(() => hubConnection.StartAsync(new CancellationToken(canceled: true))).OrTimeout();
await Assert.ThrowsAsync<TaskCanceledException>(() => hubConnection.StartAsync(new CancellationToken(canceled: true))).OrTimeout();
Assert.False(onStartCalled);
}
finally
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,98 @@ bool ExpectedErrors(WriteContext writeContext)
}
}

[Fact]
public async Task PendingInvocationsAreCanceledWhenTokenTriggered()
{
using (StartVerifiableLog())
{
var hubConnection = CreateHubConnection(new TestConnection(), loggerFactory: LoggerFactory);

await hubConnection.StartAsync().OrTimeout();
var cts = new CancellationTokenSource();
var invokeTask = hubConnection.InvokeAsync<int>("testMethod", cancellationToken: cts.Token).OrTimeout();
cts.Cancel();

await Assert.ThrowsAsync<TaskCanceledException>(async () => await invokeTask);
}
}

[Fact]
public async Task InvokeAsyncCanceledWhenPassedCanceledToken()
{
using (StartVerifiableLog())
{
var connection = new TestConnection();
var hubConnection = CreateHubConnection(connection, loggerFactory: LoggerFactory);

await hubConnection.StartAsync().OrTimeout();
await Assert.ThrowsAsync<TaskCanceledException>(() =>
hubConnection.InvokeAsync<int>("testMethod", cancellationToken: new CancellationToken(canceled: true)).OrTimeout());

await hubConnection.StopAsync().OrTimeout();

// Assert that InvokeAsync didn't send a message
Assert.Null(await connection.ReadSentTextMessageAsync().OrTimeout());
}
}

[Fact]
public async Task SendAsyncCanceledWhenPassedCanceledToken()
{
using (StartVerifiableLog())
{
var connection = new TestConnection();
var hubConnection = CreateHubConnection(connection, loggerFactory: LoggerFactory);

await hubConnection.StartAsync().OrTimeout();
await Assert.ThrowsAsync<TaskCanceledException>(() =>
hubConnection.SendAsync("testMethod", cancellationToken: new CancellationToken(canceled: true)).OrTimeout());

await hubConnection.StopAsync().OrTimeout();

// Assert that SendAsync didn't send a message
Assert.Null(await connection.ReadSentTextMessageAsync().OrTimeout());
}
}

[Fact]
public async Task StreamAsChannelAsyncCanceledWhenPassedCanceledToken()
{
using (StartVerifiableLog())
{
var connection = new TestConnection();
var hubConnection = CreateHubConnection(connection, loggerFactory: LoggerFactory);

await hubConnection.StartAsync().OrTimeout();
await Assert.ThrowsAsync<TaskCanceledException>(() =>
hubConnection.StreamAsChannelAsync<int>("testMethod", cancellationToken: new CancellationToken(canceled: true)).OrTimeout());

await hubConnection.StopAsync().OrTimeout();

// Assert that StreamAsChannelAsync didn't send a message
Assert.Null(await connection.ReadSentTextMessageAsync().OrTimeout());
}
}

[Fact]
public async Task StreamAsyncCanceledWhenPassedCanceledToken()
{
using (StartVerifiableLog())
{
var connection = new TestConnection();
var hubConnection = CreateHubConnection(connection, loggerFactory: LoggerFactory);

await hubConnection.StartAsync().OrTimeout();
var result = hubConnection.StreamAsync<int>("testMethod", cancellationToken: new CancellationToken(canceled: true));
await Assert.ThrowsAsync<TaskCanceledException>(() => result.GetAsyncEnumerator().MoveNextAsync().OrTimeout());

await hubConnection.StopAsync().OrTimeout();

// Assert that StreamAsync didn't send a message
Assert.Null(await connection.ReadSentTextMessageAsync().OrTimeout());
}
}

[Fact]
public async Task ConnectionTerminatedIfServerTimeoutIntervalElapsesWithNoMessages()
{
Expand Down Expand Up @@ -318,7 +410,7 @@ await connection.ReceiveJsonMessage(

[Fact]
[LogLevel(LogLevel.Trace)]
public async Task UploadStreamCancelationSendsStreamComplete()
public async Task UploadStreamCancellationSendsStreamComplete()
{
using (StartVerifiableLog())
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ private async Task StartAsyncCore(TransferFormat transferFormat, CancellationTok
return;
}

await _connectionLock.WaitAsync();
await _connectionLock.WaitAsync(cancellationToken);
try
{
CheckDisposed();
Expand Down