Skip to content

Add support for hub specific IHubProtocols that don't affect other hubs #15177

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Oct 22, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Internal;
using Microsoft.AspNetCore.SignalR;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Protocol;

namespace Microsoft.AspNetCore.Components.Server.BlazorPack
{
/// <summary>
/// Implements the SignalR Hub Protocol using MessagePack with limited type support.
/// </summary>
[NonDefaultHubProtocol]
internal sealed class BlazorPackHubProtocol : IHubProtocol
{
internal const string ProtocolName = "blazorpack";
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;

namespace Microsoft.AspNetCore.SignalR.Internal
{
// Tells SignalR not to add the IHubProtocol with this attribute to all hubs by default
[AttributeUsage(AttributeTargets.Class, AllowMultiple = false, Inherited = true)]
internal class NonDefaultHubProtocolAttribute : Attribute
{
}
}
12 changes: 7 additions & 5 deletions src/SignalR/perf/Microbenchmarks/RedisProtocolBenchmark.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
using System.Collections.Generic;
using BenchmarkDotNet.Attributes;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.AspNetCore.SignalR.StackExchangeRedis.Internal;
using Microsoft.Extensions.Logging.Abstractions;

namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
{
Expand All @@ -28,10 +30,10 @@ public class RedisProtocolBenchmark
[GlobalSetup]
public void GlobalSetup()
{
_protocol = new RedisProtocol(new [] {
new DummyProtocol("protocol1"),
new DummyProtocol("protocol2")
});
var resolver = new DefaultHubProtocolResolver(new List<IHubProtocol> { new DummyProtocol("protocol1"),
new DummyProtocol("protocol2") }, NullLogger<DefaultHubProtocolResolver>.Instance);

_protocol = new RedisProtocol(new DefaultHubMessageSerializer(resolver, new List<string>() { "protocol1", "protocol2" }, hubSupportedProtocols: null));

_groupCommand = new RedisGroupCommand(id: 42, serverName: "Server", GroupAction.Add, groupName: "group", connectionId: "connection");

Expand Down Expand Up @@ -119,7 +121,7 @@ private static IReadOnlyList<string> GenerateIds(int count)
return ids;
}

private class DummyProtocol: IHubProtocol
private class DummyProtocol : IHubProtocol
{
private static readonly byte[] _fixedOutput = new byte[] { 0x68, 0x68, 0x6C, 0x6C, 0x6F };

Expand Down
5 changes: 5 additions & 0 deletions src/SignalR/server/Core/src/HubOptionsSetup.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.Extensions.Options;

Expand All @@ -26,6 +27,10 @@ public HubOptionsSetup(IEnumerable<IHubProtocol> protocols)
{
foreach (var hubProtocol in protocols)
{
if (hubProtocol.GetType().CustomAttributes.Where(a => a.AttributeType.FullName == "Microsoft.AspNetCore.SignalR.Internal.NonDefaultHubProtocolAttribute").Any())
{
continue;
}
_defaultProtocols.Add(hubProtocol.Name);
}
}
Expand Down
1 change: 1 addition & 0 deletions src/SignalR/server/Core/src/HubOptionsSetup`T.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ public HubOptionsSetup(IOptions<HubOptions> options)

public void Configure(HubOptions<THub> options)
{
// Do a deep copy, otherwise users modifying the HubOptions<THub> list would be changing the global options list
options.SupportedProtocols = new List<string>(_hubOptions.SupportedProtocols.Count);
foreach (var protocol in _hubOptions.SupportedProtocols)
{
Expand Down
12 changes: 6 additions & 6 deletions src/SignalR/server/Core/src/SerializedHubMessage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ public class SerializedHubMessage
{
private SerializedMessage _cachedItem1;
private SerializedMessage _cachedItem2;
private IList<SerializedMessage> _cachedItems;
private List<SerializedMessage> _cachedItems;
private readonly object _lock = new object();

/// <summary>
Expand All @@ -32,7 +32,7 @@ public SerializedHubMessage(IReadOnlyList<SerializedMessage> messages)
for (var i = 0; i < messages.Count; i++)
{
var message = messages[i];
SetCache(message.ProtocolName, message.Serialized);
SetCacheUnsynchronized(message.ProtocolName, message.Serialized);
}
}

Expand All @@ -54,7 +54,7 @@ public ReadOnlyMemory<byte> GetSerializedMessage(IHubProtocol protocol)
{
lock (_lock)
{
if (!TryGetCached(protocol.Name, out var serialized))
if (!TryGetCachedUnsynchronized(protocol.Name, out var serialized))
{
if (Message == null)
{
Expand All @@ -63,7 +63,7 @@ public ReadOnlyMemory<byte> GetSerializedMessage(IHubProtocol protocol)
}

serialized = protocol.GetMessageBytes(Message);
SetCache(protocol.Name, serialized);
SetCacheUnsynchronized(protocol.Name, serialized);
}

return serialized;
Expand Down Expand Up @@ -98,7 +98,7 @@ internal IReadOnlyList<SerializedMessage> GetAllSerializations()
}
}

private void SetCache(string protocolName, ReadOnlyMemory<byte> serialized)
private void SetCacheUnsynchronized(string protocolName, ReadOnlyMemory<byte> serialized)
{
// We set the fields before moving on to the list, if we need it to hold more than 2 items.
// We have to read/write these fields under the lock because the structs might tear and another
Expand Down Expand Up @@ -132,7 +132,7 @@ private void SetCache(string protocolName, ReadOnlyMemory<byte> serialized)
}
}

private bool TryGetCached(string protocolName, out ReadOnlyMemory<byte> result)
private bool TryGetCachedUnsynchronized(string protocolName, out ReadOnlyMemory<byte> result)
{
if (string.Equals(_cachedItem1.ProtocolName, protocolName, StringComparison.Ordinal))
{
Expand Down
65 changes: 65 additions & 0 deletions src/SignalR/server/SignalR/test/AddSignalRTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Buffers;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.DependencyInjection.Extensions;
using Microsoft.Extensions.Options;
using Xunit;

Expand Down Expand Up @@ -148,6 +151,30 @@ public void UserSpecifiedOptionsRunAfterDefaultOptions()
Assert.Null(globalOptions.SupportedProtocols);
Assert.Equal(TimeSpan.FromSeconds(1), globalOptions.ClientTimeoutInterval);
}

[Fact]
public void HubProtocolsWithNonDefaultAttributeNotAddedToSupportedProtocols()
{
var serviceCollection = new ServiceCollection();

serviceCollection.AddSignalR().AddHubOptions<CustomHub>(options =>
{
});

serviceCollection.TryAddEnumerable(ServiceDescriptor.Singleton<IHubProtocol, CustomHubProtocol>());
serviceCollection.TryAddEnumerable(ServiceDescriptor.Singleton<IHubProtocol, MessagePackHubProtocol>());

var serviceProvider = serviceCollection.BuildServiceProvider();
Assert.Collection(serviceProvider.GetRequiredService<IOptions<HubOptions<CustomHub>>>().Value.SupportedProtocols,
p =>
{
Assert.Equal("json", p);
},
p =>
{
Assert.Equal("messagepack", p);
});
}
}

public class CustomHub : Hub
Expand Down Expand Up @@ -276,4 +303,42 @@ public override Task SendUsersAsync(IReadOnlyList<string> userIds, string method
throw new System.NotImplementedException();
}
}

[NonDefaultHubProtocol]
internal class CustomHubProtocol : IHubProtocol
{
public string Name => "custom";

public int Version => throw new NotImplementedException();

public TransferFormat TransferFormat => throw new NotImplementedException();

public ReadOnlyMemory<byte> GetMessageBytes(HubMessage message)
{
throw new NotImplementedException();
}

public bool IsVersionSupported(int version)
{
throw new NotImplementedException();
}

public bool TryParseMessage(ref ReadOnlySequence<byte> input, IInvocationBinder binder, out HubMessage message)
{
throw new NotImplementedException();
}

public void WriteMessage(HubMessage message, IBufferWriter<byte> output)
{
throw new NotImplementedException();
}
}
}

namespace Microsoft.AspNetCore.SignalR.Internal
{
[AttributeUsage(AttributeTargets.Class, AllowMultiple = false, Inherited = true)]
internal class NonDefaultHubProtocolAttribute : Attribute
{
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis
public partial class RedisHubLifetimeManager<THub> : Microsoft.AspNetCore.SignalR.HubLifetimeManager<THub>, System.IDisposable where THub : Microsoft.AspNetCore.SignalR.Hub
{
public RedisHubLifetimeManager(Microsoft.Extensions.Logging.ILogger<Microsoft.AspNetCore.SignalR.StackExchangeRedis.RedisHubLifetimeManager<THub>> logger, Microsoft.Extensions.Options.IOptions<Microsoft.AspNetCore.SignalR.StackExchangeRedis.RedisOptions> options, Microsoft.AspNetCore.SignalR.IHubProtocolResolver hubProtocolResolver) { }
public RedisHubLifetimeManager(Microsoft.Extensions.Logging.ILogger<Microsoft.AspNetCore.SignalR.StackExchangeRedis.RedisHubLifetimeManager<THub>> logger, Microsoft.Extensions.Options.IOptions<Microsoft.AspNetCore.SignalR.StackExchangeRedis.RedisOptions> options, Microsoft.AspNetCore.SignalR.IHubProtocolResolver hubProtocolResolver, Microsoft.Extensions.Options.IOptions<Microsoft.AspNetCore.SignalR.HubOptions> globalHubOptions, Microsoft.Extensions.Options.IOptions<Microsoft.AspNetCore.SignalR.HubOptions<THub>> hubOptions) { }
public override System.Threading.Tasks.Task AddToGroupAsync(string connectionId, string groupName, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public void Dispose() { }
[System.Diagnostics.DebuggerStepThroughAttribute]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.AspNetCore.SignalR.Protocol;

namespace Microsoft.AspNetCore.SignalR.Internal
{
internal class DefaultHubMessageSerializer
{
private readonly List<IHubProtocol> _hubProtocols = new List<IHubProtocol>();

public DefaultHubMessageSerializer(IHubProtocolResolver hubProtocolResolver, IList<string> globalSupportedProtocols, IList<string> hubSupportedProtocols)
{
var supportedProtocols = hubSupportedProtocols ?? globalSupportedProtocols ?? Array.Empty<string>();
foreach (var protocolName in supportedProtocols)
{
var protocol = hubProtocolResolver.GetProtocol(protocolName, (supportedProtocols as IReadOnlyList<string>) ?? supportedProtocols.ToList());
if (protocol != null)
{
_hubProtocols.Add(protocol);
}
}
}

public IReadOnlyList<SerializedMessage> SerializeMessage(HubMessage message)
{
var list = new List<SerializedMessage>(_hubProtocols.Count);
foreach (var protocol in _hubProtocols)
{
list.Add(new SerializedMessage(protocol.Name, protocol.GetMessageBytes(message)));
}

return list;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,18 @@
using System.Runtime.InteropServices;
using MessagePack;
using Microsoft.AspNetCore.Internal;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Protocol;

namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Internal
{
internal class RedisProtocol
{
private readonly IReadOnlyList<IHubProtocol> _protocols;
private readonly DefaultHubMessageSerializer _messageSerializer;

public RedisProtocol(IReadOnlyList<IHubProtocol> protocols)
public RedisProtocol(DefaultHubMessageSerializer messageSerializer)
{
_protocols = protocols;
_messageSerializer = messageSerializer;
}

// The Redis Protocol:
Expand Down Expand Up @@ -60,8 +61,7 @@ public byte[] WriteInvocation(string methodName, object[] args, IReadOnlyList<st
MessagePackBinary.WriteArrayHeader(writer, 0);
}

WriteSerializedHubMessage(writer,
new SerializedHubMessage(new InvocationMessage(methodName, args)));
WriteHubMessage(writer, new InvocationMessage(methodName, args));
return writer.ToArray();
}
finally
Expand Down Expand Up @@ -163,19 +163,20 @@ public int ReadAck(ReadOnlyMemory<byte> data)
return MessagePackUtil.ReadInt32(ref data);
}

private void WriteSerializedHubMessage(Stream stream, SerializedHubMessage message)
private void WriteHubMessage(Stream stream, HubMessage message)
{
// Written as a MessagePack 'map' where the keys are the name of the protocol (as a MessagePack 'str')
// and the values are the serialized blob (as a MessagePack 'bin').

MessagePackBinary.WriteMapHeader(stream, _protocols.Count);
var serializedHubMessages = _messageSerializer.SerializeMessage(message);

foreach (var protocol in _protocols)
MessagePackBinary.WriteMapHeader(stream, serializedHubMessages.Count);

foreach (var serializedMessage in serializedHubMessages)
{
MessagePackBinary.WriteString(stream, protocol.Name);
MessagePackBinary.WriteString(stream, serializedMessage.ProtocolName);

var serialized = message.GetSerializedMessage(protocol);
var isArray = MemoryMarshal.TryGetArray(serialized, out var array);
var isArray = MemoryMarshal.TryGetArray(serializedMessage.Serialized, out var array);
Debug.Assert(isArray);
MessagePackBinary.WriteBytes(stream, array.Array, array.Offset, array.Count);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.AspNetCore.SignalR.StackExchangeRedis.Internal;
using Microsoft.Extensions.Logging;
Expand Down Expand Up @@ -36,12 +37,29 @@ public class RedisHubLifetimeManager<THub> : HubLifetimeManager<THub>, IDisposab
public RedisHubLifetimeManager(ILogger<RedisHubLifetimeManager<THub>> logger,
IOptions<RedisOptions> options,
IHubProtocolResolver hubProtocolResolver)
: this(logger, options, hubProtocolResolver, globalHubOptions: null, hubOptions: null)
{
}

public RedisHubLifetimeManager(ILogger<RedisHubLifetimeManager<THub>> logger,
IOptions<RedisOptions> options,
IHubProtocolResolver hubProtocolResolver,
IOptions<HubOptions> globalHubOptions,
IOptions<HubOptions<THub>> hubOptions)
{
_logger = logger;
_options = options.Value;
_ackHandler = new AckHandler();
_channels = new RedisChannels(typeof(THub).FullName);
_protocol = new RedisProtocol(hubProtocolResolver.AllProtocols);
if (globalHubOptions != null && hubOptions != null)
{
_protocol = new RedisProtocol(new DefaultHubMessageSerializer(hubProtocolResolver, globalHubOptions.Value.SupportedProtocols, hubOptions.Value.SupportedProtocols));
}
else
{
var supportedProtocols = hubProtocolResolver.AllProtocols.Select(p => p.Name).ToList();
_protocol = new RedisProtocol(new DefaultHubMessageSerializer(hubProtocolResolver, supportedProtocols, null));
}

RedisLog.ConnectingToEndpoints(_logger, options.Value.Configuration.EndPoints, _serverName);
_ = EnsureRedisServerConnection();
Expand Down
Loading