Skip to content

Fix WebSockets Negotiate Auth in Kestrel #26480

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
<Reference Include="Microsoft.AspNetCore.Authentication.Negotiate" />
<Reference Include="Microsoft.AspNetCore.Routing" />
<Reference Include="Microsoft.AspNetCore.Server.Kestrel" />
<Reference Include="Microsoft.AspNetCore.WebSockets" />
<Reference Include="Microsoft.Extensions.Hosting" />
<Reference Include="System.Net.Http.WinHttpHandler" />
</ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Net.WebSockets;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting;
Expand All @@ -23,7 +26,7 @@ namespace Microsoft.AspNetCore.Authentication.Negotiate
{
// In theory this would work on Linux and Mac, but the client would require explicit credentials.
[OSSkipCondition(OperatingSystems.Linux | OperatingSystems.MacOSX)]
public class NegotiateHandlerFunctionalTests
public class NegotiateHandlerFunctionalTests : LoggedTest
{
private static readonly Version Http11Version = new Version(1, 1);
private static readonly Version Http2Version = new Version(2, 0);
Expand Down Expand Up @@ -109,6 +112,34 @@ public async Task DefautCredentials_Success(Version version)
Assert.Equal(Http11Version, result.Version); // HTTP/2 downgrades.
}

[ConditionalFact]
public async Task DefautCredentials_WebSocket_Success()
{
using var host = await CreateHostAsync();

var address = host.Services.GetRequiredService<IServer>().Features.Get<IServerAddressesFeature>().Addresses.First().Replace("https://", "wss://");

using var webSocket = new ClientWebSocket
{
Options =
{
RemoteCertificateValidationCallback = (sender, certificate, chain, sslPolicyErrors) => true,
UseDefaultCredentials = true,
}
};

using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(30));

await webSocket.ConnectAsync(new Uri($"{address}/AuthenticateWebSocket"), cts.Token);

var receiveBuffer = new byte[13];
var receiveResult = await webSocket.ReceiveAsync(receiveBuffer, cts.Token);

Assert.True(receiveResult.EndOfMessage);
Assert.Equal(WebSocketMessageType.Text, receiveResult.MessageType);
Assert.Equal("Hello World!", Encoding.UTF8.GetString(receiveBuffer, 0, receiveResult.Count));
}

public static IEnumerable<object[]> HttpOrders =>
new List<object[]>
{
Expand Down Expand Up @@ -232,9 +263,10 @@ public async Task UnauthorizedAfterAuthenticated_Success(Version version)
Assert.Equal(Http11Version, result.Version); // HTTP/2 downgrades.
}

private static Task<IHost> CreateHostAsync(Action<NegotiateOptions> configureOptions = null)
private Task<IHost> CreateHostAsync(Action<NegotiateOptions> configureOptions = null)
{
var builder = new HostBuilder()
.ConfigureServices(AddTestLogging)
.ConfigureServices(services => services
.AddRouting()
.AddAuthentication(NegotiateDefaults.AuthenticationScheme)
Expand All @@ -252,6 +284,7 @@ private static Task<IHost> CreateHostAsync(Action<NegotiateOptions> configureOpt
{
app.UseRouting();
app.UseAuthentication();
app.UseWebSockets();
app.UseEndpoints(ConfigureEndpoints);
});
});
Expand Down Expand Up @@ -289,6 +322,27 @@ private static void ConfigureEndpoints(IEndpointRouteBuilder builder)
await context.Response.WriteAsync(name);
});

builder.Map("/AuthenticateWebSocket", async context =>
{
if (!context.User.Identity.IsAuthenticated)
{
await context.ChallengeAsync();
return;
}

if (!context.WebSockets.IsWebSocketRequest)
{
context.Response.StatusCode = 400;
return;
}

Assert.False(string.IsNullOrEmpty(context.User.Identity.Name), "name");

WebSocket webSocket = await context.WebSockets.AcceptWebSocketAsync();

await webSocket.SendAsync(Encoding.UTF8.GetBytes("Hello World!"), WebSocketMessageType.Text, endOfMessage: true, context.RequestAborted);
});

builder.Map("/AlreadyAuthenticated", async context =>
{
Assert.Equal("HTTP/1.1", context.Request.Protocol); // Not HTTP/2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,9 @@ internal sealed class Http1ChunkedEncodingMessageBody : Http1MessageBody
private readonly Pipe _requestBodyPipe;
private ReadResult _readResult;

public Http1ChunkedEncodingMessageBody(bool keepAlive, Http1Connection context)
: base(context)
public Http1ChunkedEncodingMessageBody(Http1Connection context, bool keepAlive)
: base(context, keepAlive)
{
RequestKeepAlive = keepAlive;
_requestBodyPipe = CreateRequestBodyPipe(context);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@

namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
{
using BadHttpRequestException = Microsoft.AspNetCore.Http.BadHttpRequestException;

internal sealed class Http1ContentLengthMessageBody : Http1MessageBody
{
private ReadResult _readResult;
Expand All @@ -23,12 +21,11 @@ internal sealed class Http1ContentLengthMessageBody : Http1MessageBody
private bool _finalAdvanceCalled;
private bool _cannotResetInputPipe;

public Http1ContentLengthMessageBody(bool keepAlive, long contentLength, Http1Connection context)
: base(context)
public Http1ContentLengthMessageBody(Http1Connection context, long contentLength, bool keepAlive)
: base(context, keepAlive)
{
RequestKeepAlive = keepAlive;
_contentLength = contentLength;
_unexaminedInputLength = _contentLength;
_unexaminedInputLength = contentLength;
}

public override ValueTask<ReadResult> ReadAsync(CancellationToken cancellationToken = default)
Expand Down
14 changes: 8 additions & 6 deletions src/Servers/Kestrel/Core/src/Internal/Http/Http1MessageBody.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ internal abstract class Http1MessageBody : MessageBody
protected readonly Http1Connection _context;
protected bool _completed;

protected Http1MessageBody(Http1Connection context) : base(context)
protected Http1MessageBody(Http1Connection context, bool keepAlive) : base(context)
{
_context = context;
RequestKeepAlive = keepAlive;
}

[StackTraceHidden]
Expand Down Expand Up @@ -118,14 +119,15 @@ public static MessageBody For(
{
// see also http://tools.ietf.org/html/rfc2616#section-4.4
var keepAlive = httpVersion != HttpVersion.Http10;

var upgrade = false;

if (headers.HasConnection)
{
var connectionOptions = HttpHeaders.ParseConnection(headers.HeaderConnection);

upgrade = (connectionOptions & ConnectionOptions.Upgrade) != 0;
keepAlive = (connectionOptions & ConnectionOptions.KeepAlive) != 0;
keepAlive = keepAlive || (connectionOptions & ConnectionOptions.KeepAlive) != 0;
keepAlive = keepAlive && (connectionOptions & ConnectionOptions.Close) == 0;
}

if (upgrade)
Expand All @@ -136,7 +138,7 @@ public static MessageBody For(
}

context.OnTrailersComplete(); // No trailers for these.
return new Http1UpgradeMessageBody(context);
return new Http1UpgradeMessageBody(context, keepAlive);
}

if (headers.HasTransferEncoding)
Expand All @@ -157,7 +159,7 @@ public static MessageBody For(

// TODO may push more into the wrapper rather than just calling into the message body
// NBD for now.
return new Http1ChunkedEncodingMessageBody(keepAlive, context);
return new Http1ChunkedEncodingMessageBody(context, keepAlive);
}

if (headers.ContentLength.HasValue)
Expand All @@ -169,7 +171,7 @@ public static MessageBody For(
return keepAlive ? MessageBody.ZeroContentLengthKeepAlive : MessageBody.ZeroContentLengthClose;
}

return new Http1ContentLengthMessageBody(keepAlive, contentLength, context);
return new Http1ContentLengthMessageBody(context, contentLength, keepAlive);
}

// If we got here, request contains no Content-Length or Transfer-Encoding header.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
/// </summary>
internal sealed class Http1UpgradeMessageBody : Http1MessageBody
{
public Http1UpgradeMessageBody(Http1Connection context)
: base(context)
public Http1UpgradeMessageBody(Http1Connection context, bool keepAlive)
: base(context, keepAlive)
{
RequestUpgrade = true;
}
Expand Down
10 changes: 5 additions & 5 deletions src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1113,13 +1113,13 @@ private HttpResponseHeaders CreateResponseHeaders(bool appCompleted)
{
RejectNonBodyTransferEncodingResponse(appCompleted);
}
else if (StatusCode == StatusCodes.Status101SwitchingProtocols)
{
_keepAlive = false;
}
else if (!hasTransferEncoding && !responseHeaders.ContentLength.HasValue)
{
if (StatusCode == StatusCodes.Status101SwitchingProtocols)
{
_keepAlive = false;
}
else if ((appCompleted || !_canWriteResponseBody) && !_hasAdvanced) // Avoid setting contentLength of 0 if we wrote data before calling CreateResponseHeaders
if ((appCompleted || !_canWriteResponseBody) && !_hasAdvanced) // Avoid setting contentLength of 0 if we wrote data before calling CreateResponseHeaders
{
// Don't set the Content-Length header automatically for HEAD requests, 204 responses, or 304 responses.
if (CanAutoSetContentLengthZeroResponseHeader())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
using Microsoft.AspNetCore.Server.Kestrel.Core;
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal;
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http;
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2;
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure;
using Microsoft.AspNetCore.Testing;

Expand Down Expand Up @@ -112,7 +113,7 @@ private TestHttp1Connection MakeHttp1Connection()
});

http1Connection.Reset();
http1Connection.InitializeBodyControl(new Http1ContentLengthMessageBody(keepAlive: true, 100, http1Connection));
http1Connection.InitializeBodyControl(new Http1ContentLengthMessageBody(http1Connection, contentLength: 100, keepAlive: true));
serviceContext.DateHeaderValueManager.OnHeartbeat(DateTimeOffset.UtcNow);

return http1Connection;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
using Microsoft.AspNetCore.Server.Kestrel.InMemory.FunctionalTests.TestTransport;
using Microsoft.AspNetCore.Server.Kestrel.Tests;
using Microsoft.AspNetCore.Testing;
using Microsoft.Extensions.Logging.Testing;
using Xunit;

namespace Microsoft.AspNetCore.Server.Kestrel.InMemory.FunctionalTests
Expand Down Expand Up @@ -343,5 +342,38 @@ await connection.Receive("HTTP/1.1 101 Switching Protocols",
await appCompletedTcs.Task.DefaultTimeout();
}
}

[Fact]
public async Task DoesNotCloseConnectionWithout101Response()
{
var requestCount = 0;

await using (var server = new TestServer(async context =>
{
if (requestCount++ > 0)
{
await context.Features.Get<IHttpUpgradeFeature>().UpgradeAsync();
}
}, new TestServiceContext(LoggerFactory)))
{
using (var connection = server.CreateConnection())
{
await connection.SendEmptyGetWithUpgrade();
await connection.Receive(
"HTTP/1.1 200 OK",
$"Date: {server.Context.DateHeaderValue}",
"Content-Length: 0",
"",
"");

await connection.SendEmptyGetWithUpgrade();
await connection.Receive("HTTP/1.1 101 Switching Protocols",
"Connection: Upgrade",
$"Date: {server.Context.DateHeaderValue}",
"",
"");
}
}
}
}
}