Skip to content

Commit 102efca

Browse files
committed
Address feedback, added tests
1 parent d9d63cf commit 102efca

File tree

7 files changed

+302
-43
lines changed

7 files changed

+302
-43
lines changed

src/Components/Server/src/Builder/ComponentEndpointConventionBuilder.cs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@ namespace Microsoft.AspNetCore.Builder
1010
/// </summary>
1111
public sealed class ComponentEndpointConventionBuilder : IHubEndpointConventionBuilder
1212
{
13-
private readonly IEndpointConventionBuilder [] _endpointConventionBuilders;
13+
private readonly IEndpointConventionBuilder _hubEndpoint;
14+
private readonly IEndpointConventionBuilder _disconnectEndpoint;
1415

15-
internal ComponentEndpointConventionBuilder(params IEndpointConventionBuilder [] endpointConventionBuilder)
16+
internal ComponentEndpointConventionBuilder(IEndpointConventionBuilder hubEndpoint, IEndpointConventionBuilder disconnectEndpoint)
1617
{
17-
_endpointConventionBuilders = endpointConventionBuilder;
18+
_hubEndpoint = hubEndpoint;
19+
_disconnectEndpoint = disconnectEndpoint;
1820
}
1921

2022
/// <summary>
@@ -23,10 +25,8 @@ internal ComponentEndpointConventionBuilder(params IEndpointConventionBuilder []
2325
/// <param name="convention">The convention to add to the builder.</param>
2426
public void Add(Action<EndpointBuilder> convention)
2527
{
26-
foreach (var endpoint in _endpointConventionBuilders)
27-
{
28-
endpoint.Add(convention);
29-
}
28+
_hubEndpoint.Add(convention);
29+
_disconnectEndpoint.Add(convention);
3030
}
3131
}
3232
}

src/Components/Server/src/Builder/ComponentEndpointRouteBuilderExtensions.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -300,9 +300,9 @@ public static ComponentEndpointConventionBuilder MapBlazorHub(
300300
.WithDisplayName("Blazor disconnect");
301301

302302
return new ComponentEndpointConventionBuilder(
303-
disconnectEndpoint,
304-
hubEndpoint)
305-
.AddComponent(componentType, selector);
303+
hubEndpoint,
304+
disconnectEndpoint)
305+
.AddComponent(componentType, selector);
306306
}
307307
}
308308
}

src/Components/Server/src/CircuitDisconnectMiddleware.cs

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
namespace Microsoft.AspNetCore.Components.Server
1111
{
12+
// We use a middlware so that we can use DI.
1213
internal class CircuitDisconnectMiddleware
1314
{
1415
private const string CircuitIdKey = "circuitId";
@@ -38,25 +39,47 @@ public async Task Invoke(HttpContext context)
3839
return;
3940
}
4041

41-
var form = await context.Request.ReadFormAsync();
42-
if(!form.TryGetValue(CircuitIdKey, out var circuitId) || !CircuitIdFactory.ValidateCircuitId(circuitId))
42+
var (hasCircuitId, circuitId) = await TryGetCircuitIdAsync(context);
43+
if (!hasCircuitId)
4344
{
44-
context.Response.StatusCode = StatusCodes.Status404NotFound;
45+
context.Response.StatusCode = StatusCodes.Status400BadRequest;
4546
return;
4647
}
4748

4849
await TerminateCircuitGracefully(circuitId);
4950

5051
context.Response.StatusCode = StatusCodes.Status200OK;
51-
return;
52+
}
53+
54+
private async Task<(bool, string)> TryGetCircuitIdAsync(HttpContext context)
55+
{
56+
try
57+
{
58+
if (!string.Equals(context.Request.ContentType, "application/x-www-form-urlencoded"))
59+
{
60+
return (false, null);
61+
}
62+
63+
var form = await context.Request.ReadFormAsync();
64+
if (!form.TryGetValue(CircuitIdKey, out var circuitId) || !CircuitIdFactory.ValidateCircuitId(circuitId))
65+
{
66+
return (false, null);
67+
}
68+
69+
return (true, circuitId);
70+
}
71+
catch
72+
{
73+
return (false, null);
74+
}
5275
}
5376

5477
private async Task TerminateCircuitGracefully(string circuitId)
5578
{
5679
try
5780
{
58-
Log.CircuitTerminatedGracefully(Logger, circuitId);
5981
await Registry.Terminate(circuitId);
82+
Log.CircuitTerminatedGracefully(Logger, circuitId);
6083
}
6184
catch (Exception e)
6285
{

src/Components/Server/src/Circuits/CircuitRegistry.cs

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -81,14 +81,6 @@ public void Register(CircuitHost circuitHost)
8181
}
8282
}
8383

84-
public void PermanentDisconnect(CircuitHost circuitHost)
85-
{
86-
DisconnectedCircuits.Remove(circuitHost.CircuitId);
87-
ConnectedCircuits.TryRemove(circuitHost.CircuitId, out _);
88-
Log.CircuitDisconnectedPermanently(_logger, circuitHost.CircuitId);
89-
circuitHost.Client.SetDisconnected();
90-
}
91-
9284
public virtual Task DisconnectAsync(CircuitHost circuitHost, string connectionId)
9385
{
9486
Log.CircuitDisconnectStarted(_logger, circuitHost.CircuitId, connectionId);
@@ -296,29 +288,27 @@ private void DisposeTokenSource(DisconnectedCircuitEntry entry)
296288
}
297289
}
298290

299-
internal ValueTask Terminate(string circuitId)
291+
public ValueTask Terminate(string circuitId)
300292
{
301293
CircuitHost circuitHost;
302294
DisconnectedCircuitEntry entry = default;
303295
lock (CircuitRegistryLock)
304296
{
305297
if (ConnectedCircuits.TryGetValue(circuitId, out circuitHost) || DisconnectedCircuits.TryGetValue(circuitId, out entry))
306298
{
307-
PermanentDisconnect(circuitHost ?? entry.CircuitHost);
299+
circuitHost ??= entry.CircuitHost;
300+
DisconnectedCircuits.Remove(circuitHost.CircuitId);
301+
ConnectedCircuits.TryRemove(circuitHost.CircuitId, out _);
302+
Log.CircuitDisconnectedPermanently(_logger, circuitHost.CircuitId);
303+
circuitHost.Client.SetDisconnected();
308304
}
309305
else
310306
{
311307
return default;
312308
}
313309
}
314-
if (circuitHost != null)
315-
{
316-
return circuitHost.DisposeAsync();
317-
}
318-
else
319-
{
320-
return default;
321-
}
310+
311+
return circuitHost?.DisposeAsync() ?? default;
322312
}
323313

324314
private readonly struct DisconnectedCircuitEntry

0 commit comments

Comments
 (0)