Skip to content

Add warning if user changes during Stateful Reconnect #50059

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2666,6 +2666,78 @@ public async Task CanReconnectAndSendMessageOnceConnected()
}
}

[Fact]
public async Task ChangingUserNameDuringReconnectLogsWarning()
{
var protocol = HubProtocols["json"];
await using (var server = await StartServer<Startup>())
{
var websocket = new ClientWebSocket();
var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);

var userName = "test1";
var connectionBuilder = new HubConnectionBuilder()
.WithLoggerFactory(LoggerFactory)
.WithUrl(server.Url + "/default", HttpTransportType.WebSockets, o =>
{
o.WebSocketFactory = async (context, token) =>
{
var httpResponse = await new HttpClient().GetAsync(server.Url + $"/generateJwtToken/{userName}");
httpResponse.EnsureSuccessStatusCode();
var authHeader = await httpResponse.Content.ReadAsStringAsync();
websocket.Options.SetRequestHeader("Authorization", $"Bearer {authHeader}");

await websocket.ConnectAsync(context.Uri, token);
tcs.SetResult();
return websocket;
};
o.UseAcks = true;
})
.WithAutomaticReconnect();
connectionBuilder.Services.AddSingleton(protocol);
var connection = connectionBuilder.Build();

var reconnectCalled = false;
connection.Reconnecting += ex =>
{
reconnectCalled = true;
return Task.CompletedTask;
};

try
{
await connection.StartAsync().DefaultTimeout();
userName = "test2";
await tcs.Task;
tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);

var originalConnectionId = connection.ConnectionId;

var originalWebsocket = websocket;
websocket = new ClientWebSocket();

originalWebsocket.Dispose();

await tcs.Task.DefaultTimeout();

Assert.Equal(originalConnectionId, connection.ConnectionId);
Assert.False(reconnectCalled);

var changeLog = Assert.Single(TestSink.Writes.Where(w => w.EventId.Name == "UserNameChanged"));
Assert.EndsWith("The name of the user changed from 'test1' to 'test2'.", changeLog.Message);
}
catch (Exception ex)
{
LoggerFactory.CreateLogger<HubConnectionTests>().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName);
throw;
}
finally
{
await connection.DisposeAsync().DefaultTimeout();
}
}
}

[Fact]
public async Task ServerAbortsConnectionWithAckingEnabledNoReconnectAttempted()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ public void ConfigureServices(IServiceCollection services)
});
});

services.AddAuthentication(NegotiateDefaults.AuthenticationScheme).AddNegotiate();
services.AddAuthentication(JwtBearerDefaults.AuthenticationScheme)
.AddJwtBearer(options =>
{
Expand All @@ -60,7 +61,6 @@ public void ConfigureServices(IServiceCollection services)
IssuerSigningKey = SecurityKey
};
});
services.AddAuthentication(NegotiateDefaults.AuthenticationScheme).AddNegotiate();

// Since tests run in parallel, it's possible multiple servers will startup,
// we use an ephemeral key provider and repository to avoid filesystem contention issues
Expand Down Expand Up @@ -114,9 +114,9 @@ public void Configure(IApplicationBuilder app)
options.MinimumProtocolVersion = -1;
});

endpoints.MapGet("/generateJwtToken", context =>
endpoints.MapGet("/generateJwtToken/{name?}", (HttpContext context, string name) =>
{
return context.Response.WriteAsync(GenerateJwtToken());
return context.Response.WriteAsync(GenerateJwtToken(name ?? "testuser"));
});

endpoints.Map("/redirect/{*anything}", context =>
Expand All @@ -130,9 +130,9 @@ public void Configure(IApplicationBuilder app)
});
}

private string GenerateJwtToken()
private string GenerateJwtToken(string name = "testuser")
{
var claims = new[] { new Claim(ClaimTypes.NameIdentifier, "testuser") };
var claims = new[] { new Claim(ClaimTypes.NameIdentifier, name) };
var credentials = new SigningCredentials(SecurityKey, SecurityAlgorithms.HmacSha256);
var token = new JwtSecurityToken("SignalRTestServer", "SignalRTests", claims, expires: DateTime.Now.AddSeconds(5), signingCredentials: credentials);
return JwtTokenHandler.WriteToken(token);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,13 @@ internal static partial class Log

[LoggerMessage(16, LogLevel.Debug, "The client requested an invalid protocol version '{queryStringVersionValue}'", EventName = "InvalidNegotiateProtocolVersion")]
public static partial void InvalidNegotiateProtocolVersion(ILogger logger, string queryStringVersionValue);

[LoggerMessage(17, LogLevel.Warning, "The name of the user changed from '{PreviousUserName}' to '{CurrentUserName}'.", EventName = "UserNameChanged")]
private static partial void UserNameChangedInternal(ILogger logger, string previousUserName, string currentUserName);

public static void UserNameChanged(ILogger logger, string? previousUserName, string? currentUserName)
{
UserNameChangedInternal(logger, previousUserName ?? "(null)", currentUserName ?? "(null)");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,17 @@ private async Task<bool> EnsureConnectionStateAsync(HttpConnectionContext connec
connection.HttpContext = context;
}

if (connection.User is not null)
{
var originalName = connection.User.FindFirst(ClaimTypes.NameIdentifier)?.Value;
var newName = connection.HttpContext?.User.FindFirst(ClaimTypes.NameIdentifier)?.Value;
if (originalName != newName)
{
// Log warning, different user
Log.UserNameChanged(_logger, originalName, newName);
}
}

// Setup the connection state from the http context
connection.User = connection.HttpContext?.User;

Expand Down