Skip to content

Added ability to use Issuer to receive Token Endpoint for the OAuth2ClientBuilder #1656

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions projects/RabbitMQ.Client.OAuth2/IOAuth2Client.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,19 @@ namespace RabbitMQ.Client.OAuth2
{
public interface IOAuth2Client
{
/// <summary>
/// Request a new AccessToken from the Token Endpoint.
/// </summary>
/// <param name="cancellationToken">Cancellation token for this request</param>
/// <returns>Token with Access and Refresh Token</returns>
Task<IToken> RequestTokenAsync(CancellationToken cancellationToken = default);

/// <summary>
/// Request a new AccessToken using the Refresh Token from the Token Endpoint.
/// </summary>
/// <param name="token">Token with the Refresh Token</param>
/// <param name="cancellationToken">Cancellation token for this request</param>
/// <returns>Token with Access and Refresh Token</returns>
Task<IToken> RefreshTokenAsync(IToken token, CancellationToken cancellationToken = default);
}
}
185 changes: 143 additions & 42 deletions projects/RabbitMQ.Client.OAuth2/OAuth2Client.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,58 +42,147 @@ namespace RabbitMQ.Client.OAuth2
{
public class OAuth2ClientBuilder
{
/// <summary>
/// Discovery endpoint subpath for all OpenID Connect issuers.
/// </summary>
const string DISCOVERY_ENDPOINT = ".well-known/openid-configuration";

private readonly string _clientId;
private readonly string _clientSecret;
private readonly Uri _tokenEndpoint;

// At least one of the following Uris is not null
private readonly Uri? _tokenEndpoint;
private readonly Uri? _issuer;

private string? _scope;
private IDictionary<string, string>? _additionalRequestParameters;
private HttpClientHandler? _httpClientHandler;

public OAuth2ClientBuilder(string clientId, string clientSecret, Uri tokenEndpoint)
/// <summary>
/// Create a new builder for creating <see cref="OAuth2Client"/>s.
/// </summary>
/// <param name="clientId">Id of the client</param>
/// <param name="clientSecret">Secret of the client</param>
/// <param name="tokenEndpoint">Endpoint to receive the Access Token</param>
/// <param name="issuer">Issuer of the Access Token. Used to automaticly receive the Token Endpoint while building</param>
/// <remarks>
/// Either <paramref name="tokenEndpoint"/> or <paramref name="issuer"/> must be provided.
/// </remarks>
public OAuth2ClientBuilder(string clientId, string clientSecret, Uri? tokenEndpoint = null, Uri? issuer = null)
{
_clientId = clientId ?? throw new ArgumentNullException(nameof(clientId));
_clientSecret = clientSecret ?? throw new ArgumentNullException(nameof(clientSecret));
_tokenEndpoint = tokenEndpoint ?? throw new ArgumentNullException(nameof(tokenEndpoint));

if (tokenEndpoint is null && issuer is null)
{
throw new ArgumentException("Either tokenEndpoint or issuer is required");
}

_tokenEndpoint = tokenEndpoint;
_issuer = issuer;
}

/// <summary>
/// Set the requested scopes for the client.
/// </summary>
/// <param name="scope">OAuth scopes to request from the Issuer</param>
public OAuth2ClientBuilder SetScope(string scope)
{
_scope = scope ?? throw new ArgumentNullException(nameof(scope));
return this;
}

/// <summary>
/// Set custom HTTP Client handler for requests of the OAuth2 client.
/// </summary>
/// <param name="handler">Custom handler for HTTP requests</param>
public OAuth2ClientBuilder SetHttpClientHandler(HttpClientHandler handler)
{
_httpClientHandler = handler ?? throw new ArgumentNullException(nameof(handler));
return this;
}

/// <summary>
/// Add a additional request parameter to each HTTP request.
/// </summary>
/// <param name="param">Name of the parameter</param>
/// <param name="paramValue">Value of the parameter</param>
public OAuth2ClientBuilder AddRequestParameter(string param, string paramValue)
{
if (param == null)
if (param is null)
{
throw new ArgumentNullException("param is null");
throw new ArgumentNullException(nameof(param));
}

if (paramValue == null)
if (paramValue is null)
{
throw new ArgumentNullException("paramValue is null");
throw new ArgumentNullException(nameof(paramValue));
}

if (_additionalRequestParameters == null)
{
_additionalRequestParameters = new Dictionary<string, string>();
}
_additionalRequestParameters ??= new Dictionary<string, string>();
_additionalRequestParameters[param] = paramValue;

return this;
}

public IOAuth2Client Build()
/// <summary>
/// Build the <see cref="OAuth2Client"/> with the provided properties of the builder.
/// </summary>
/// <param name="cancellationToken">Cancellation token for this method</param>
/// <returns>Configured OAuth2Client</returns>
public async ValueTask<IOAuth2Client> BuildAsync(CancellationToken cancellationToken = default)
{
// Check if Token Endpoint is missing -> Use Issuer to receive Token Endpoint
if (_tokenEndpoint is null)
{
Uri tokenEndpoint = await GetTokenEndpointFromIssuerAsync(cancellationToken).ConfigureAwait(false);
return new OAuth2Client(_clientId, _clientSecret, tokenEndpoint,
_scope, _additionalRequestParameters, _httpClientHandler);
}

return new OAuth2Client(_clientId, _clientSecret, _tokenEndpoint,
_scope, _additionalRequestParameters, _httpClientHandler);
}

/// <summary>
/// Receive Token Endpoint from discovery page of the Issuer.
/// </summary>
/// <param name="cancellationToken">Cancellation token for this request</param>
/// <returns>Uri of the Token Endpoint</returns>
private async Task<Uri> GetTokenEndpointFromIssuerAsync(CancellationToken cancellationToken = default)
{
if (_issuer is null)
{
throw new InvalidOperationException("The issuer is required");
}

using HttpClient httpClient = _httpClientHandler is null
? new HttpClient()
: new HttpClient(_httpClientHandler, false);

httpClient.DefaultRequestHeaders.Accept.Clear();
httpClient.DefaultRequestHeaders.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json"));

// Build endpoint from Issuer and dicovery endpoint, we can't use the Uri overload because the Issuer Uri may not have a trailing '/'
string tempIssuer = _issuer.AbsoluteUri.EndsWith("/") ? _issuer.AbsoluteUri : _issuer.AbsoluteUri + "/";
Uri discoveryEndpoint = new Uri(tempIssuer + DISCOVERY_ENDPOINT);

using HttpRequestMessage req = new HttpRequestMessage(HttpMethod.Get, discoveryEndpoint);
using HttpResponseMessage response = await httpClient.SendAsync(req, cancellationToken)
.ConfigureAwait(false);

response.EnsureSuccessStatusCode();

OpenIDConnectDiscovery? discovery = await response.Content.ReadFromJsonAsync<OpenIDConnectDiscovery>(cancellationToken: cancellationToken)
.ConfigureAwait(false);

if (discovery is null || string.IsNullOrEmpty(discovery.TokenEndpoint))
{
throw new InvalidOperationException("No token endpoint was found");
}

return new Uri(discovery.TokenEndpoint);
}
}

/**
Expand All @@ -119,7 +208,7 @@ internal class OAuth2Client : IOAuth2Client, IDisposable

public static readonly IDictionary<string, string> EMPTY = new Dictionary<string, string>();

private HttpClient _httpClient;
private readonly HttpClient _httpClient;

public OAuth2Client(string clientId, string clientSecret, Uri tokenEndpoint,
string? scope,
Expand All @@ -132,73 +221,64 @@ public OAuth2Client(string clientId, string clientSecret, Uri tokenEndpoint,
_additionalRequestParameters = additionalRequestParameters ?? EMPTY;
_tokenEndpoint = tokenEndpoint;

if (httpClientHandler is null)
{
_httpClient = new HttpClient();
}
else
{
_httpClient = new HttpClient(httpClientHandler, false);
}
_httpClient = httpClientHandler is null
? new HttpClient()
: new HttpClient(httpClientHandler, false);

_httpClient.DefaultRequestHeaders.Accept.Clear();
_httpClient.DefaultRequestHeaders.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json"));
}

/// <inheritdoc />
public async Task<IToken> RequestTokenAsync(CancellationToken cancellationToken = default)
{
using HttpRequestMessage req = new HttpRequestMessage(HttpMethod.Post, _tokenEndpoint);
req.Content = new FormUrlEncodedContent(BuildRequestParameters());

using HttpResponseMessage response = await _httpClient.SendAsync(req)
using HttpResponseMessage response = await _httpClient.SendAsync(req, cancellationToken)
.ConfigureAwait(false);

response.EnsureSuccessStatusCode();

JsonToken? token = await response.Content.ReadFromJsonAsync<JsonToken>()
JsonToken? token = await response.Content.ReadFromJsonAsync<JsonToken>(cancellationToken: cancellationToken)
.ConfigureAwait(false);

if (token is null)
{
// TODO specific exception?
throw new InvalidOperationException("token is null");
}
else
{
return new Token(token);
}

return new Token(token);
}

/// <inheritdoc />
public async Task<IToken> RefreshTokenAsync(IToken token,
CancellationToken cancellationToken = default)
{
if (token.RefreshToken == null)
if (token.RefreshToken is null)
{
throw new InvalidOperationException("Token has no Refresh Token");
}

using HttpRequestMessage req = new HttpRequestMessage(HttpMethod.Post, _tokenEndpoint)
{
Content = new FormUrlEncodedContent(BuildRefreshParameters(token))
};
using HttpRequestMessage req = new HttpRequestMessage(HttpMethod.Post, _tokenEndpoint);
req.Content = new FormUrlEncodedContent(BuildRefreshParameters(token));

using HttpResponseMessage response = await _httpClient.SendAsync(req)
using HttpResponseMessage response = await _httpClient.SendAsync(req, cancellationToken)
.ConfigureAwait(false);

response.EnsureSuccessStatusCode();

JsonToken? refreshedToken = await response.Content.ReadFromJsonAsync<JsonToken>()
JsonToken? refreshedToken = await response.Content.ReadFromJsonAsync<JsonToken>(cancellationToken: cancellationToken)
.ConfigureAwait(false);

if (refreshedToken is null)
{
// TODO specific exception?
throw new InvalidOperationException("refreshed token is null");
}
else
{
return new Token(refreshedToken);
}

return new Token(refreshedToken);
}

public void Dispose()
Expand All @@ -214,9 +294,9 @@ private Dictionary<string, string> BuildRequestParameters()
{ CLIENT_SECRET, _clientSecret }
};

if (_scope != null && _scope.Length > 0)
if (!string.IsNullOrEmpty(_scope))
{
dict.Add(SCOPE, _scope);
dict.Add(SCOPE, _scope!);
}

dict.Add(GRANT_TYPE, GRANT_TYPE_CLIENT_CREDENTIALS);
Expand All @@ -227,8 +307,7 @@ private Dictionary<string, string> BuildRequestParameters()
private Dictionary<string, string> BuildRefreshParameters(IToken token)
{
Dictionary<string, string> dict = BuildRequestParameters();
dict.Remove(GRANT_TYPE);
dict.Add(GRANT_TYPE, REFRESH_TOKEN);
dict[GRANT_TYPE] = REFRESH_TOKEN;

if (_scope != null)
{
Expand Down Expand Up @@ -284,4 +363,26 @@ public long ExpiresIn
get; set;
}
}

/// <summary>
/// Minimal version of the properties of the discovery endpoint.
/// </summary>
internal class OpenIDConnectDiscovery
{
public OpenIDConnectDiscovery()
{
TokenEndpoint = string.Empty;
}

public OpenIDConnectDiscovery(string tokenEndpoint)
{
TokenEndpoint = tokenEndpoint;
}

[JsonPropertyName("token_endpoint")]
public string TokenEndpoint
{
get; set;
}
}
}
3 changes: 1 addition & 2 deletions projects/RabbitMQ.Client.OAuth2/PublicAPI.Shipped.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ RabbitMQ.Client.OAuth2.IToken.HasExpired.get -> bool
RabbitMQ.Client.OAuth2.IToken.RefreshToken.get -> string
RabbitMQ.Client.OAuth2.OAuth2ClientBuilder
RabbitMQ.Client.OAuth2.OAuth2ClientBuilder.AddRequestParameter(string param, string paramValue) -> RabbitMQ.Client.OAuth2.OAuth2ClientBuilder
RabbitMQ.Client.OAuth2.OAuth2ClientBuilder.Build() -> RabbitMQ.Client.OAuth2.IOAuth2Client
RabbitMQ.Client.OAuth2.OAuth2ClientBuilder.OAuth2ClientBuilder(string clientId, string clientSecret, System.Uri tokenEndpoint) -> void
RabbitMQ.Client.OAuth2.OAuth2ClientBuilder.OAuth2ClientBuilder(string! clientId, string! clientSecret, System.Uri? tokenEndpoint = null, System.Uri? issuer = null) -> void
RabbitMQ.Client.OAuth2.OAuth2ClientBuilder.SetScope(string scope) -> RabbitMQ.Client.OAuth2.OAuth2ClientBuilder
RabbitMQ.Client.OAuth2.OAuth2ClientCredentialsProvider
RabbitMQ.Client.OAuth2.OAuth2ClientCredentialsProvider.Name.get -> string
Expand Down
1 change: 1 addition & 0 deletions projects/RabbitMQ.Client.OAuth2/PublicAPI.Unshipped.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ RabbitMQ.Client.OAuth2.CredentialsRefresherEventSource.Stopped(string! name) ->
RabbitMQ.Client.OAuth2.IOAuth2Client.RefreshTokenAsync(RabbitMQ.Client.OAuth2.IToken! token, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task<RabbitMQ.Client.OAuth2.IToken!>!
RabbitMQ.Client.OAuth2.IOAuth2Client.RequestTokenAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task<RabbitMQ.Client.OAuth2.IToken!>!
RabbitMQ.Client.OAuth2.NotifyCredentialsRefreshedAsync
RabbitMQ.Client.OAuth2.OAuth2ClientBuilder.BuildAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.ValueTask<RabbitMQ.Client.OAuth2.IOAuth2Client!>
RabbitMQ.Client.OAuth2.OAuth2ClientBuilder.SetHttpClientHandler(System.Net.Http.HttpClientHandler! handler) -> RabbitMQ.Client.OAuth2.OAuth2ClientBuilder!
RabbitMQ.Client.OAuth2.OAuth2ClientCredentialsProvider.Dispose() -> void
RabbitMQ.Client.OAuth2.OAuth2ClientCredentialsProvider.GetCredentialsAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task<RabbitMQ.Client.Credentials!>!
Expand Down
Loading
Loading