|
6 | 6 | using System.Linq;
|
7 | 7 | using System.Net;
|
8 | 8 | using System.Net.Http;
|
| 9 | +using System.Net.WebSockets; |
| 10 | +using System.Text; |
| 11 | +using System.Threading; |
9 | 12 | using System.Threading.Tasks;
|
10 | 13 | using Microsoft.AspNetCore.Builder;
|
11 | 14 | using Microsoft.AspNetCore.Hosting;
|
12 | 15 | using Microsoft.AspNetCore.Hosting.Server;
|
13 | 16 | using Microsoft.AspNetCore.Hosting.Server.Features;
|
14 | 17 | using Microsoft.AspNetCore.Http;
|
| 18 | +using Microsoft.AspNetCore.Http.Features; |
15 | 19 | using Microsoft.AspNetCore.Routing;
|
16 | 20 | using Microsoft.AspNetCore.Testing;
|
17 | 21 | using Microsoft.Extensions.DependencyInjection;
|
@@ -109,6 +113,34 @@ public async Task DefautCredentials_Success(Version version)
|
109 | 113 | Assert.Equal(Http11Version, result.Version); // HTTP/2 downgrades.
|
110 | 114 | }
|
111 | 115 |
|
| 116 | + [ConditionalFact] |
| 117 | + public async Task DefautCredentials_WebSocket_Success() |
| 118 | + { |
| 119 | + using var host = await CreateHostAsync(); |
| 120 | + |
| 121 | + var address = host.Services.GetRequiredService<IServer>().Features.Get<IServerAddressesFeature>().Addresses.First().Replace("https://", "wss://"); |
| 122 | + |
| 123 | + using var webSocket = new ClientWebSocket |
| 124 | + { |
| 125 | + Options = |
| 126 | + { |
| 127 | + RemoteCertificateValidationCallback = (sender, certificate, chain, sslPolicyErrors) => true, |
| 128 | + UseDefaultCredentials = true, |
| 129 | + } |
| 130 | + }; |
| 131 | + |
| 132 | + using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(30)); |
| 133 | + |
| 134 | + await webSocket.ConnectAsync(new Uri($"{address}/AuthenticateWebSocket"), cts.Token); |
| 135 | + |
| 136 | + var receiveBuffer = new byte[13]; |
| 137 | + var receiveResult = await webSocket.ReceiveAsync(receiveBuffer, cts.Token); |
| 138 | + |
| 139 | + Assert.True(receiveResult.EndOfMessage); |
| 140 | + Assert.Equal(WebSocketMessageType.Text, receiveResult.MessageType); |
| 141 | + Assert.Equal("Hello World!", Encoding.UTF8.GetString(receiveBuffer, 0, receiveResult.Count)); |
| 142 | + } |
| 143 | + |
112 | 144 | public static IEnumerable<object[]> HttpOrders =>
|
113 | 145 | new List<object[]>
|
114 | 146 | {
|
@@ -252,6 +284,7 @@ private static Task<IHost> CreateHostAsync(Action<NegotiateOptions> configureOpt
|
252 | 284 | {
|
253 | 285 | app.UseRouting();
|
254 | 286 | app.UseAuthentication();
|
| 287 | + app.UseWebSockets(); |
255 | 288 | app.UseEndpoints(ConfigureEndpoints);
|
256 | 289 | });
|
257 | 290 | });
|
@@ -289,6 +322,27 @@ private static void ConfigureEndpoints(IEndpointRouteBuilder builder)
|
289 | 322 | await context.Response.WriteAsync(name);
|
290 | 323 | });
|
291 | 324 |
|
| 325 | + builder.Map("/AuthenticateWebSocket", async context => |
| 326 | + { |
| 327 | + if (!context.User.Identity.IsAuthenticated) |
| 328 | + { |
| 329 | + await context.ChallengeAsync(); |
| 330 | + return; |
| 331 | + } |
| 332 | + |
| 333 | + if (!context.WebSockets.IsWebSocketRequest) |
| 334 | + { |
| 335 | + context.Response.StatusCode = 400; |
| 336 | + return; |
| 337 | + } |
| 338 | + |
| 339 | + Assert.False(string.IsNullOrEmpty(context.User.Identity.Name), "name"); |
| 340 | + |
| 341 | + WebSocket webSocket = await context.WebSockets.AcceptWebSocketAsync(); |
| 342 | + |
| 343 | + await webSocket.SendAsync(Encoding.UTF8.GetBytes("Hello World!"), WebSocketMessageType.Text, endOfMessage: true, context.RequestAborted); |
| 344 | + }); |
| 345 | + |
292 | 346 | builder.Map("/AlreadyAuthenticated", async context =>
|
293 | 347 | {
|
294 | 348 | Assert.Equal("HTTP/1.1", context.Request.Protocol); // Not HTTP/2
|
|
0 commit comments