Skip to content

Commit 96c082f

Browse files
authored
Fix WebSockets Negotiate Auth in Kestrel (#26480)
* Don't close connections after upgrade requests without a 101 response * Add test * Add DefautCredentials_WebSocket_Success
1 parent 8eb9603 commit 96c082f

File tree

9 files changed

+112
-26
lines changed

9 files changed

+112
-26
lines changed

src/Security/Authentication/Negotiate/test/Negotiate.FunctionalTest/Microsoft.AspNetCore.Authentication.Negotiate.FunctionalTest.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
<Reference Include="Microsoft.AspNetCore.Authentication.Negotiate" />
1010
<Reference Include="Microsoft.AspNetCore.Routing" />
1111
<Reference Include="Microsoft.AspNetCore.Server.Kestrel" />
12+
<Reference Include="Microsoft.AspNetCore.WebSockets" />
1213
<Reference Include="Microsoft.Extensions.Hosting" />
1314
<Reference Include="System.Net.Http.WinHttpHandler" />
1415
</ItemGroup>

src/Security/Authentication/Negotiate/test/Negotiate.FunctionalTest/NegotiateHandlerFunctionalTests.cs

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
using System.Linq;
77
using System.Net;
88
using System.Net.Http;
9+
using System.Net.WebSockets;
10+
using System.Text;
11+
using System.Threading;
912
using System.Threading.Tasks;
1013
using Microsoft.AspNetCore.Builder;
1114
using Microsoft.AspNetCore.Hosting;
@@ -23,7 +26,7 @@ namespace Microsoft.AspNetCore.Authentication.Negotiate
2326
{
2427
// In theory this would work on Linux and Mac, but the client would require explicit credentials.
2528
[OSSkipCondition(OperatingSystems.Linux | OperatingSystems.MacOSX)]
26-
public class NegotiateHandlerFunctionalTests
29+
public class NegotiateHandlerFunctionalTests : LoggedTest
2730
{
2831
private static readonly Version Http11Version = new Version(1, 1);
2932
private static readonly Version Http2Version = new Version(2, 0);
@@ -109,6 +112,34 @@ public async Task DefautCredentials_Success(Version version)
109112
Assert.Equal(Http11Version, result.Version); // HTTP/2 downgrades.
110113
}
111114

115+
[ConditionalFact]
116+
public async Task DefautCredentials_WebSocket_Success()
117+
{
118+
using var host = await CreateHostAsync();
119+
120+
var address = host.Services.GetRequiredService<IServer>().Features.Get<IServerAddressesFeature>().Addresses.First().Replace("https://", "wss://");
121+
122+
using var webSocket = new ClientWebSocket
123+
{
124+
Options =
125+
{
126+
RemoteCertificateValidationCallback = (sender, certificate, chain, sslPolicyErrors) => true,
127+
UseDefaultCredentials = true,
128+
}
129+
};
130+
131+
using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(30));
132+
133+
await webSocket.ConnectAsync(new Uri($"{address}/AuthenticateWebSocket"), cts.Token);
134+
135+
var receiveBuffer = new byte[13];
136+
var receiveResult = await webSocket.ReceiveAsync(receiveBuffer, cts.Token);
137+
138+
Assert.True(receiveResult.EndOfMessage);
139+
Assert.Equal(WebSocketMessageType.Text, receiveResult.MessageType);
140+
Assert.Equal("Hello World!", Encoding.UTF8.GetString(receiveBuffer, 0, receiveResult.Count));
141+
}
142+
112143
public static IEnumerable<object[]> HttpOrders =>
113144
new List<object[]>
114145
{
@@ -232,9 +263,10 @@ public async Task UnauthorizedAfterAuthenticated_Success(Version version)
232263
Assert.Equal(Http11Version, result.Version); // HTTP/2 downgrades.
233264
}
234265

235-
private static Task<IHost> CreateHostAsync(Action<NegotiateOptions> configureOptions = null)
266+
private Task<IHost> CreateHostAsync(Action<NegotiateOptions> configureOptions = null)
236267
{
237268
var builder = new HostBuilder()
269+
.ConfigureServices(AddTestLogging)
238270
.ConfigureServices(services => services
239271
.AddRouting()
240272
.AddAuthentication(NegotiateDefaults.AuthenticationScheme)
@@ -252,6 +284,7 @@ private static Task<IHost> CreateHostAsync(Action<NegotiateOptions> configureOpt
252284
{
253285
app.UseRouting();
254286
app.UseAuthentication();
287+
app.UseWebSockets();
255288
app.UseEndpoints(ConfigureEndpoints);
256289
});
257290
});
@@ -289,6 +322,27 @@ private static void ConfigureEndpoints(IEndpointRouteBuilder builder)
289322
await context.Response.WriteAsync(name);
290323
});
291324

325+
builder.Map("/AuthenticateWebSocket", async context =>
326+
{
327+
if (!context.User.Identity.IsAuthenticated)
328+
{
329+
await context.ChallengeAsync();
330+
return;
331+
}
332+
333+
if (!context.WebSockets.IsWebSocketRequest)
334+
{
335+
context.Response.StatusCode = 400;
336+
return;
337+
}
338+
339+
Assert.False(string.IsNullOrEmpty(context.User.Identity.Name), "name");
340+
341+
WebSocket webSocket = await context.WebSockets.AcceptWebSocketAsync();
342+
343+
await webSocket.SendAsync(Encoding.UTF8.GetBytes("Hello World!"), WebSocketMessageType.Text, endOfMessage: true, context.RequestAborted);
344+
});
345+
292346
builder.Map("/AlreadyAuthenticated", async context =>
293347
{
294348
Assert.Equal("HTTP/1.1", context.Request.Protocol); // Not HTTP/2

src/Servers/Kestrel/Core/src/Internal/Http/Http1ChunkedEncodingMessageBody.cs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,9 @@ internal sealed class Http1ChunkedEncodingMessageBody : Http1MessageBody
3030
private readonly Pipe _requestBodyPipe;
3131
private ReadResult _readResult;
3232

33-
public Http1ChunkedEncodingMessageBody(bool keepAlive, Http1Connection context)
34-
: base(context)
33+
public Http1ChunkedEncodingMessageBody(Http1Connection context, bool keepAlive)
34+
: base(context, keepAlive)
3535
{
36-
RequestKeepAlive = keepAlive;
3736
_requestBodyPipe = CreateRequestBodyPipe(context);
3837
}
3938

src/Servers/Kestrel/Core/src/Internal/Http/Http1ContentLengthMessageBody.cs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010

1111
namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
1212
{
13-
using BadHttpRequestException = Microsoft.AspNetCore.Http.BadHttpRequestException;
14-
1513
internal sealed class Http1ContentLengthMessageBody : Http1MessageBody
1614
{
1715
private ReadResult _readResult;
@@ -23,12 +21,11 @@ internal sealed class Http1ContentLengthMessageBody : Http1MessageBody
2321
private bool _finalAdvanceCalled;
2422
private bool _cannotResetInputPipe;
2523

26-
public Http1ContentLengthMessageBody(bool keepAlive, long contentLength, Http1Connection context)
27-
: base(context)
24+
public Http1ContentLengthMessageBody(Http1Connection context, long contentLength, bool keepAlive)
25+
: base(context, keepAlive)
2826
{
29-
RequestKeepAlive = keepAlive;
3027
_contentLength = contentLength;
31-
_unexaminedInputLength = _contentLength;
28+
_unexaminedInputLength = contentLength;
3229
}
3330

3431
public override ValueTask<ReadResult> ReadAsync(CancellationToken cancellationToken = default)

src/Servers/Kestrel/Core/src/Internal/Http/Http1MessageBody.cs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@ internal abstract class Http1MessageBody : MessageBody
1818
protected readonly Http1Connection _context;
1919
protected bool _completed;
2020

21-
protected Http1MessageBody(Http1Connection context) : base(context)
21+
protected Http1MessageBody(Http1Connection context, bool keepAlive) : base(context)
2222
{
2323
_context = context;
24+
RequestKeepAlive = keepAlive;
2425
}
2526

2627
[StackTraceHidden]
@@ -118,14 +119,15 @@ public static MessageBody For(
118119
{
119120
// see also http://tools.ietf.org/html/rfc2616#section-4.4
120121
var keepAlive = httpVersion != HttpVersion.Http10;
121-
122122
var upgrade = false;
123+
123124
if (headers.HasConnection)
124125
{
125126
var connectionOptions = HttpHeaders.ParseConnection(headers.HeaderConnection);
126127

127128
upgrade = (connectionOptions & ConnectionOptions.Upgrade) != 0;
128-
keepAlive = (connectionOptions & ConnectionOptions.KeepAlive) != 0;
129+
keepAlive = keepAlive || (connectionOptions & ConnectionOptions.KeepAlive) != 0;
130+
keepAlive = keepAlive && (connectionOptions & ConnectionOptions.Close) == 0;
129131
}
130132

131133
if (upgrade)
@@ -136,7 +138,7 @@ public static MessageBody For(
136138
}
137139

138140
context.OnTrailersComplete(); // No trailers for these.
139-
return new Http1UpgradeMessageBody(context);
141+
return new Http1UpgradeMessageBody(context, keepAlive);
140142
}
141143

142144
if (headers.HasTransferEncoding)
@@ -157,7 +159,7 @@ public static MessageBody For(
157159

158160
// TODO may push more into the wrapper rather than just calling into the message body
159161
// NBD for now.
160-
return new Http1ChunkedEncodingMessageBody(keepAlive, context);
162+
return new Http1ChunkedEncodingMessageBody(context, keepAlive);
161163
}
162164

163165
if (headers.ContentLength.HasValue)
@@ -169,7 +171,7 @@ public static MessageBody For(
169171
return keepAlive ? MessageBody.ZeroContentLengthKeepAlive : MessageBody.ZeroContentLengthClose;
170172
}
171173

172-
return new Http1ContentLengthMessageBody(keepAlive, contentLength, context);
174+
return new Http1ContentLengthMessageBody(context, contentLength, keepAlive);
173175
}
174176

175177
// If we got here, request contains no Content-Length or Transfer-Encoding header.

src/Servers/Kestrel/Core/src/Internal/Http/Http1UpgradeMessageBody.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
1414
/// </summary>
1515
internal sealed class Http1UpgradeMessageBody : Http1MessageBody
1616
{
17-
public Http1UpgradeMessageBody(Http1Connection context)
18-
: base(context)
17+
public Http1UpgradeMessageBody(Http1Connection context, bool keepAlive)
18+
: base(context, keepAlive)
1919
{
2020
RequestUpgrade = true;
2121
}

src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,13 +1113,13 @@ private HttpResponseHeaders CreateResponseHeaders(bool appCompleted)
11131113
{
11141114
RejectNonBodyTransferEncodingResponse(appCompleted);
11151115
}
1116+
else if (StatusCode == StatusCodes.Status101SwitchingProtocols)
1117+
{
1118+
_keepAlive = false;
1119+
}
11161120
else if (!hasTransferEncoding && !responseHeaders.ContentLength.HasValue)
11171121
{
1118-
if (StatusCode == StatusCodes.Status101SwitchingProtocols)
1119-
{
1120-
_keepAlive = false;
1121-
}
1122-
else if ((appCompleted || !_canWriteResponseBody) && !_hasAdvanced) // Avoid setting contentLength of 0 if we wrote data before calling CreateResponseHeaders
1122+
if ((appCompleted || !_canWriteResponseBody) && !_hasAdvanced) // Avoid setting contentLength of 0 if we wrote data before calling CreateResponseHeaders
11231123
{
11241124
// Don't set the Content-Length header automatically for HEAD requests, 204 responses, or 304 responses.
11251125
if (CanAutoSetContentLengthZeroResponseHeader())

src/Servers/Kestrel/perf/Kestrel.Performance/Http1ReadingBenchmark.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
using Microsoft.AspNetCore.Server.Kestrel.Core;
1313
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal;
1414
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http;
15+
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2;
1516
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure;
1617
using Microsoft.AspNetCore.Testing;
1718

@@ -112,7 +113,7 @@ private TestHttp1Connection MakeHttp1Connection()
112113
});
113114

114115
http1Connection.Reset();
115-
http1Connection.InitializeBodyControl(new Http1ContentLengthMessageBody(keepAlive: true, 100, http1Connection));
116+
http1Connection.InitializeBodyControl(new Http1ContentLengthMessageBody(http1Connection, contentLength: 100, keepAlive: true));
116117
serviceContext.DateHeaderValueManager.OnHeartbeat(DateTimeOffset.UtcNow);
117118

118119
return http1Connection;

src/Servers/Kestrel/test/InMemory.FunctionalTests/UpgradeTests.cs

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
using Microsoft.AspNetCore.Server.Kestrel.InMemory.FunctionalTests.TestTransport;
1111
using Microsoft.AspNetCore.Server.Kestrel.Tests;
1212
using Microsoft.AspNetCore.Testing;
13-
using Microsoft.Extensions.Logging.Testing;
1413
using Xunit;
1514

1615
namespace Microsoft.AspNetCore.Server.Kestrel.InMemory.FunctionalTests
@@ -343,5 +342,38 @@ await connection.Receive("HTTP/1.1 101 Switching Protocols",
343342
await appCompletedTcs.Task.DefaultTimeout();
344343
}
345344
}
345+
346+
[Fact]
347+
public async Task DoesNotCloseConnectionWithout101Response()
348+
{
349+
var requestCount = 0;
350+
351+
await using (var server = new TestServer(async context =>
352+
{
353+
if (requestCount++ > 0)
354+
{
355+
await context.Features.Get<IHttpUpgradeFeature>().UpgradeAsync();
356+
}
357+
}, new TestServiceContext(LoggerFactory)))
358+
{
359+
using (var connection = server.CreateConnection())
360+
{
361+
await connection.SendEmptyGetWithUpgrade();
362+
await connection.Receive(
363+
"HTTP/1.1 200 OK",
364+
$"Date: {server.Context.DateHeaderValue}",
365+
"Content-Length: 0",
366+
"",
367+
"");
368+
369+
await connection.SendEmptyGetWithUpgrade();
370+
await connection.Receive("HTTP/1.1 101 Switching Protocols",
371+
"Connection: Upgrade",
372+
$"Date: {server.Context.DateHeaderValue}",
373+
"",
374+
"");
375+
}
376+
}
377+
}
346378
}
347379
}

0 commit comments

Comments
 (0)