Skip to content

Use newer method overloads in auth handlers #30715

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
8 commits merged into from
Mar 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ protected virtual async Task InitializeEventsAsync()
{
Events = Context.RequestServices.GetRequiredService(Options.EventsType);
}
Events = Events ?? await CreateEventsAsync();
Events ??= await CreateEventsAsync();
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ protected override async Task<AuthenticationTicket> CreateTicketAsync(ClaimsIden
throw new HttpRequestException($"An error occurred when retrieving Facebook user information ({response.StatusCode}). Please check if the authentication information is correct and the corresponding Facebook Graph API is enabled.");
}

using (var payload = JsonDocument.Parse(await response.Content.ReadAsStringAsync()))
using (var payload = JsonDocument.Parse(await response.Content.ReadAsStringAsync(Context.RequestAborted)))
{
var context = new OAuthCreatingTicketContext(new ClaimsPrincipal(identity), properties, Context, Scheme, Options, Backchannel, tokens, payload.RootElement);
context.RunClaimActions();
Expand Down
6 changes: 3 additions & 3 deletions src/Security/Authentication/Google/src/GoogleHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ protected override async Task<AuthenticationTicket> CreateTicketAsync(
throw new HttpRequestException($"An error occurred when retrieving Google user information ({response.StatusCode}). Please check if the authentication information is correct.");
}

using (var payload = JsonDocument.Parse(await response.Content.ReadAsStringAsync()))
using (var payload = JsonDocument.Parse(await response.Content.ReadAsStringAsync(Context.RequestAborted)))
{
var context = new OAuthCreatingTicketContext(new ClaimsPrincipal(identity), properties, Context, Scheme, Options, Backchannel, tokens, payload.RootElement);
context.RunClaimActions();
Expand Down Expand Up @@ -80,7 +80,7 @@ protected override string BuildChallengeUrl(AuthenticationProperties properties,
return authorizationEndpoint;
}

private void AddQueryString<T>(
private static void AddQueryString<T>(
IDictionary<string, string> queryStrings,
AuthenticationProperties properties,
string name,
Expand All @@ -107,7 +107,7 @@ private void AddQueryString<T>(
}
}

private void AddQueryString(
private static void AddQueryString(
IDictionary<string, string> queryStrings,
AuthenticationProperties properties,
string name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,13 +243,13 @@ protected override async Task HandleChallengeAsync(AuthenticationProperties prop
{
builder.Append(" error=\"");
builder.Append(eventContext.Error);
builder.Append("\"");
builder.Append('\"');
}
if (!string.IsNullOrEmpty(eventContext.ErrorDescription))
{
if (!string.IsNullOrEmpty(eventContext.Error))
{
builder.Append(",");
builder.Append(',');
}

builder.Append(" error_description=\"");
Expand All @@ -261,7 +261,7 @@ protected override async Task HandleChallengeAsync(AuthenticationProperties prop
if (!string.IsNullOrEmpty(eventContext.Error) ||
!string.IsNullOrEmpty(eventContext.ErrorDescription))
{
builder.Append(",");
builder.Append(',');
}

builder.Append(" error_uri=\"");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ protected override async Task<AuthenticationTicket> CreateTicketAsync(ClaimsIden
throw new HttpRequestException($"An error occurred when retrieving Microsoft user information ({response.StatusCode}). Please check if the authentication information is correct and the corresponding Microsoft Account API is enabled.");
}

using (var payload = JsonDocument.Parse(await response.Content.ReadAsStringAsync()))
using (var payload = JsonDocument.Parse(await response.Content.ReadAsStringAsync(Context.RequestAborted)))
{
var context = new OAuthCreatingTicketContext(new ClaimsPrincipal(identity), properties, Context, Scheme, Options, Backchannel, tokens, payload.RootElement);
context.RunClaimActions();
Expand Down Expand Up @@ -92,7 +92,7 @@ protected override string BuildChallengeUrl(AuthenticationProperties properties,
return QueryHelpers.AddQueryString(Options.AuthorizationEndpoint, queryStrings!);
}

private void AddQueryString<T>(
private static void AddQueryString<T>(
Dictionary<string, string> queryStrings,
AuthenticationProperties properties,
string name,
Expand All @@ -119,7 +119,7 @@ private void AddQueryString<T>(
}
}

private void AddQueryString(
private static void AddQueryString(
Dictionary<string, string> queryStrings,
AuthenticationProperties properties,
string name,
Expand Down
2 changes: 1 addition & 1 deletion src/Security/Authentication/OAuth/src/OAuthHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ protected virtual async Task<OAuthTokenResponse> ExchangeCodeAsync(OAuthCodeExch
var response = await Backchannel.SendAsync(requestMessage, Context.RequestAborted);
if (response.IsSuccessStatusCode)
{
var payload = JsonDocument.Parse(await response.Content.ReadAsStringAsync());
var payload = JsonDocument.Parse(await response.Content.ReadAsStringAsync(Context.RequestAborted));
return OAuthTokenResponse.Success(payload);
}
else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ protected virtual async Task<bool> HandleRemoteSignOutAsync()
&& Request.ContentType.StartsWith("application/x-www-form-urlencoded", StringComparison.OrdinalIgnoreCase)
&& Request.Body.CanRead)
{
var form = await Request.ReadFormAsync();
var form = await Request.ReadFormAsync(Context.RequestAborted);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this? This happens automagically without passing in the token.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TIL - I didn't realise this was equivalent to the default behaviour with CancellationToken.None/default.

message = new OpenIdConnectMessage(form.Select(pair => new KeyValuePair<string, string[]>(pair.Key, pair.Value)));
}

Expand Down Expand Up @@ -195,7 +195,7 @@ public async virtual Task SignOutAsync(AuthenticationProperties properties)
return;
}

properties = properties ?? new AuthenticationProperties();
properties ??= new AuthenticationProperties();

Logger.EnteringOpenIdAuthenticationHandlerHandleSignOutAsync(GetType().FullName);

Expand Down Expand Up @@ -276,7 +276,7 @@ public async virtual Task SignOutAsync(AuthenticationProperties properties)
Response.Headers[HeaderNames.Pragma] = "no-cache";
Response.Headers[HeaderNames.Expires] = HeaderValueEpocDate;

await Response.Body.WriteAsync(buffer, 0, buffer.Length);
await Response.Body.WriteAsync(buffer);
}
else
{
Expand Down Expand Up @@ -479,7 +479,7 @@ private async Task HandleChallengeAsyncInternal(AuthenticationProperties propert
Response.Headers[HeaderNames.Pragma] = "no-cache";
Response.Headers[HeaderNames.Expires] = HeaderValueEpocDate;

await Response.Body.WriteAsync(buffer, 0, buffer.Length);
await Response.Body.WriteAsync(buffer);
return;
}

Expand Down Expand Up @@ -521,7 +521,7 @@ protected override async Task<HandleRequestResult> HandleRemoteAuthenticateAsync
&& Request.ContentType.StartsWith("application/x-www-form-urlencoded", StringComparison.OrdinalIgnoreCase)
&& Request.Body.CanRead)
{
var form = await Request.ReadFormAsync();
var form = await Request.ReadFormAsync(Context.RequestAborted);
authorizationResponse = new OpenIdConnectMessage(form.Select(pair => new KeyValuePair<string, string[]>(pair.Key, pair.Value)));
}

Expand Down Expand Up @@ -823,7 +823,7 @@ protected virtual async Task<OpenIdConnectMessage> RedeemAuthorizationCodeAsync(
var requestMessage = new HttpRequestMessage(HttpMethod.Post, tokenEndpointRequest.TokenEndpoint ?? _configuration.TokenEndpoint);
requestMessage.Content = new FormUrlEncodedContent(tokenEndpointRequest.Parameters);
requestMessage.Version = Backchannel.DefaultRequestVersion;
var responseMessage = await Backchannel.SendAsync(requestMessage);
var responseMessage = await Backchannel.SendAsync(requestMessage, Context.RequestAborted);

var contentMediaType = responseMessage.Content.Headers.ContentType?.MediaType;
if (string.IsNullOrEmpty(contentMediaType))
Expand All @@ -842,7 +842,7 @@ protected virtual async Task<OpenIdConnectMessage> RedeemAuthorizationCodeAsync(
OpenIdConnectMessage message;
try
{
var responseContent = await responseMessage.Content.ReadAsStringAsync();
var responseContent = await responseMessage.Content.ReadAsStringAsync(Context.RequestAborted);
message = new OpenIdConnectMessage(responseContent);
}
catch (Exception ex)
Expand Down Expand Up @@ -886,9 +886,9 @@ protected virtual async Task<HandleRequestResult> GetUserInformationAsync(
var requestMessage = new HttpRequestMessage(HttpMethod.Get, userInfoEndpoint);
requestMessage.Headers.Authorization = new AuthenticationHeaderValue("Bearer", message.AccessToken);
requestMessage.Version = Backchannel.DefaultRequestVersion;
var responseMessage = await Backchannel.SendAsync(requestMessage);
var responseMessage = await Backchannel.SendAsync(requestMessage, Context.RequestAborted);
responseMessage.EnsureSuccessStatusCode();
var userInfoResponse = await responseMessage.Content.ReadAsStringAsync();
var userInfoResponse = await responseMessage.Content.ReadAsStringAsync(Context.RequestAborted);

JsonDocument user;
var contentType = responseMessage.Content.Headers.ContentType;
Expand Down Expand Up @@ -1037,36 +1037,6 @@ private string ReadNonceCookie(string nonce)
return null;
}

private AuthenticationProperties GetPropertiesFromState(string state)
{
// assume a well formed query string: <a=b&>OpenIdConnectAuthenticationDefaults.AuthenticationPropertiesKey=kasjd;fljasldkjflksdj<&c=d>
var startIndex = 0;
if (string.IsNullOrEmpty(state) || (startIndex = state.IndexOf(OpenIdConnectDefaults.AuthenticationPropertiesKey, StringComparison.Ordinal)) == -1)
{
return null;
}

var authenticationIndex = startIndex + OpenIdConnectDefaults.AuthenticationPropertiesKey.Length;
if (authenticationIndex == -1 || authenticationIndex == state.Length || state[authenticationIndex] != '=')
{
return null;
}

// scan rest of string looking for '&'
authenticationIndex++;
var endIndex = state.Substring(authenticationIndex, state.Length - authenticationIndex).IndexOf("&", StringComparison.Ordinal);

// -1 => no other parameters are after the AuthenticationPropertiesKey
if (endIndex == -1)
{
return Options.StateDataFormat.Unprotect(Uri.UnescapeDataString(state.Substring(authenticationIndex).Replace('+', ' ')));
}
else
{
return Options.StateDataFormat.Unprotect(Uri.UnescapeDataString(state.Substring(authenticationIndex, endIndex).Replace('+', ' ')));
}
}

private async Task<MessageReceivedContext> RunMessageReceivedEventAsync(OpenIdConnectMessage message, AuthenticationProperties properties)
{
Logger.MessageReceived(message.BuildRedirectUrl());
Expand Down
18 changes: 9 additions & 9 deletions src/Security/Authentication/Twitter/src/TwitterHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ protected override async Task<HandleRequestResult> HandleRemoteAuthenticateAsync
JsonDocument user;
if (Options.RetrieveUserDetails)
{
user = await RetrieveUserDetailsAsync(accessToken, identity);
user = await RetrieveUserDetailsAsync(accessToken);
}
else
{
Expand Down Expand Up @@ -223,9 +223,9 @@ private async Task<HttpResponseMessage> ExecuteRequestAsync(string url, HttpMeth

var canonicalizedRequestBuilder = new StringBuilder();
canonicalizedRequestBuilder.Append(httpMethod.Method);
canonicalizedRequestBuilder.Append("&");
canonicalizedRequestBuilder.Append('&');
canonicalizedRequestBuilder.Append(Uri.EscapeDataString(url));
canonicalizedRequestBuilder.Append("&");
canonicalizedRequestBuilder.Append('&');
canonicalizedRequestBuilder.Append(Uri.EscapeDataString(parameterString));

var signature = ComputeSignature(Options.ConsumerSecret, accessToken?.TokenSecret, canonicalizedRequestBuilder.ToString());
Expand Down Expand Up @@ -271,7 +271,7 @@ private async Task<RequestToken> ObtainRequestTokenAsync(string callBackUri, Aut

var response = await ExecuteRequestAsync(TwitterDefaults.RequestTokenEndpoint, HttpMethod.Post, extraOAuthPairs: new Dictionary<string, string>() { { "oauth_callback", callBackUri } });
await EnsureTwitterRequestSuccess(response);
var responseText = await response.Content.ReadAsStringAsync();
var responseText = await response.Content.ReadAsStringAsync(Context.RequestAborted);

var responseParameters = new FormCollection(new FormReader(responseText).ReadForm());
if (!string.Equals(responseParameters["oauth_callback_confirmed"], "true", StringComparison.Ordinal))
Expand All @@ -297,7 +297,7 @@ private async Task<AccessToken> ObtainAccessTokenAsync(RequestToken token, strin
await EnsureTwitterRequestSuccess(response); // throw
}

var responseText = await response.Content.ReadAsStringAsync();
var responseText = await response.Content.ReadAsStringAsync(Context.RequestAborted);
var responseParameters = new FormCollection(new FormReader(responseText).ReadForm());

return new AccessToken
Expand All @@ -310,7 +310,7 @@ private async Task<AccessToken> ObtainAccessTokenAsync(RequestToken token, strin
}

// https://dev.twitter.com/rest/reference/get/account/verify_credentials
private async Task<JsonDocument> RetrieveUserDetailsAsync(AccessToken accessToken, ClaimsIdentity identity)
private async Task<JsonDocument> RetrieveUserDetailsAsync(AccessToken accessToken)
{
Logger.RetrieveUserDetails();

Expand All @@ -321,7 +321,7 @@ private async Task<JsonDocument> RetrieveUserDetailsAsync(AccessToken accessToke
Logger.LogError("Email request failed with a status code of " + response.StatusCode);
await EnsureTwitterRequestSuccess(response); // throw
}
var responseText = await response.Content.ReadAsStringAsync();
var responseText = await response.Content.ReadAsStringAsync(Context.RequestAborted);

var result = JsonDocument.Parse(responseText);

Expand All @@ -334,7 +334,7 @@ private string GenerateTimeStamp()
return Convert.ToInt64(secondsSinceUnixEpocStart.TotalSeconds).ToString(CultureInfo.InvariantCulture);
}

private string ComputeSignature(string consumerSecret, string tokenSecret, string signatureData)
private static string ComputeSignature(string consumerSecret, string tokenSecret, string signatureData)
{
using (var algorithm = new HMACSHA1())
{
Expand Down Expand Up @@ -363,7 +363,7 @@ private async Task EnsureTwitterRequestSuccess(HttpResponseMessage response)
try
{
// Failure, attempt to parse Twitters error message
var errorContentStream = await response.Content.ReadAsStreamAsync();
var errorContentStream = await response.Content.ReadAsStreamAsync(Context.RequestAborted);
errorResponse = await JsonSerializer.DeserializeAsync<TwitterErrorResponse>(errorContentStream, ErrorSerializerOptions);
}
catch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ protected override async Task<HandleRequestResult> HandleRemoteAuthenticateAsync
&& Request.ContentType.StartsWith("application/x-www-form-urlencoded", StringComparison.OrdinalIgnoreCase)
&& Request.Body.CanRead)
{
var form = await Request.ReadFormAsync();
var form = await Request.ReadFormAsync(Context.RequestAborted);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This too.


wsFederationMessage = new WsFederationMessage(form.Select(pair => new KeyValuePair<string, string[]>(pair.Key, pair.Value)));
}
Expand Down
4 changes: 2 additions & 2 deletions src/Security/Authentication/test/GoogleTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -941,7 +941,7 @@ public async Task AuthenticateGoogleWhenAlreadySignedInSucceeds()
}

[Fact]
public async Task AuthenticateFacebookWhenAlreadySignedWithGoogleReturnsNull()
public async Task AuthenticateGoogleWhenAlreadySignedWithGoogleReturnsNull()
{
var stateFormat = new PropertiesDataFormat(new EphemeralDataProtectionProvider(NullLoggerFactory.Instance).CreateProtector("GoogleTest"));
using var host = await CreateHost(o =>
Expand Down Expand Up @@ -978,7 +978,7 @@ public async Task AuthenticateFacebookWhenAlreadySignedWithGoogleReturnsNull()
}

[Fact]
public async Task ChallengeFacebookWhenAlreadySignedWithGoogleSucceeds()
public async Task ChallengeGoogleWhenAlreadySignedWithGoogleSucceeds()
{
var stateFormat = new PropertiesDataFormat(new EphemeralDataProtectionProvider(NullLoggerFactory.Instance).CreateProtector("GoogleTest"));
using var host = await CreateHost(o =>
Expand Down