Skip to content

Commit 3d93e09

Browse files
Add support for hub specific IHubProtocols that don't affect other hubs (#15177)
1 parent 99e79a0 commit 3d93e09

13 files changed

+367
-31
lines changed

src/Components/Server/src/BlazorPack/BlazorPackHubProtocol.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,15 @@
1212
using Microsoft.AspNetCore.Connections;
1313
using Microsoft.AspNetCore.Internal;
1414
using Microsoft.AspNetCore.SignalR;
15+
using Microsoft.AspNetCore.SignalR.Internal;
1516
using Microsoft.AspNetCore.SignalR.Protocol;
1617

1718
namespace Microsoft.AspNetCore.Components.Server.BlazorPack
1819
{
1920
/// <summary>
2021
/// Implements the SignalR Hub Protocol using MessagePack with limited type support.
2122
/// </summary>
23+
[NonDefaultHubProtocol]
2224
internal sealed class BlazorPackHubProtocol : IHubProtocol
2325
{
2426
internal const string ProtocolName = "blazorpack";
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
// Copyright (c) .NET Foundation. All rights reserved.
2+
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
3+
4+
using System;
5+
6+
namespace Microsoft.AspNetCore.SignalR.Internal
7+
{
8+
// Tells SignalR not to add the IHubProtocol with this attribute to all hubs by default
9+
[AttributeUsage(AttributeTargets.Class, AllowMultiple = false, Inherited = true)]
10+
internal class NonDefaultHubProtocolAttribute : Attribute
11+
{
12+
}
13+
}

src/SignalR/perf/Microbenchmarks/RedisProtocolBenchmark.cs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
using System.Collections.Generic;
77
using BenchmarkDotNet.Attributes;
88
using Microsoft.AspNetCore.Connections;
9+
using Microsoft.AspNetCore.SignalR.Internal;
910
using Microsoft.AspNetCore.SignalR.Protocol;
1011
using Microsoft.AspNetCore.SignalR.StackExchangeRedis.Internal;
12+
using Microsoft.Extensions.Logging.Abstractions;
1113

1214
namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
1315
{
@@ -28,10 +30,10 @@ public class RedisProtocolBenchmark
2830
[GlobalSetup]
2931
public void GlobalSetup()
3032
{
31-
_protocol = new RedisProtocol(new [] {
32-
new DummyProtocol("protocol1"),
33-
new DummyProtocol("protocol2")
34-
});
33+
var resolver = new DefaultHubProtocolResolver(new List<IHubProtocol> { new DummyProtocol("protocol1"),
34+
new DummyProtocol("protocol2") }, NullLogger<DefaultHubProtocolResolver>.Instance);
35+
36+
_protocol = new RedisProtocol(new DefaultHubMessageSerializer(resolver, new List<string>() { "protocol1", "protocol2" }, hubSupportedProtocols: null));
3537

3638
_groupCommand = new RedisGroupCommand(id: 42, serverName: "Server", GroupAction.Add, groupName: "group", connectionId: "connection");
3739

@@ -119,7 +121,7 @@ private static IReadOnlyList<string> GenerateIds(int count)
119121
return ids;
120122
}
121123

122-
private class DummyProtocol: IHubProtocol
124+
private class DummyProtocol : IHubProtocol
123125
{
124126
private static readonly byte[] _fixedOutput = new byte[] { 0x68, 0x68, 0x6C, 0x6C, 0x6F };
125127

src/SignalR/server/Core/src/HubOptionsSetup.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
using System;
55
using System.Collections.Generic;
6+
using System.Linq;
67
using Microsoft.AspNetCore.SignalR.Protocol;
78
using Microsoft.Extensions.Options;
89

@@ -26,6 +27,10 @@ public HubOptionsSetup(IEnumerable<IHubProtocol> protocols)
2627
{
2728
foreach (var hubProtocol in protocols)
2829
{
30+
if (hubProtocol.GetType().CustomAttributes.Where(a => a.AttributeType.FullName == "Microsoft.AspNetCore.SignalR.Internal.NonDefaultHubProtocolAttribute").Any())
31+
{
32+
continue;
33+
}
2934
_defaultProtocols.Add(hubProtocol.Name);
3035
}
3136
}

src/SignalR/server/Core/src/HubOptionsSetup`T.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ public HubOptionsSetup(IOptions<HubOptions> options)
1616

1717
public void Configure(HubOptions<THub> options)
1818
{
19+
// Do a deep copy, otherwise users modifying the HubOptions<THub> list would be changing the global options list
1920
options.SupportedProtocols = new List<string>(_hubOptions.SupportedProtocols.Count);
2021
foreach (var protocol in _hubOptions.SupportedProtocols)
2122
{

src/SignalR/server/Core/src/SerializedHubMessage.cs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ public class SerializedHubMessage
1414
{
1515
private SerializedMessage _cachedItem1;
1616
private SerializedMessage _cachedItem2;
17-
private IList<SerializedMessage> _cachedItems;
17+
private List<SerializedMessage> _cachedItems;
1818
private readonly object _lock = new object();
1919

2020
/// <summary>
@@ -32,7 +32,7 @@ public SerializedHubMessage(IReadOnlyList<SerializedMessage> messages)
3232
for (var i = 0; i < messages.Count; i++)
3333
{
3434
var message = messages[i];
35-
SetCache(message.ProtocolName, message.Serialized);
35+
SetCacheUnsynchronized(message.ProtocolName, message.Serialized);
3636
}
3737
}
3838

@@ -54,7 +54,7 @@ public ReadOnlyMemory<byte> GetSerializedMessage(IHubProtocol protocol)
5454
{
5555
lock (_lock)
5656
{
57-
if (!TryGetCached(protocol.Name, out var serialized))
57+
if (!TryGetCachedUnsynchronized(protocol.Name, out var serialized))
5858
{
5959
if (Message == null)
6060
{
@@ -63,7 +63,7 @@ public ReadOnlyMemory<byte> GetSerializedMessage(IHubProtocol protocol)
6363
}
6464

6565
serialized = protocol.GetMessageBytes(Message);
66-
SetCache(protocol.Name, serialized);
66+
SetCacheUnsynchronized(protocol.Name, serialized);
6767
}
6868

6969
return serialized;
@@ -98,7 +98,7 @@ internal IReadOnlyList<SerializedMessage> GetAllSerializations()
9898
}
9999
}
100100

101-
private void SetCache(string protocolName, ReadOnlyMemory<byte> serialized)
101+
private void SetCacheUnsynchronized(string protocolName, ReadOnlyMemory<byte> serialized)
102102
{
103103
// We set the fields before moving on to the list, if we need it to hold more than 2 items.
104104
// We have to read/write these fields under the lock because the structs might tear and another
@@ -132,7 +132,7 @@ private void SetCache(string protocolName, ReadOnlyMemory<byte> serialized)
132132
}
133133
}
134134

135-
private bool TryGetCached(string protocolName, out ReadOnlyMemory<byte> result)
135+
private bool TryGetCachedUnsynchronized(string protocolName, out ReadOnlyMemory<byte> result)
136136
{
137137
if (string.Equals(_cachedItem1.ProtocolName, protocolName, StringComparison.Ordinal))
138138
{

src/SignalR/server/SignalR/test/AddSignalRTests.cs

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,15 @@
22
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
33

44
using System;
5+
using System.Buffers;
56
using System.Collections.Generic;
67
using System.Threading;
78
using System.Threading.Tasks;
9+
using Microsoft.AspNetCore.Connections;
810
using Microsoft.AspNetCore.SignalR.Internal;
911
using Microsoft.AspNetCore.SignalR.Protocol;
1012
using Microsoft.Extensions.DependencyInjection;
13+
using Microsoft.Extensions.DependencyInjection.Extensions;
1114
using Microsoft.Extensions.Options;
1215
using Xunit;
1316

@@ -148,6 +151,30 @@ public void UserSpecifiedOptionsRunAfterDefaultOptions()
148151
Assert.Null(globalOptions.SupportedProtocols);
149152
Assert.Equal(TimeSpan.FromSeconds(1), globalOptions.ClientTimeoutInterval);
150153
}
154+
155+
[Fact]
156+
public void HubProtocolsWithNonDefaultAttributeNotAddedToSupportedProtocols()
157+
{
158+
var serviceCollection = new ServiceCollection();
159+
160+
serviceCollection.AddSignalR().AddHubOptions<CustomHub>(options =>
161+
{
162+
});
163+
164+
serviceCollection.TryAddEnumerable(ServiceDescriptor.Singleton<IHubProtocol, CustomHubProtocol>());
165+
serviceCollection.TryAddEnumerable(ServiceDescriptor.Singleton<IHubProtocol, MessagePackHubProtocol>());
166+
167+
var serviceProvider = serviceCollection.BuildServiceProvider();
168+
Assert.Collection(serviceProvider.GetRequiredService<IOptions<HubOptions<CustomHub>>>().Value.SupportedProtocols,
169+
p =>
170+
{
171+
Assert.Equal("json", p);
172+
},
173+
p =>
174+
{
175+
Assert.Equal("messagepack", p);
176+
});
177+
}
151178
}
152179

153180
public class CustomHub : Hub
@@ -276,4 +303,42 @@ public override Task SendUsersAsync(IReadOnlyList<string> userIds, string method
276303
throw new System.NotImplementedException();
277304
}
278305
}
306+
307+
[NonDefaultHubProtocol]
308+
internal class CustomHubProtocol : IHubProtocol
309+
{
310+
public string Name => "custom";
311+
312+
public int Version => throw new NotImplementedException();
313+
314+
public TransferFormat TransferFormat => throw new NotImplementedException();
315+
316+
public ReadOnlyMemory<byte> GetMessageBytes(HubMessage message)
317+
{
318+
throw new NotImplementedException();
319+
}
320+
321+
public bool IsVersionSupported(int version)
322+
{
323+
throw new NotImplementedException();
324+
}
325+
326+
public bool TryParseMessage(ref ReadOnlySequence<byte> input, IInvocationBinder binder, out HubMessage message)
327+
{
328+
throw new NotImplementedException();
329+
}
330+
331+
public void WriteMessage(HubMessage message, IBufferWriter<byte> output)
332+
{
333+
throw new NotImplementedException();
334+
}
335+
}
336+
}
337+
338+
namespace Microsoft.AspNetCore.SignalR.Internal
339+
{
340+
[AttributeUsage(AttributeTargets.Class, AllowMultiple = false, Inherited = true)]
341+
internal class NonDefaultHubProtocolAttribute : Attribute
342+
{
343+
}
279344
}

src/SignalR/server/StackExchangeRedis/ref/Microsoft.AspNetCore.SignalR.StackExchangeRedis.netcoreapp.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis
66
public partial class RedisHubLifetimeManager<THub> : Microsoft.AspNetCore.SignalR.HubLifetimeManager<THub>, System.IDisposable where THub : Microsoft.AspNetCore.SignalR.Hub
77
{
88
public RedisHubLifetimeManager(Microsoft.Extensions.Logging.ILogger<Microsoft.AspNetCore.SignalR.StackExchangeRedis.RedisHubLifetimeManager<THub>> logger, Microsoft.Extensions.Options.IOptions<Microsoft.AspNetCore.SignalR.StackExchangeRedis.RedisOptions> options, Microsoft.AspNetCore.SignalR.IHubProtocolResolver hubProtocolResolver) { }
9+
public RedisHubLifetimeManager(Microsoft.Extensions.Logging.ILogger<Microsoft.AspNetCore.SignalR.StackExchangeRedis.RedisHubLifetimeManager<THub>> logger, Microsoft.Extensions.Options.IOptions<Microsoft.AspNetCore.SignalR.StackExchangeRedis.RedisOptions> options, Microsoft.AspNetCore.SignalR.IHubProtocolResolver hubProtocolResolver, Microsoft.Extensions.Options.IOptions<Microsoft.AspNetCore.SignalR.HubOptions> globalHubOptions, Microsoft.Extensions.Options.IOptions<Microsoft.AspNetCore.SignalR.HubOptions<THub>> hubOptions) { }
910
public override System.Threading.Tasks.Task AddToGroupAsync(string connectionId, string groupName, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
1011
public void Dispose() { }
1112
[System.Diagnostics.DebuggerStepThroughAttribute]
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// Copyright (c) .NET Foundation. All rights reserved.
2+
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
3+
4+
using System;
5+
using System.Collections.Generic;
6+
using System.Linq;
7+
using Microsoft.AspNetCore.SignalR.Protocol;
8+
9+
namespace Microsoft.AspNetCore.SignalR.Internal
10+
{
11+
internal class DefaultHubMessageSerializer
12+
{
13+
private readonly List<IHubProtocol> _hubProtocols = new List<IHubProtocol>();
14+
15+
public DefaultHubMessageSerializer(IHubProtocolResolver hubProtocolResolver, IList<string> globalSupportedProtocols, IList<string> hubSupportedProtocols)
16+
{
17+
var supportedProtocols = hubSupportedProtocols ?? globalSupportedProtocols ?? Array.Empty<string>();
18+
foreach (var protocolName in supportedProtocols)
19+
{
20+
var protocol = hubProtocolResolver.GetProtocol(protocolName, (supportedProtocols as IReadOnlyList<string>) ?? supportedProtocols.ToList());
21+
if (protocol != null)
22+
{
23+
_hubProtocols.Add(protocol);
24+
}
25+
}
26+
}
27+
28+
public IReadOnlyList<SerializedMessage> SerializeMessage(HubMessage message)
29+
{
30+
var list = new List<SerializedMessage>(_hubProtocols.Count);
31+
foreach (var protocol in _hubProtocols)
32+
{
33+
list.Add(new SerializedMessage(protocol.Name, protocol.GetMessageBytes(message)));
34+
}
35+
36+
return list;
37+
}
38+
}
39+
}

src/SignalR/server/StackExchangeRedis/src/Internal/RedisProtocol.cs

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,18 @@
88
using System.Runtime.InteropServices;
99
using MessagePack;
1010
using Microsoft.AspNetCore.Internal;
11+
using Microsoft.AspNetCore.SignalR.Internal;
1112
using Microsoft.AspNetCore.SignalR.Protocol;
1213

1314
namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Internal
1415
{
1516
internal class RedisProtocol
1617
{
17-
private readonly IReadOnlyList<IHubProtocol> _protocols;
18+
private readonly DefaultHubMessageSerializer _messageSerializer;
1819

19-
public RedisProtocol(IReadOnlyList<IHubProtocol> protocols)
20+
public RedisProtocol(DefaultHubMessageSerializer messageSerializer)
2021
{
21-
_protocols = protocols;
22+
_messageSerializer = messageSerializer;
2223
}
2324

2425
// The Redis Protocol:
@@ -60,8 +61,7 @@ public byte[] WriteInvocation(string methodName, object[] args, IReadOnlyList<st
6061
MessagePackBinary.WriteArrayHeader(writer, 0);
6162
}
6263

63-
WriteSerializedHubMessage(writer,
64-
new SerializedHubMessage(new InvocationMessage(methodName, args)));
64+
WriteHubMessage(writer, new InvocationMessage(methodName, args));
6565
return writer.ToArray();
6666
}
6767
finally
@@ -163,19 +163,20 @@ public int ReadAck(ReadOnlyMemory<byte> data)
163163
return MessagePackUtil.ReadInt32(ref data);
164164
}
165165

166-
private void WriteSerializedHubMessage(Stream stream, SerializedHubMessage message)
166+
private void WriteHubMessage(Stream stream, HubMessage message)
167167
{
168168
// Written as a MessagePack 'map' where the keys are the name of the protocol (as a MessagePack 'str')
169169
// and the values are the serialized blob (as a MessagePack 'bin').
170170

171-
MessagePackBinary.WriteMapHeader(stream, _protocols.Count);
171+
var serializedHubMessages = _messageSerializer.SerializeMessage(message);
172172

173-
foreach (var protocol in _protocols)
173+
MessagePackBinary.WriteMapHeader(stream, serializedHubMessages.Count);
174+
175+
foreach (var serializedMessage in serializedHubMessages)
174176
{
175-
MessagePackBinary.WriteString(stream, protocol.Name);
177+
MessagePackBinary.WriteString(stream, serializedMessage.ProtocolName);
176178

177-
var serialized = message.GetSerializedMessage(protocol);
178-
var isArray = MemoryMarshal.TryGetArray(serialized, out var array);
179+
var isArray = MemoryMarshal.TryGetArray(serializedMessage.Serialized, out var array);
179180
Debug.Assert(isArray);
180181
MessagePackBinary.WriteBytes(stream, array.Array, array.Offset, array.Count);
181182
}

src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
using System.Text;
99
using System.Threading;
1010
using System.Threading.Tasks;
11+
using Microsoft.AspNetCore.SignalR.Internal;
1112
using Microsoft.AspNetCore.SignalR.Protocol;
1213
using Microsoft.AspNetCore.SignalR.StackExchangeRedis.Internal;
1314
using Microsoft.Extensions.Logging;
@@ -36,12 +37,29 @@ public class RedisHubLifetimeManager<THub> : HubLifetimeManager<THub>, IDisposab
3637
public RedisHubLifetimeManager(ILogger<RedisHubLifetimeManager<THub>> logger,
3738
IOptions<RedisOptions> options,
3839
IHubProtocolResolver hubProtocolResolver)
40+
: this(logger, options, hubProtocolResolver, globalHubOptions: null, hubOptions: null)
41+
{
42+
}
43+
44+
public RedisHubLifetimeManager(ILogger<RedisHubLifetimeManager<THub>> logger,
45+
IOptions<RedisOptions> options,
46+
IHubProtocolResolver hubProtocolResolver,
47+
IOptions<HubOptions> globalHubOptions,
48+
IOptions<HubOptions<THub>> hubOptions)
3949
{
4050
_logger = logger;
4151
_options = options.Value;
4252
_ackHandler = new AckHandler();
4353
_channels = new RedisChannels(typeof(THub).FullName);
44-
_protocol = new RedisProtocol(hubProtocolResolver.AllProtocols);
54+
if (globalHubOptions != null && hubOptions != null)
55+
{
56+
_protocol = new RedisProtocol(new DefaultHubMessageSerializer(hubProtocolResolver, globalHubOptions.Value.SupportedProtocols, hubOptions.Value.SupportedProtocols));
57+
}
58+
else
59+
{
60+
var supportedProtocols = hubProtocolResolver.AllProtocols.Select(p => p.Name).ToList();
61+
_protocol = new RedisProtocol(new DefaultHubMessageSerializer(hubProtocolResolver, supportedProtocols, null));
62+
}
4563

4664
RedisLog.ConnectingToEndpoints(_logger, options.Value.Configuration.EndPoints, _serverName);
4765
_ = EnsureRedisServerConnection();

0 commit comments

Comments
 (0)