Skip to content

Add Keyed DI support to SignalR #49729

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Aug 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,7 @@ private void ReplaceArguments(HubMethodDescriptor descriptor, HubMethodInvocatio
}
else if (descriptor.IsServiceArgument(parameterPointer))
{
arguments[parameterPointer] = scope.ServiceProvider.GetRequiredService(descriptor.OriginalParameterTypes[parameterPointer]);
arguments[parameterPointer] = descriptor.GetService(scope.ServiceProvider, parameterPointer, descriptor.OriginalParameterTypes[parameterPointer]);
}
else if (isStreamCall && ReflectionHelper.IsStreamingType(descriptor.OriginalParameterTypes[parameterPointer], mustBeDirectType: true))
{
Expand Down
63 changes: 55 additions & 8 deletions src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -79,18 +79,31 @@ public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IServiceProvider
HasSyntheticArguments = true;
return false;
}
else if (p.CustomAttributes.Any(a => typeof(IFromServiceMetadata).IsAssignableFrom(a.AttributeType)) ||
serviceProviderIsService?.IsService(GetServiceType(p.ParameterType)) == true)
else if (p.CustomAttributes.Any())
{
if (index >= 64)
foreach (var attribute in p.GetCustomAttributes(true))
{
throw new InvalidOperationException(
"Hub methods can't use services from DI in the parameters after the 64th parameter.");
if (attribute is IFromServiceMetadata)
{
return MarkServiceParameter(index);
}
else if (attribute is FromKeyedServicesAttribute keyedServicesAttribute)
{
if (serviceProviderIsService is IServiceProviderIsKeyedService keyedServiceProvider &&
keyedServiceProvider.IsKeyedService(GetServiceType(p.ParameterType), keyedServicesAttribute.Key))
{
KeyedServiceKeys ??= new List<(int, object)>();
KeyedServiceKeys.Add((index, keyedServicesAttribute.Key));
return MarkServiceParameter(index);
}
}
}
_isServiceArgument |= (1UL << index);
HasSyntheticArguments = true;
return false;
}
else if (serviceProviderIsService?.IsService(GetServiceType(p.ParameterType)) == true)
{
return MarkServiceParameter(index);
}

return true;
}).Select(p => p.ParameterType).ToArray();

Expand All @@ -102,8 +115,22 @@ public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IServiceProvider
Policies = policies.ToArray();
}

private bool MarkServiceParameter(int index)
{
if (index >= 64)
{
throw new InvalidOperationException(
"Hub methods can't use services from DI in the parameters after the 64th parameter.");
}
_isServiceArgument |= (1UL << index);
HasSyntheticArguments = true;
return false;
}

public List<Type>? StreamingParameters { get; private set; }

public List<(int, object)>? KeyedServiceKeys { get; private set; }

public ObjectMethodExecutor MethodExecutor { get; }

public IReadOnlyList<Type> ParameterTypes { get; }
Expand All @@ -125,6 +152,26 @@ public bool IsServiceArgument(int argumentIndex)
return (_isServiceArgument & (1UL << argumentIndex)) != 0;
}

public object GetService(IServiceProvider serviceProvider, int index, Type parameterType)
{
if (KeyedServiceKeys is not null)
{
foreach (var (paramIndex, key) in KeyedServiceKeys)
{
if (paramIndex == index)
{
return serviceProvider.GetRequiredKeyedService(parameterType, key);
}
else if (paramIndex > index)
{
break;
}
}
}

return serviceProvider.GetRequiredService(parameterType);
}

public IAsyncEnumerator<object> FromReturnedStream(object stream, CancellationToken cancellationToken)
{
// there is the potential for compile to be called times but this has no harmful effect other than perf
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Threading.Channels;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Http.Metadata;
using Microsoft.Extensions.DependencyInjection;
using Newtonsoft.Json.Serialization;

namespace Microsoft.AspNetCore.SignalR.Tests;
Expand Down Expand Up @@ -1362,6 +1363,12 @@ public async Task<int> ServicesAndParams(int value, [FromService] Service1 servi
return total + value;
}

public int MultipleSameKeyedServices([FromKeyedServices("service1")] Service1 service, [FromKeyedServices("service1")] Service1 service2)
{
Assert.Same(service, service2);
return 445;
}

public int ServiceWithoutAttribute(Service1 service)
{
return 1;
Expand All @@ -1384,6 +1391,27 @@ public async Task Stream(ChannelReader<int> channelReader)
await channelReader.ReadAsync();
}
}

public int KeyedService([FromKeyedServices("service1")] Service1 service)
{
return 43;
}

public int KeyedServiceWithParam(int input, [FromKeyedServices("service1")] Service1 service)
{
return 13 * input;
}

public int KeyedServiceNonKeyedService(Service2 service2, [FromKeyedServices("service1")] Service1 service)
{
return 11;
}

public int MultipleKeyedServices([FromKeyedServices("service1")] Service1 service, [FromKeyedServices("service2")] Service1 service2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test case for two services with the same key?

{
Assert.NotEqual(service, service2);
return 45;
}
}

public class TooManyParamsHub : Hub
Expand Down
132 changes: 132 additions & 0 deletions src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4915,6 +4915,138 @@ public async Task ServiceNotResolvedForIEnumerableParameterIfNotInDI()
}
}

[Fact]
public async Task KeyedServiceResolvedIfInDI()
{
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(provider =>
{
provider.AddSignalR(options =>
{
options.EnableDetailedErrors = true;
});

provider.AddKeyedScoped<Service1>("service1");
});
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<ServicesHub>>();

using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout();
var res = await client.InvokeAsync(nameof(ServicesHub.KeyedService)).DefaultTimeout();
Assert.Equal(43L, res.Result);
}
}

[Fact]
public async Task HubMethodCanInjectKeyedServiceWithOtherParameters()
{
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(provider =>
{
provider.AddSignalR(options =>
{
options.EnableDetailedErrors = true;
});

provider.AddKeyedScoped<Service1>("service1");
});
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<ServicesHub>>();

using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout();
var res = await client.InvokeAsync(nameof(ServicesHub.KeyedServiceWithParam), 91).DefaultTimeout();
Assert.Equal(1183L, res.Result);
}
}

[Fact]
public async Task HubMethodCanInjectKeyedServiceWithNonKeyedService()
{
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(provider =>
{
provider.AddSignalR(options =>
{
options.EnableDetailedErrors = true;
});

provider.AddKeyedScoped<Service1>("service1");
provider.AddScoped<Service2>();
});
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<ServicesHub>>();

using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout();
var res = await client.InvokeAsync(nameof(ServicesHub.KeyedServiceNonKeyedService)).DefaultTimeout();
Assert.Equal(11L, res.Result);
}
}

[Fact]
public async Task MultipleKeyedServicesResolved()
{
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(provider =>
{
provider.AddSignalR(options =>
{
options.EnableDetailedErrors = true;
});

provider.AddKeyedScoped<Service1>("service1");
provider.AddKeyedScoped<Service1>("service2");
});
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<ServicesHub>>();

using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout();
var res = await client.InvokeAsync(nameof(ServicesHub.MultipleKeyedServices)).DefaultTimeout();
Assert.Equal(45L, res.Result);
}
}

[Fact]
public async Task MultipleKeyedServicesWithSameNameResolved()
{
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(provider =>
{
provider.AddSignalR(options =>
{
options.EnableDetailedErrors = true;
});

provider.AddKeyedScoped<Service1>("service1");
});
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<ServicesHub>>();

using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout();
var res = await client.InvokeAsync(nameof(ServicesHub.MultipleSameKeyedServices)).DefaultTimeout();
Assert.Equal(445L, res.Result);
}
}

[Fact]
public async Task KeyedServiceNotResolvedIfNotInDI()
{
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(provider =>
{
provider.AddSignalR(options =>
{
options.EnableDetailedErrors = true;
});
});
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<ServicesHub>>();

using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout();
var res = await client.InvokeAsync(nameof(ServicesHub.KeyedService)).DefaultTimeout();
Assert.Equal("Failed to invoke 'KeyedService' due to an error on the server. InvalidDataException: Invocation provides 0 argument(s) but target expects 1.", res.Error);
}
}

[Fact]
public void TooManyParametersWithServiceThrows()
{
Expand Down