Skip to content

Support CAE #16852

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Feb 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
336 changes: 336 additions & 0 deletions src/Accounts/Accounts.Test/SilentReAuthByTenantCmdletTest.cs

Large diffs are not rendered by default.

22 changes: 22 additions & 0 deletions src/Accounts/Accounts/Account/ConnectAzureRmAccount.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
using System.Management.Automation;
using System.Runtime.InteropServices;
using System.Security;
using System.Text;
using System.Threading;
using System.Threading.Tasks;

Expand Down Expand Up @@ -511,11 +512,17 @@ public override void ExecuteCmdlet()
}
catch (AuthenticationFailedException ex)
{
string message = string.Empty;
if (IsUnableToOpenWebPageError(ex))
{
WriteWarning(Resources.InteractiveAuthNotSupported);
WriteDebug(ex.ToString());
}
else if (TryParseUnknownAuthenticationException(ex, out message))
{
WriteDebug(ex.ToString());
throw ex.WithAdditionalMessage(message);
}
else
{
if (IsUsingInteractiveAuthentication())
Expand Down Expand Up @@ -554,6 +561,21 @@ private bool IsUnableToOpenWebPageError(AuthenticationFailedException exception)
|| (exception.Message?.ToLower()?.Contains("unable to open a web page") ?? false);
}

private bool TryParseUnknownAuthenticationException(AuthenticationFailedException exception, out string message)
{

Comment on lines +564 to +566
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we move the method out of this class because (1) this class is already very long (2) the logic of parsing MsalServiceException is not strongly tied to Connect-AzAccount (3) it's easier to test

We can have a static helper class that handles details of authentication exceptions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently no other place to use the TryParseUnknownAuthenticationException and its logic is very specific for login. We can extract it when we find other place can share the logic.

var innerException = exception?.InnerException as MsalServiceException;
bool isUnknownMsalServiceException = string.Equals(innerException?.ErrorCode, "access_denied", StringComparison.OrdinalIgnoreCase);
message = null;
if(isUnknownMsalServiceException)
{
StringBuilder messageBuilder = new StringBuilder(nameof(innerException.ErrorCode));
messageBuilder.Append(": ").Append(innerException.ErrorCode);
message = messageBuilder.ToString();
}
return isUnknownMsalServiceException;
}

private ConcurrentQueue<Task> _tasks = new ConcurrentQueue<Task>();

private void HandleActions()
Expand Down
3 changes: 3 additions & 0 deletions src/Accounts/Accounts/ChangeLog.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
-->

## Upcoming Release
* Enabled Continue Access Evaluation for MSGraph
* Improved error message when login is blocked by AAD
* Improved error message when silent reauthentication failed

## Version 2.7.2
* Removed legacy assembly System.Private.ServiceModel and System.ServiceModel.Primitives [#16063]
Expand Down
46 changes: 38 additions & 8 deletions src/Accounts/Accounts/CommonModule/ContextAdapter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@
using System.Collections.Generic;
using Microsoft.Azure.Commands.Common.Authentication.Abstractions;
using Microsoft.Azure.Commands.Common.Authentication.Abstractions.Core;
using Microsoft.Azure.Commands.Common.Utilities;
using Microsoft.Azure.Commands.Profile.Models;
using System.Globalization;
using Microsoft.Azure.Commands.Common.Authentication;
using Microsoft.Azure.Commands.ResourceManager.Common.ArgumentCompleters;
using System.Linq;
using System.Management.Automation;
using Microsoft.Azure.Commands.Profile.Properties;
using Azure.Identity;

namespace Microsoft.Azure.Commands.Common
{
Expand Down Expand Up @@ -115,8 +117,7 @@ internal void AddAuthorizeRequestHandler(
{
endpointResourceIdKey = endpointResourceIdKey ?? AzureEnvironment.Endpoint.ResourceManager;
var context = GetDefaultContext(_provider, invocationInfo);
await AuthorizeRequest(context, request, cancelToken, endpointResourceIdKey, endpointSuffixKey, tokenAudienceConverter);
return await next(request, cancelToken, cancelAction, signal);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we disscussed in today's meeting, we would like to revisit the part of the design related to autorest -- instead of adding CAE logic to the AddAuthorizeRequestHandler method, which has little to do with CAE, it seems better to spilt that into a new handler specially for CAE.

However, redoing authentication requires information about the access token (such as environment, audience). If the logic is departed from authentication, how does it get those info?

A second concern is about calling next(). We want to retry the failed request by calling next() for the second time (line 211), however if there are more handlers after this one, those handlers will be called twice against the same http request, which may have side-effect such as duplicated HTTP header. Is there a way allowing us to resend the request without duplicately calling the handlers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we disscussed in today's meeting, we would like to revisit the part of the design related to autorest -- instead of adding CAE logic to the AddAuthorizeRequestHandler method, which has little to do with CAE, it seems better to spilt that into a new handler specially for CAE.

However, redoing authentication requires information about the access token (such as environment, audience). If the logic is departed from authentication, how does it get those info?

A second concern is about calling next(). We want to retry the failed request by calling next() for the second time (line 211), however if there are more handlers after this one, those handlers will be called twice against the same http request, which may have side-effect such as duplicated HTTP header. Is there a way allowing us to resend the request without duplicately calling the handlers?

Thanks for @isra-fel's conclusions. After the discussion with @dolauli, we find a way to seperate the authentication and reauthentication steps. However, the codes of these 2 parts are coupled, we decide to include them into one step currently.
@dolauli Could you help to review the related codes?

Thanks

return await AuthenticationHelper(context, endpointResourceIdKey, endpointSuffixKey, request, cancelToken, cancelAction, signal, next);
});
}

Expand Down Expand Up @@ -191,6 +192,35 @@ public object GetParameterValue(string resourceId, string moduleName, Invocation
return string.Empty;
}

internal async Task<HttpResponseMessage> AuthenticationHelper(IAzureContext context, string endpointResourceIdKey, string endpointSuffixKey, HttpRequestMessage request, CancellationToken cancelToken, Action cancelAction, SignalDelegate signal, NextDelegate next, TokenAudienceConverterDelegate tokenAudienceConverter = null)
{
IAccessToken accessToken = await AuthorizeRequest(context, request, cancelToken, endpointResourceIdKey, endpointSuffixKey, tokenAudienceConverter);
var newRequest = await request.CloneWithContentAndDispose(request.RequestUri, request.Method);
var response = await next(request, cancelToken, cancelAction, signal);

if (response.MatchClaimsChallengePattern())
{
//get token again with claims challenge
if (accessToken is IClaimsChallengeProcessor processor)
{
try
{
var claimsChallenge = ClaimsChallengeUtilities.GetClaimsChallenge(response);
if (!string.IsNullOrEmpty(claimsChallenge))
{
await processor.OnClaimsChallenageAsync(newRequest, claimsChallenge, cancelToken).ConfigureAwait(false);
response = await next(newRequest, cancelToken, cancelAction, signal);
}
}
catch (AuthenticationFailedException e)
{
throw e.WithAdditionalMessage(response?.GetWwwAuthenticateMessage());
}
}
}
return response;
}

/// <summary>
///
/// </summary>
Expand All @@ -202,8 +232,7 @@ internal Func<HttpRequestMessage, CancellationToken, Action, SignalDelegate, Nex
return async (request, cancelToken, cancelAction, signal, next) =>
{
PatchRequestUri(context, request);
await AuthorizeRequest(context, request, cancelToken, resourceId, resourceId);
return await next(request, cancelToken, cancelAction, signal);
return await AuthenticationHelper(context, resourceId, resourceId, request, cancelToken, cancelAction, signal, next);
};
}

Expand All @@ -213,17 +242,17 @@ internal Func<HttpRequestMessage, CancellationToken, Action, SignalDelegate, Nex
/// <param name="context"></param>
/// <param name="endpointResourceIdKey"></param>
/// <param name="request"></param>
/// <param name="outerToken"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
internal async Task AuthorizeRequest(IAzureContext context, HttpRequestMessage request, CancellationToken outerToken, string endpointResourceIdKey,
internal async Task<IAccessToken> AuthorizeRequest(IAzureContext context, HttpRequestMessage request, CancellationToken cancellationToken, string endpointResourceIdKey,
string endpointSuffixKey, TokenAudienceConverterDelegate tokenAudienceConverter = null, IDictionary<string, object> extensibleParamters = null)
{
if (context == null || context.Account == null || context.Environment == null)
{
throw new InvalidOperationException(Resources.InvalidAzureContext);
}

await Task.Run(() =>
return await Task.Run(() =>
{
if (tokenAudienceConverter != null)
{
Expand All @@ -233,7 +262,8 @@ await Task.Run(() =>
}
var authToken = _authenticator.Authenticate(context.Account, context.Environment, context.Tenant.Id, null, "Never", null, endpointResourceIdKey);
authToken.AuthorizeRequest((type, token) => request.Headers.Authorization = new System.Net.Http.Headers.AuthenticationHeaderValue(type, token));
}, outerToken);
return authToken;
}, cancellationToken);
}

private (string CurEnvEndpointResourceId, string CurEnvEndpointSuffix, string BaseEnvEndpointResourceId, string BaseEnvEndpointSuffix) GetEndpointInfo(IAzureEnvironment environment, string endpointResourceIdKey, string endpointSuffixKey)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
using Microsoft.Azure.Internal.Subscriptions;
using Microsoft.Azure.Internal.Subscriptions.Models;
using Microsoft.Azure.Internal.Subscriptions.Models.Utilities;
using Microsoft.Rest;
using Microsoft.WindowsAzure.Commands.Utilities.Common;
using System.Collections.Generic;
using System.Linq;
Expand All @@ -41,7 +40,7 @@ public IList<AzureTenant> ListAccountTenants(IAccessToken accessToken, IAzureEnv
{
subscriptionClient = AzureSession.Instance.ClientFactory.CreateCustomArmClient<SubscriptionClient>(
environment.GetEndpointAsUri(AzureEnvironment.Endpoint.ResourceManager),
new TokenCredentials(accessToken.AccessToken) as ServiceClientCredentials,
new RenewingTokenCredential(accessToken),
AzureSession.Instance.ClientFactory.GetCustomHandlers());

var tenants = new GenericPageEnumerable<TenantIdDescription>(subscriptionClient.Tenants.List, subscriptionClient.Tenants.ListNext, ulong.MaxValue, 0).ToList();
Expand Down Expand Up @@ -71,7 +70,7 @@ public IList<AzureSubscription> ListAllSubscriptionsForTenant(IAccessToken acces
{
using (var subscriptionClient = AzureSession.Instance.ClientFactory.CreateCustomArmClient<SubscriptionClient>(
environment.GetEndpointAsUri(AzureEnvironment.Endpoint.ResourceManager),
new TokenCredentials(accessToken.AccessToken) as ServiceClientCredentials,
new RenewingTokenCredential(accessToken),
AzureSession.Instance.ClientFactory.GetCustomHandlers()))
{
return (subscriptionClient.ListAllSubscriptions()?
Expand All @@ -83,7 +82,7 @@ public AzureSubscription GetSubscriptionById(string subscriptionId, IAccessToken
{
using (var subscriptionClient = AzureSession.Instance.ClientFactory.CreateCustomArmClient<SubscriptionClient>(
environment.GetEndpointAsUri(AzureEnvironment.Endpoint.ResourceManager),
new TokenCredentials(accessToken.AccessToken) as ServiceClientCredentials,
new RenewingTokenCredential(accessToken),
AzureSession.Instance.ClientFactory.GetCustomHandlers()))
{
var subscription = subscriptionClient.Subscriptions.Get(subscriptionId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public IList<AzureTenant> ListAccountTenants(IAccessToken accessToken, IAzureEnv
{
subscriptionClient = AzureSession.Instance.ClientFactory.CreateCustomArmClient<SubscriptionClient>(
environment.GetEndpointAsUri(AzureEnvironment.Endpoint.ResourceManager),
new TokenCredentials(accessToken.AccessToken) as ServiceClientCredentials,
new RenewingTokenCredential(accessToken),
AzureSession.Instance.ClientFactory.GetCustomHandlers());

var tenants = new GenericPageEnumerable<TenantIdDescription>(subscriptionClient.Tenants.List, subscriptionClient.Tenants.ListNext, ulong.MaxValue, 0).ToList();
Expand Down Expand Up @@ -72,7 +72,7 @@ public IList<AzureSubscription> ListAllSubscriptionsForTenant(IAccessToken acces
{
using (var subscriptionClient = AzureSession.Instance.ClientFactory.CreateCustomArmClient<SubscriptionClient>(
environment.GetEndpointAsUri(AzureEnvironment.Endpoint.ResourceManager),
new TokenCredentials(accessToken.AccessToken) as ServiceClientCredentials,
new RenewingTokenCredential(accessToken),
AzureSession.Instance.ClientFactory.GetCustomHandlers()))
{
return subscriptionClient.ListAllSubscriptions()?
Expand All @@ -84,7 +84,7 @@ public AzureSubscription GetSubscriptionById(string subscriptionId, IAccessToken
{
using (var subscriptionClient = AzureSession.Instance.ClientFactory.CreateCustomArmClient<SubscriptionClient>(
environment.GetEndpointAsUri(AzureEnvironment.Endpoint.ResourceManager),
new TokenCredentials(accessToken.AccessToken) as ServiceClientCredentials,
new RenewingTokenCredential(accessToken),
AzureSession.Instance.ClientFactory.GetCustomHandlers()))
{
var subscription = subscriptionClient.Subscriptions.Get(subscriptionId);
Expand Down
20 changes: 18 additions & 2 deletions src/Accounts/Accounts/Tenant/GetAzureRMTenant.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
using Microsoft.Azure.Commands.Common.Authentication.Models;
using Microsoft.Azure.Commands.Profile.Models;
using Microsoft.Azure.Commands.ResourceManager.Common;
using Microsoft.WindowsAzure.Commands.Common;
using System.Collections.Concurrent;
using System.Linq;
using System.Management.Automation;
using System.Threading.Tasks;

namespace Microsoft.Azure.Commands.Profile
{
Expand All @@ -36,11 +37,26 @@ public class GetAzureRMTenantCommand : AzureRMCmdlet
[ValidateNotNullOrEmpty]
public string TenantId { get; set; }


public override void ExecuteCmdlet()
{
var profileClient = new RMProfileClient(AzureRmProfileProvider.Instance.GetProfile<AzureRmProfile>());
profileClient.WarningLog = (message) => _tasks.Enqueue(new Task(() => this.WriteWarning(message)));

var tenants = profileClient.ListTenants(TenantId).Select((t) => new PSAzureTenant(t));
HandleActions();
WriteObject(tenants, enumerateCollection: true);
}

WriteObject(profileClient.ListTenants(TenantId).Select((t) => new PSAzureTenant(t)), enumerateCollection: true);
private ConcurrentQueue<Task> _tasks = new ConcurrentQueue<Task>();

private void HandleActions()
{
Task task;
while (_tasks.TryDequeue(out task))
{
task.RunSynchronously();
}
}
}
}
1 change: 0 additions & 1 deletion src/Accounts/Accounts/Token/GetAzureRmAccessToken.cs
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ public override void ExecuteCmdlet()
{
var tokenParts = accessToken.AccessToken.Split('.');
var decodedToken = Base64UrlHelper.DecodeToString(tokenParts[1]);

var tokenDocument = JsonDocument.Parse(decodedToken);
int expSeconds = tokenDocument.RootElement.EnumerateObject()
.Where(p => p.Name == "exp")
Expand Down
98 changes: 98 additions & 0 deletions src/Accounts/Accounts/Utilities/HttpRequestMessageExtension.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// ----------------------------------------------------------------------------------
//
// Copyright Microsoft Corporation
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// ----------------------------------------------------------------------------------

using System.Collections.Generic;
using System.Net.Http;
using System.Threading.Tasks;

namespace Microsoft.Azure.Commands.Common.Utilities
{
internal static class HttpRequestMessageExtension
{
internal static HttpRequestMessage CloneAndDispose(this HttpRequestMessage original, System.Uri requestUri = null, System.Net.Http.HttpMethod method = null)
{
using (original)
{
return original.Clone(requestUri, method);
}
}

internal static Task<HttpRequestMessage> CloneWithContentAndDispose(this HttpRequestMessage original, System.Uri requestUri = null, System.Net.Http.HttpMethod method = null)
{
using (original)
{
return original.CloneWithContent(requestUri, method);
}
}

/// <summary>
/// Clones an HttpRequestMessage (without the content)
/// </summary>
/// <param name="original">Original HttpRequestMessage (Will be diposed before returning)</param>
/// <returns>A clone of the HttpRequestMessage</returns>
internal static HttpRequestMessage Clone(this HttpRequestMessage original, System.Uri requestUri = null, System.Net.Http.HttpMethod method = null)
{
var clone = new HttpRequestMessage
{
Method = method ?? original.Method,
RequestUri = requestUri ?? original.RequestUri,
Version = original.Version,
};

foreach (KeyValuePair<string, object> prop in original.Properties)
{
clone.Properties.Add(prop);
}

foreach (KeyValuePair<string, IEnumerable<string>> header in original.Headers)
{
/*
**temporarily skip cloning telemetry related headers**
clone.Headers.TryAddWithoutValidation(header.Key, header.Value);
*/
if (!"x-ms-unique-id".Equals(header.Key) && !"x-ms-client-request-id".Equals(header.Key) && !"CommandName".Equals(header.Key) && !"FullCommandName".Equals(header.Key) && !"ParameterSetName".Equals(header.Key) && !"User-Agent".Equals(header.Key))
{
clone.Headers.TryAddWithoutValidation(header.Key, header.Value);
}
}

return clone;
}

/// <summary>
/// Clones an HttpRequestMessage (including the content stream and content headers)
/// </summary>
/// <param name="original">Original HttpRequestMessage (Will be diposed before returning)</param>
/// <returns>A clone of the HttpRequestMessage</returns>
internal static async Task<HttpRequestMessage> CloneWithContent(this HttpRequestMessage original, System.Uri requestUri = null, System.Net.Http.HttpMethod method = null)
{
var clone = original.Clone(requestUri, method);
var stream = new System.IO.MemoryStream();
if (original.Content != null)
{
await original.Content.CopyToAsync(stream).ConfigureAwait(false);
stream.Position = 0;
clone.Content = new StreamContent(stream);
if (original.Content.Headers != null)
{
foreach (var h in original.Content.Headers)
{
clone.Content.Headers.Add(h.Key, h.Value);
}
}
}
return clone;
}
}
}
Loading