Skip to content

Commit e3e2ef1

Browse files
authored
[release/7.0] Backport pr 48892 to 7.0 (#48911)
* Backport PR 48892. * Fix bug found in unit test on main.
1 parent 9b97894 commit e3e2ef1

File tree

3 files changed

+224
-1
lines changed

3 files changed

+224
-1
lines changed
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using System.Net.WebSockets;
5+
using Microsoft.AspNetCore.Http;
6+
7+
namespace Microsoft.AspNetCore.WebSockets;
8+
9+
/// <summary>
10+
/// Used in ASP.NET Core to wrap a WebSocket with its associated HttpContext so that when the WebSocket is aborted
11+
/// the underlying HttpContext is aborted. All other methods are delegated to the underlying WebSocket.
12+
/// </summary>
13+
internal sealed class ServerWebSocket : WebSocket
14+
{
15+
private readonly WebSocket _wrappedSocket;
16+
private readonly HttpContext _context;
17+
18+
internal ServerWebSocket(WebSocket wrappedSocket, HttpContext context)
19+
{
20+
ArgumentNullException.ThrowIfNull(wrappedSocket);
21+
ArgumentNullException.ThrowIfNull(context);
22+
23+
_wrappedSocket = wrappedSocket;
24+
_context = context;
25+
}
26+
27+
public override WebSocketCloseStatus? CloseStatus => _wrappedSocket.CloseStatus;
28+
29+
public override string? CloseStatusDescription => _wrappedSocket.CloseStatusDescription;
30+
31+
public override WebSocketState State => _wrappedSocket.State;
32+
33+
public override string? SubProtocol => _wrappedSocket.SubProtocol;
34+
35+
public override void Abort()
36+
{
37+
_wrappedSocket.Abort();
38+
_context.Abort();
39+
}
40+
41+
public override Task CloseAsync(WebSocketCloseStatus closeStatus, string? statusDescription, CancellationToken cancellationToken)
42+
{
43+
return _wrappedSocket.CloseAsync(closeStatus, statusDescription, cancellationToken);
44+
}
45+
46+
public override Task CloseOutputAsync(WebSocketCloseStatus closeStatus, string? statusDescription, CancellationToken cancellationToken)
47+
{
48+
return _wrappedSocket.CloseOutputAsync(closeStatus, statusDescription, cancellationToken);
49+
}
50+
51+
public override void Dispose()
52+
{
53+
_wrappedSocket.Dispose();
54+
}
55+
56+
public override Task<WebSocketReceiveResult> ReceiveAsync(ArraySegment<byte> buffer, CancellationToken cancellationToken)
57+
{
58+
return _wrappedSocket.ReceiveAsync(buffer, cancellationToken);
59+
}
60+
61+
public override ValueTask<ValueWebSocketReceiveResult> ReceiveAsync(Memory<byte> buffer, CancellationToken cancellationToken)
62+
{
63+
return _wrappedSocket.ReceiveAsync(buffer, cancellationToken);
64+
}
65+
66+
public override ValueTask SendAsync(ReadOnlyMemory<byte> buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken)
67+
{
68+
return _wrappedSocket.SendAsync(buffer, messageType, endOfMessage, cancellationToken);
69+
}
70+
71+
public override ValueTask SendAsync(ReadOnlyMemory<byte> buffer, WebSocketMessageType messageType, WebSocketMessageFlags messageFlags, CancellationToken cancellationToken)
72+
{
73+
return _wrappedSocket.SendAsync(buffer, messageType, messageFlags, cancellationToken);
74+
}
75+
76+
public override Task SendAsync(ArraySegment<byte> buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken)
77+
{
78+
return _wrappedSocket.SendAsync(buffer, messageType, endOfMessage, cancellationToken);
79+
}
80+
}

src/Middleware/WebSockets/src/WebSocketMiddleware.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,13 +207,15 @@ public async Task<WebSocket> AcceptAsync(WebSocketAcceptContext acceptContext)
207207
opaqueTransport = await _upgradeFeature!.UpgradeAsync(); // Sets status code to 101
208208
}
209209

210-
return WebSocket.CreateFromStream(opaqueTransport, new WebSocketCreationOptions()
210+
var wrappedSocket = WebSocket.CreateFromStream(opaqueTransport, new WebSocketCreationOptions()
211211
{
212212
IsServer = true,
213213
KeepAliveInterval = keepAliveInterval,
214214
SubProtocol = subProtocol,
215215
DangerousDeflateOptions = deflateOptions
216216
});
217+
218+
return new ServerWebSocket(wrappedSocket, _context);
217219
}
218220

219221
public static bool CheckSupportedWebSocketRequest(string method, IHeaderDictionary requestHeaders)

src/Middleware/WebSockets/test/UnitTests/WebSocketMiddlewareTests.cs

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System.Net.Http;
66
using System.Net.WebSockets;
77
using System.Text;
8+
using Microsoft.AspNetCore.Connections;
89
using Microsoft.AspNetCore.Testing;
910
using Microsoft.Net.Http.Headers;
1011

@@ -495,6 +496,146 @@ public async Task CloseFromCloseReceived_Success()
495496
}
496497
}
497498

499+
[Fact]
500+
public async Task WebSocket_Abort_Interrupts_Pending_ReceiveAsync()
501+
{
502+
WebSocket serverSocket = null;
503+
504+
// Events that we want to sequence execution across client and server.
505+
var socketWasAccepted = new ManualResetEventSlim();
506+
var socketWasAborted = new ManualResetEventSlim();
507+
var firstReceiveOccured = new ManualResetEventSlim();
508+
var secondReceiveInitiated = new ManualResetEventSlim();
509+
510+
Exception receiveException = null;
511+
512+
await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context =>
513+
{
514+
Assert.True(context.WebSockets.IsWebSocketRequest);
515+
serverSocket = await context.WebSockets.AcceptWebSocketAsync();
516+
socketWasAccepted.Set();
517+
518+
var serverBuffer = new byte[1024];
519+
520+
try
521+
{
522+
while (serverSocket.State is WebSocketState.Open or WebSocketState.CloseSent)
523+
{
524+
if (firstReceiveOccured.IsSet)
525+
{
526+
var pendingResponse = serverSocket.ReceiveAsync(serverBuffer, default);
527+
secondReceiveInitiated.Set();
528+
var response = await pendingResponse;
529+
}
530+
else
531+
{
532+
var response = await serverSocket.ReceiveAsync(serverBuffer, default);
533+
firstReceiveOccured.Set();
534+
}
535+
}
536+
}
537+
catch (ConnectionAbortedException ex)
538+
{
539+
socketWasAborted.Set();
540+
receiveException = ex;
541+
}
542+
catch (Exception ex)
543+
{
544+
// Capture this exception so a test failure can give us more information.
545+
receiveException = ex;
546+
}
547+
finally
548+
{
549+
Assert.IsType<ConnectionAbortedException>(receiveException);
550+
}
551+
}))
552+
{
553+
var clientBuffer = new byte[1024];
554+
555+
using (var client = new ClientWebSocket())
556+
{
557+
await client.ConnectAsync(new Uri($"ws://127.0.0.1:{port}/"), CancellationToken.None);
558+
559+
var socketWasAcceptedDidNotTimeout = socketWasAccepted.Wait(10000);
560+
Assert.True(socketWasAcceptedDidNotTimeout, "Socket was not accepted within the allotted time.");
561+
562+
await client.SendAsync(clientBuffer, WebSocketMessageType.Binary, false, default);
563+
564+
var firstReceiveOccuredDidNotTimeout = firstReceiveOccured.Wait(10000);
565+
Assert.True(firstReceiveOccuredDidNotTimeout, "First receive did not occur within the allotted time.");
566+
567+
var secondReceiveInitiatedDidNotTimeout = secondReceiveInitiated.Wait(10000);
568+
Assert.True(secondReceiveInitiatedDidNotTimeout, "Second receive was not initiated within the allotted time.");
569+
570+
serverSocket.Abort();
571+
572+
var socketWasAbortedDidNotTimeout = socketWasAborted.Wait(1000); // Give it a second to process the abort.
573+
Assert.True(socketWasAbortedDidNotTimeout, "Abort did not occur within the allotted time.");
574+
}
575+
}
576+
}
577+
578+
[Fact]
579+
public async Task WebSocket_AllowsCancelling_Pending_ReceiveAsync_When_CancellationTokenProvided()
580+
{
581+
WebSocket serverSocket = null;
582+
CancellationTokenSource cts = new CancellationTokenSource();
583+
584+
var socketWasAccepted = new ManualResetEventSlim();
585+
var operationWasCancelled = new ManualResetEventSlim();
586+
var firstReceiveOccured = new ManualResetEventSlim();
587+
588+
await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context =>
589+
{
590+
Assert.True(context.WebSockets.IsWebSocketRequest);
591+
serverSocket = await context.WebSockets.AcceptWebSocketAsync();
592+
socketWasAccepted.Set();
593+
594+
var serverBuffer = new byte[1024];
595+
596+
var finishedWithOperationCancelled = false;
597+
598+
try
599+
{
600+
while (serverSocket.State is WebSocketState.Open or WebSocketState.CloseSent)
601+
{
602+
var response = await serverSocket.ReceiveAsync(serverBuffer, cts.Token);
603+
firstReceiveOccured.Set();
604+
}
605+
}
606+
catch (OperationCanceledException)
607+
{
608+
operationWasCancelled.Set();
609+
finishedWithOperationCancelled = true;
610+
}
611+
finally
612+
{
613+
Assert.True(finishedWithOperationCancelled);
614+
}
615+
}))
616+
{
617+
var clientBuffer = new byte[1024];
618+
619+
using (var client = new ClientWebSocket())
620+
{
621+
await client.ConnectAsync(new Uri($"ws://127.0.0.1:{port}/"), CancellationToken.None);
622+
623+
var socketWasAcceptedDidNotTimeout = socketWasAccepted.Wait(10000);
624+
Assert.True(socketWasAcceptedDidNotTimeout, "Socket was not accepted within the allotted time.");
625+
626+
await client.SendAsync(clientBuffer, WebSocketMessageType.Binary, false, default);
627+
628+
var firstReceiveOccuredDidNotTimeout = firstReceiveOccured.Wait(10000);
629+
Assert.True(firstReceiveOccuredDidNotTimeout, "First receive did not occur within the allotted time.");
630+
631+
cts.Cancel();
632+
633+
var operationWasCancelledDidNotTimeout = operationWasCancelled.Wait(1000); // Give it a second to process the abort.
634+
Assert.True(operationWasCancelledDidNotTimeout, "Cancel did not occur within the allotted time.");
635+
}
636+
}
637+
}
638+
498639
[Theory]
499640
[InlineData(HttpStatusCode.OK, null)]
500641
[InlineData(HttpStatusCode.Forbidden, "")]

0 commit comments

Comments
 (0)