Skip to content

SignalR ConnectionToken/ConnectionAddress feature #13773

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Sep 16, 2019
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,9 @@ private static bool IsWebSocketsSupported()
private async Task<NegotiationResponse> GetNegotiationResponseAsync(Uri uri, CancellationToken cancellationToken)
{
var negotiationResponse = await NegotiateAsync(uri, _httpClient, _logger, cancellationToken);
// If the negotiationVersion is greater than zero then we know that the negotiation response contains a
// connectionToken that will be required to conenct. Otherwise we just set the connectionId and the
// connectionToken on the client to the same value.
if (negotiationResponse.Version > 0)
{
_connectionId = negotiationResponse.ConnectionId;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ public static void WriteResponse(NegotiationResponse response, IBufferWriter<byt
writer.WriteString(ConnectionIdPropertyNameBytes, response.ConnectionId);
}

if (!string.IsNullOrEmpty(response.ConnectionToken))
if (response.Version > 0 & !string.IsNullOrEmpty(response.ConnectionToken))
{
writer.WriteString(ConnectionTokenPropertyNameBytes, response.ConnectionToken);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ public HttpConnectionContext(string connectionId, string connectionToken, ILogge
Features.Set<IConnectionInherentKeepAliveFeature>(this);
}

public HttpConnectionContext(string id, IDuplexPipe transport, IDuplexPipe application, ILogger logger = null)
internal HttpConnectionContext(string id, IDuplexPipe transport, IDuplexPipe application, ILogger logger = null)
: this(id, null, logger)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this only used for tests? If so, can we add a comment or something? Maybe make it internal.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we plan to bring this into 3.1 I don't think we can make it internal, but I agree that it should have been internal to start

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Internal class, make all the changes you want!

{
Transport = transport;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public async Task ExecuteAsync(HttpContext context, HttpConnectionDispatcherOpti
// Create the log scope and attempt to pass the Connection ID to it so as many logs as possible contain
// the Connection ID metadata. If this is the negotiate request then the Connection ID for the scope will
// be set a little later.
var logScope = new ConnectionLogScope(GetConnectionId(context));
var logScope = new ConnectionLogScope(GetConnectionToken(context));
using (_logger.BeginScope(logScope))
{
if (HttpMethods.IsPost(context.Request.Method))
Expand Down Expand Up @@ -279,21 +279,33 @@ private async Task DoPersistentConnection(ConnectionDelegate connectionDelegate,
private async Task ProcessNegotiate(HttpContext context, HttpConnectionDispatcherOptions options, ConnectionLogScope logScope)
{
context.Response.ContentType = "application/json";
string error = null;
int clientProtocolVersion = 0;
if (context.Request.Query.TryGetValue("NegotiateVersion", out var queryStringVersion))
{
// Set the negotiate response to the protocol we use.
var queryStringVersionValue = queryStringVersion.ToString();
if (!int.TryParse(queryStringVersionValue, out clientProtocolVersion))
{
error = $"The client requested an invalid protocol version '{queryStringVersionValue}'";
Log.InvalidNegotiateProtocolVersion(_logger, queryStringVersionValue);
}
}

// Establish the connection
var connection = CreateConnection(options, context);
var connection = CreateConnection(options, context, clientProtocolVersion, error);

// Set the Connection ID on the logging scope so that logs from now on will have the
// Connection ID metadata set.
logScope.ConnectionId = connection.ConnectionId;
logScope.ConnectionId = connection?.ConnectionId;

// Don't use thread static instance here because writer is used with async
var writer = new MemoryBufferWriter();

try
{
// Get the bytes for the connection id
WriteNegotiatePayload(writer, connection.ConnectionId, connection.ConnectionToken, context, options);
WriteNegotiatePayload(writer, connection?.ConnectionId, connection?.ConnectionToken, context, options, clientProtocolVersion, error);

Log.NegotiationRequest(_logger);

Expand All @@ -307,38 +319,34 @@ private async Task ProcessNegotiate(HttpContext context, HttpConnectionDispatche
}
}

private void WriteNegotiatePayload(IBufferWriter<byte> writer, string connectionId, string privateId, HttpContext context, HttpConnectionDispatcherOptions options)
private void WriteNegotiatePayload(IBufferWriter<byte> writer, string connectionId, string connectionToken, HttpContext context, HttpConnectionDispatcherOptions options,
int clientProtocolVersion, string error)
{
var response = new NegotiationResponse();

if (context.Request.Query.TryGetValue("NegotiateVersion", out var queryStringVersion))
if (!string.IsNullOrEmpty(error))
{
// Set the negotiate response to the protocol we use.
var queryStringVersionValue = queryStringVersion.ToString();
if (int.TryParse(queryStringVersionValue, out var clientProtocolVersion))
response.Error = error;
NegotiateProtocol.WriteResponse(response, writer);
return;
}

if (clientProtocolVersion > 0)
{
if (clientProtocolVersion < options.MinimumProtocolVersion)
{
if (clientProtocolVersion < options.MinimumProtocolVersion)
{
response.Error = $"The client requested version '{clientProtocolVersion}', but the server does not support this version.";
Log.NegotiateProtocolVersionMismatch(_logger, clientProtocolVersion);
NegotiateProtocol.WriteResponse(response, writer);
return;
}
else if (clientProtocolVersion > _protocolVersion)
{
response.Version = _protocolVersion;
}
else
{
response.Version = clientProtocolVersion;
}
response.Error = $"The client requested version '{clientProtocolVersion}', but the server does not support this version.";
Log.NegotiateProtocolVersionMismatch(_logger, clientProtocolVersion);
NegotiateProtocol.WriteResponse(response, writer);
return;
}
else if (clientProtocolVersion > _protocolVersion)
{
response.Version = _protocolVersion;
}
else
{
response.Error = $"The client requested an invalid protocol version '{queryStringVersionValue}'";
Log.InvalidNegotiateProtocolVersion(_logger, queryStringVersionValue);
NegotiateProtocol.WriteResponse(response, writer);
return;
response.Version = clientProtocolVersion;
}
}
else if (options.MinimumProtocolVersion > 0)
Expand All @@ -350,7 +358,7 @@ private void WriteNegotiatePayload(IBufferWriter<byte> writer, string connection
}

response.ConnectionId = connectionId;
response.ConnectionToken = privateId;
response.ConnectionToken = connectionToken;
response.AvailableTransports = new List<AvailableTransport>();

if ((options.Transports & HttpTransportType.WebSockets) != 0 && ServerHasWebSockets(context.Features))
Expand All @@ -376,7 +384,7 @@ private static bool ServerHasWebSockets(IFeatureCollection features)
return features.Get<IHttpWebSocketFeature>() != null;
}

private static string GetConnectionId(HttpContext context) => context.Request.Query["id"];
private static string GetConnectionToken(HttpContext context) => context.Request.Query["id"];

private async Task ProcessSend(HttpContext context, HttpConnectionDispatcherOptions options)
{
Expand Down Expand Up @@ -649,9 +657,9 @@ private static HttpContext CloneHttpContext(HttpContext context)

private async Task<HttpConnectionContext> GetConnectionAsync(HttpContext context)
{
var connectionId = GetConnectionId(context);
var connectionToken = GetConnectionToken(context);

if (StringValues.IsNullOrEmpty(connectionId))
if (StringValues.IsNullOrEmpty(connectionToken))
{
// There's no connection ID: bad request
context.Response.StatusCode = StatusCodes.Status400BadRequest;
Expand All @@ -660,7 +668,7 @@ private async Task<HttpConnectionContext> GetConnectionAsync(HttpContext context
return null;
}

if (!_manager.TryGetConnection(connectionId, out var connection))
if (!_manager.TryGetConnection(connectionToken, out var connection))
{
// No connection with that ID: Not Found
context.Response.StatusCode = StatusCodes.Status404NotFound;
Expand All @@ -675,15 +683,15 @@ private async Task<HttpConnectionContext> GetConnectionAsync(HttpContext context
// This is only used for WebSockets connections, which can connect directly without negotiating
private async Task<HttpConnectionContext> GetOrCreateConnectionAsync(HttpContext context, HttpConnectionDispatcherOptions options)
{
var connectionId = GetConnectionId(context);
var connectionToken = GetConnectionToken(context);
HttpConnectionContext connection;

// There's no connection id so this is a brand new connection
if (StringValues.IsNullOrEmpty(connectionId))
if (StringValues.IsNullOrEmpty(connectionToken))
{
connection = CreateConnection(options, context);
}
else if (!_manager.TryGetConnection(connectionId, out connection))
else if (!_manager.TryGetConnection(connectionToken, out connection))
{
// No connection with that ID: Not Found
context.Response.StatusCode = StatusCodes.Status404NotFound;
Expand All @@ -694,21 +702,15 @@ private async Task<HttpConnectionContext> GetOrCreateConnectionAsync(HttpContext
return connection;
}

private HttpConnectionContext CreateConnection(HttpConnectionDispatcherOptions options, HttpContext context)
private HttpConnectionContext CreateConnection(HttpConnectionDispatcherOptions options, HttpContext context, int clientProtocolVersion = 0, string error = null)
{
var transportPipeOptions = new PipeOptions(pauseWriterThreshold: options.TransportMaxBufferSize, resumeWriterThreshold: options.TransportMaxBufferSize / 2, readerScheduler: PipeScheduler.ThreadPool, useSynchronizationContext: false);
var appPipeOptions = new PipeOptions(pauseWriterThreshold: options.ApplicationMaxBufferSize, resumeWriterThreshold: options.ApplicationMaxBufferSize / 2, readerScheduler: PipeScheduler.ThreadPool, useSynchronizationContext: false);

if (context.Request.Query.TryGetValue("NegotiateVersion", out var qsVersion))
if (error != null)
{
// Set the negotiate response to the protocol we use.
var queryStringVersionValue = qsVersion.ToString();
int.TryParse(queryStringVersionValue, out var clientProtocolVersion);
return _manager.CreateConnection(transportPipeOptions, appPipeOptions, clientProtocolVersion);

return null;
}

return _manager.CreateConnection(transportPipeOptions, appPipeOptions, negotiateVersion: 0);
var transportPipeOptions = new PipeOptions(pauseWriterThreshold: options.TransportMaxBufferSize, resumeWriterThreshold: options.TransportMaxBufferSize / 2, readerScheduler: PipeScheduler.ThreadPool, useSynchronizationContext: false);
var appPipeOptions = new PipeOptions(pauseWriterThreshold: options.ApplicationMaxBufferSize, resumeWriterThreshold: options.ApplicationMaxBufferSize / 2, readerScheduler: PipeScheduler.ThreadPool, useSynchronizationContext: false);
return _manager.CreateConnection(transportPipeOptions, appPipeOptions, clientProtocolVersion);
}

private class EmptyServiceProvider : IServiceProvider
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,25 +80,25 @@ internal HttpConnectionContext CreateConnection()
/// <returns></returns>
internal HttpConnectionContext CreateConnection(PipeOptions transportPipeOptions, PipeOptions appPipeOptions, int negotiateVersion = 0)
{
string connectionKey;
string connectionToken;
var id = MakeNewConnectionId();
if (negotiateVersion > 0)
{
connectionKey = MakeNewConnectionId();
connectionToken = MakeNewConnectionId();
}
else
{
connectionKey = id;
connectionToken = id;
}

Log.CreatedNewConnection(_logger, id);
var connectionTimer = HttpConnectionsEventSource.Log.ConnectionStart(id);
var connection = new HttpConnectionContext(id, connectionKey, _connectionLogger);
var connection = new HttpConnectionContext(id, connectionToken, _connectionLogger);
var pair = DuplexPipe.CreateConnectionPair(transportPipeOptions, appPipeOptions);
connection.Transport = pair.Application;
connection.Application = pair.Transport;

_connections.TryAdd(connectionKey, (connection, connectionTimer));
_connections.TryAdd(connectionToken, (connection, connectionTimer));

return connection;
}
Expand Down
Loading