Skip to content

Commit 524b35f

Browse files
Add Keyed DI support to SignalR (#49729)
1 parent dacb224 commit 524b35f

File tree

4 files changed

+216
-9
lines changed

4 files changed

+216
-9
lines changed

src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -676,7 +676,7 @@ private void ReplaceArguments(HubMethodDescriptor descriptor, HubMethodInvocatio
676676
}
677677
else if (descriptor.IsServiceArgument(parameterPointer))
678678
{
679-
arguments[parameterPointer] = scope.ServiceProvider.GetRequiredService(descriptor.OriginalParameterTypes[parameterPointer]);
679+
arguments[parameterPointer] = descriptor.GetService(scope.ServiceProvider, parameterPointer, descriptor.OriginalParameterTypes[parameterPointer]);
680680
}
681681
else if (isStreamCall && ReflectionHelper.IsStreamingType(descriptor.OriginalParameterTypes[parameterPointer], mustBeDirectType: true))
682682
{

src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -79,18 +79,31 @@ public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IServiceProvider
7979
HasSyntheticArguments = true;
8080
return false;
8181
}
82-
else if (p.CustomAttributes.Any(a => typeof(IFromServiceMetadata).IsAssignableFrom(a.AttributeType)) ||
83-
serviceProviderIsService?.IsService(GetServiceType(p.ParameterType)) == true)
82+
else if (p.CustomAttributes.Any())
8483
{
85-
if (index >= 64)
84+
foreach (var attribute in p.GetCustomAttributes(true))
8685
{
87-
throw new InvalidOperationException(
88-
"Hub methods can't use services from DI in the parameters after the 64th parameter.");
86+
if (attribute is IFromServiceMetadata)
87+
{
88+
return MarkServiceParameter(index);
89+
}
90+
else if (attribute is FromKeyedServicesAttribute keyedServicesAttribute)
91+
{
92+
if (serviceProviderIsService is IServiceProviderIsKeyedService keyedServiceProvider &&
93+
keyedServiceProvider.IsKeyedService(GetServiceType(p.ParameterType), keyedServicesAttribute.Key))
94+
{
95+
KeyedServiceKeys ??= new List<(int, object)>();
96+
KeyedServiceKeys.Add((index, keyedServicesAttribute.Key));
97+
return MarkServiceParameter(index);
98+
}
99+
}
89100
}
90-
_isServiceArgument |= (1UL << index);
91-
HasSyntheticArguments = true;
92-
return false;
93101
}
102+
else if (serviceProviderIsService?.IsService(GetServiceType(p.ParameterType)) == true)
103+
{
104+
return MarkServiceParameter(index);
105+
}
106+
94107
return true;
95108
}).Select(p => p.ParameterType).ToArray();
96109

@@ -102,8 +115,22 @@ public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IServiceProvider
102115
Policies = policies.ToArray();
103116
}
104117

118+
private bool MarkServiceParameter(int index)
119+
{
120+
if (index >= 64)
121+
{
122+
throw new InvalidOperationException(
123+
"Hub methods can't use services from DI in the parameters after the 64th parameter.");
124+
}
125+
_isServiceArgument |= (1UL << index);
126+
HasSyntheticArguments = true;
127+
return false;
128+
}
129+
105130
public List<Type>? StreamingParameters { get; private set; }
106131

132+
public List<(int, object)>? KeyedServiceKeys { get; private set; }
133+
107134
public ObjectMethodExecutor MethodExecutor { get; }
108135

109136
public IReadOnlyList<Type> ParameterTypes { get; }
@@ -125,6 +152,26 @@ public bool IsServiceArgument(int argumentIndex)
125152
return (_isServiceArgument & (1UL << argumentIndex)) != 0;
126153
}
127154

155+
public object GetService(IServiceProvider serviceProvider, int index, Type parameterType)
156+
{
157+
if (KeyedServiceKeys is not null)
158+
{
159+
foreach (var (paramIndex, key) in KeyedServiceKeys)
160+
{
161+
if (paramIndex == index)
162+
{
163+
return serviceProvider.GetRequiredKeyedService(parameterType, key);
164+
}
165+
else if (paramIndex > index)
166+
{
167+
break;
168+
}
169+
}
170+
}
171+
172+
return serviceProvider.GetRequiredService(parameterType);
173+
}
174+
128175
public IAsyncEnumerator<object> FromReturnedStream(object stream, CancellationToken cancellationToken)
129176
{
130177
// there is the potential for compile to be called times but this has no harmful effect other than perf

src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
using System.Threading.Channels;
88
using Microsoft.AspNetCore.Authorization;
99
using Microsoft.AspNetCore.Http.Metadata;
10+
using Microsoft.Extensions.DependencyInjection;
1011
using Newtonsoft.Json.Serialization;
1112

1213
namespace Microsoft.AspNetCore.SignalR.Tests;
@@ -1362,6 +1363,12 @@ public async Task<int> ServicesAndParams(int value, [FromService] Service1 servi
13621363
return total + value;
13631364
}
13641365

1366+
public int MultipleSameKeyedServices([FromKeyedServices("service1")] Service1 service, [FromKeyedServices("service1")] Service1 service2)
1367+
{
1368+
Assert.Same(service, service2);
1369+
return 445;
1370+
}
1371+
13651372
public int ServiceWithoutAttribute(Service1 service)
13661373
{
13671374
return 1;
@@ -1384,6 +1391,27 @@ public async Task Stream(ChannelReader<int> channelReader)
13841391
await channelReader.ReadAsync();
13851392
}
13861393
}
1394+
1395+
public int KeyedService([FromKeyedServices("service1")] Service1 service)
1396+
{
1397+
return 43;
1398+
}
1399+
1400+
public int KeyedServiceWithParam(int input, [FromKeyedServices("service1")] Service1 service)
1401+
{
1402+
return 13 * input;
1403+
}
1404+
1405+
public int KeyedServiceNonKeyedService(Service2 service2, [FromKeyedServices("service1")] Service1 service)
1406+
{
1407+
return 11;
1408+
}
1409+
1410+
public int MultipleKeyedServices([FromKeyedServices("service1")] Service1 service, [FromKeyedServices("service2")] Service1 service2)
1411+
{
1412+
Assert.NotEqual(service, service2);
1413+
return 45;
1414+
}
13871415
}
13881416

13891417
public class TooManyParamsHub : Hub

src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4915,6 +4915,138 @@ public async Task ServiceNotResolvedForIEnumerableParameterIfNotInDI()
49154915
}
49164916
}
49174917

4918+
[Fact]
4919+
public async Task KeyedServiceResolvedIfInDI()
4920+
{
4921+
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(provider =>
4922+
{
4923+
provider.AddSignalR(options =>
4924+
{
4925+
options.EnableDetailedErrors = true;
4926+
});
4927+
4928+
provider.AddKeyedScoped<Service1>("service1");
4929+
});
4930+
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<ServicesHub>>();
4931+
4932+
using (var client = new TestClient())
4933+
{
4934+
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout();
4935+
var res = await client.InvokeAsync(nameof(ServicesHub.KeyedService)).DefaultTimeout();
4936+
Assert.Equal(43L, res.Result);
4937+
}
4938+
}
4939+
4940+
[Fact]
4941+
public async Task HubMethodCanInjectKeyedServiceWithOtherParameters()
4942+
{
4943+
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(provider =>
4944+
{
4945+
provider.AddSignalR(options =>
4946+
{
4947+
options.EnableDetailedErrors = true;
4948+
});
4949+
4950+
provider.AddKeyedScoped<Service1>("service1");
4951+
});
4952+
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<ServicesHub>>();
4953+
4954+
using (var client = new TestClient())
4955+
{
4956+
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout();
4957+
var res = await client.InvokeAsync(nameof(ServicesHub.KeyedServiceWithParam), 91).DefaultTimeout();
4958+
Assert.Equal(1183L, res.Result);
4959+
}
4960+
}
4961+
4962+
[Fact]
4963+
public async Task HubMethodCanInjectKeyedServiceWithNonKeyedService()
4964+
{
4965+
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(provider =>
4966+
{
4967+
provider.AddSignalR(options =>
4968+
{
4969+
options.EnableDetailedErrors = true;
4970+
});
4971+
4972+
provider.AddKeyedScoped<Service1>("service1");
4973+
provider.AddScoped<Service2>();
4974+
});
4975+
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<ServicesHub>>();
4976+
4977+
using (var client = new TestClient())
4978+
{
4979+
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout();
4980+
var res = await client.InvokeAsync(nameof(ServicesHub.KeyedServiceNonKeyedService)).DefaultTimeout();
4981+
Assert.Equal(11L, res.Result);
4982+
}
4983+
}
4984+
4985+
[Fact]
4986+
public async Task MultipleKeyedServicesResolved()
4987+
{
4988+
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(provider =>
4989+
{
4990+
provider.AddSignalR(options =>
4991+
{
4992+
options.EnableDetailedErrors = true;
4993+
});
4994+
4995+
provider.AddKeyedScoped<Service1>("service1");
4996+
provider.AddKeyedScoped<Service1>("service2");
4997+
});
4998+
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<ServicesHub>>();
4999+
5000+
using (var client = new TestClient())
5001+
{
5002+
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout();
5003+
var res = await client.InvokeAsync(nameof(ServicesHub.MultipleKeyedServices)).DefaultTimeout();
5004+
Assert.Equal(45L, res.Result);
5005+
}
5006+
}
5007+
5008+
[Fact]
5009+
public async Task MultipleKeyedServicesWithSameNameResolved()
5010+
{
5011+
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(provider =>
5012+
{
5013+
provider.AddSignalR(options =>
5014+
{
5015+
options.EnableDetailedErrors = true;
5016+
});
5017+
5018+
provider.AddKeyedScoped<Service1>("service1");
5019+
});
5020+
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<ServicesHub>>();
5021+
5022+
using (var client = new TestClient())
5023+
{
5024+
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout();
5025+
var res = await client.InvokeAsync(nameof(ServicesHub.MultipleSameKeyedServices)).DefaultTimeout();
5026+
Assert.Equal(445L, res.Result);
5027+
}
5028+
}
5029+
5030+
[Fact]
5031+
public async Task KeyedServiceNotResolvedIfNotInDI()
5032+
{
5033+
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(provider =>
5034+
{
5035+
provider.AddSignalR(options =>
5036+
{
5037+
options.EnableDetailedErrors = true;
5038+
});
5039+
});
5040+
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<ServicesHub>>();
5041+
5042+
using (var client = new TestClient())
5043+
{
5044+
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout();
5045+
var res = await client.InvokeAsync(nameof(ServicesHub.KeyedService)).DefaultTimeout();
5046+
Assert.Equal("Failed to invoke 'KeyedService' due to an error on the server. InvalidDataException: Invocation provides 0 argument(s) but target expects 1.", res.Error);
5047+
}
5048+
}
5049+
49185050
[Fact]
49195051
public void TooManyParametersWithServiceThrows()
49205052
{

0 commit comments

Comments
 (0)