Skip to content

Commit 6518976

Browse files
msJinLeierich-wangwyunchi-ms
authored
Support CAE (#16852)
* CAE (#14567) * prototype for CAE * CAE issue fix * Improved error message when login is blocked by AAD * Improved error message when silent reauthentication failed * Enable CAE for Get-AzTenant and Get-AzSubcription * Add test case Co-authored-by: Erich(Renyong) Wang <[email protected]> * Migrate on-claim-chanllenge handler to the new authentication step for autorest *Enable CAE for MSGraph (#16766) * Update ChangeLog.md * Address review comments * Fix duplicate request issue * Address review comments Co-authored-by: Erich(Renyong) Wang <[email protected]> Co-authored-by: Yunchi Wang <[email protected]>
1 parent a22b1b6 commit 6518976

File tree

21 files changed

+884
-30
lines changed

21 files changed

+884
-30
lines changed

src/Accounts/Accounts.Test/SilentReAuthByTenantCmdletTest.cs

Lines changed: 336 additions & 0 deletions
Large diffs are not rendered by default.

src/Accounts/Accounts/Account/ConnectAzureRmAccount.cs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
using System.Management.Automation;
1919
using System.Runtime.InteropServices;
2020
using System.Security;
21+
using System.Text;
2122
using System.Threading;
2223
using System.Threading.Tasks;
2324

@@ -511,11 +512,17 @@ public override void ExecuteCmdlet()
511512
}
512513
catch (AuthenticationFailedException ex)
513514
{
515+
string message = string.Empty;
514516
if (IsUnableToOpenWebPageError(ex))
515517
{
516518
WriteWarning(Resources.InteractiveAuthNotSupported);
517519
WriteDebug(ex.ToString());
518520
}
521+
else if (TryParseUnknownAuthenticationException(ex, out message))
522+
{
523+
WriteDebug(ex.ToString());
524+
throw ex.WithAdditionalMessage(message);
525+
}
519526
else
520527
{
521528
if (IsUsingInteractiveAuthentication())
@@ -554,6 +561,21 @@ private bool IsUnableToOpenWebPageError(AuthenticationFailedException exception)
554561
|| (exception.Message?.ToLower()?.Contains("unable to open a web page") ?? false);
555562
}
556563

564+
private bool TryParseUnknownAuthenticationException(AuthenticationFailedException exception, out string message)
565+
{
566+
567+
var innerException = exception?.InnerException as MsalServiceException;
568+
bool isUnknownMsalServiceException = string.Equals(innerException?.ErrorCode, "access_denied", StringComparison.OrdinalIgnoreCase);
569+
message = null;
570+
if(isUnknownMsalServiceException)
571+
{
572+
StringBuilder messageBuilder = new StringBuilder(nameof(innerException.ErrorCode));
573+
messageBuilder.Append(": ").Append(innerException.ErrorCode);
574+
message = messageBuilder.ToString();
575+
}
576+
return isUnknownMsalServiceException;
577+
}
578+
557579
private ConcurrentQueue<Task> _tasks = new ConcurrentQueue<Task>();
558580

559581
private void HandleActions()

src/Accounts/Accounts/ChangeLog.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
-->
2020

2121
## Upcoming Release
22+
* Enabled Continue Access Evaluation for MSGraph
23+
* Improved error message when login is blocked by AAD
24+
* Improved error message when silent reauthentication failed
2225

2326
## Version 2.7.2
2427
* Removed legacy assembly System.Private.ServiceModel and System.ServiceModel.Primitives [#16063]

src/Accounts/Accounts/CommonModule/ContextAdapter.cs

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,15 @@
1919
using System.Collections.Generic;
2020
using Microsoft.Azure.Commands.Common.Authentication.Abstractions;
2121
using Microsoft.Azure.Commands.Common.Authentication.Abstractions.Core;
22+
using Microsoft.Azure.Commands.Common.Utilities;
2223
using Microsoft.Azure.Commands.Profile.Models;
2324
using System.Globalization;
2425
using Microsoft.Azure.Commands.Common.Authentication;
2526
using Microsoft.Azure.Commands.ResourceManager.Common.ArgumentCompleters;
2627
using System.Linq;
2728
using System.Management.Automation;
2829
using Microsoft.Azure.Commands.Profile.Properties;
30+
using Azure.Identity;
2931

3032
namespace Microsoft.Azure.Commands.Common
3133
{
@@ -115,8 +117,7 @@ internal void AddAuthorizeRequestHandler(
115117
{
116118
endpointResourceIdKey = endpointResourceIdKey ?? AzureEnvironment.Endpoint.ResourceManager;
117119
var context = GetDefaultContext(_provider, invocationInfo);
118-
await AuthorizeRequest(context, request, cancelToken, endpointResourceIdKey, endpointSuffixKey, tokenAudienceConverter);
119-
return await next(request, cancelToken, cancelAction, signal);
120+
return await AuthenticationHelper(context, endpointResourceIdKey, endpointSuffixKey, request, cancelToken, cancelAction, signal, next);
120121
});
121122
}
122123

@@ -191,6 +192,35 @@ public object GetParameterValue(string resourceId, string moduleName, Invocation
191192
return string.Empty;
192193
}
193194

195+
internal async Task<HttpResponseMessage> AuthenticationHelper(IAzureContext context, string endpointResourceIdKey, string endpointSuffixKey, HttpRequestMessage request, CancellationToken cancelToken, Action cancelAction, SignalDelegate signal, NextDelegate next, TokenAudienceConverterDelegate tokenAudienceConverter = null)
196+
{
197+
IAccessToken accessToken = await AuthorizeRequest(context, request, cancelToken, endpointResourceIdKey, endpointSuffixKey, tokenAudienceConverter);
198+
var newRequest = await request.CloneWithContentAndDispose(request.RequestUri, request.Method);
199+
var response = await next(request, cancelToken, cancelAction, signal);
200+
201+
if (response.MatchClaimsChallengePattern())
202+
{
203+
//get token again with claims challenge
204+
if (accessToken is IClaimsChallengeProcessor processor)
205+
{
206+
try
207+
{
208+
var claimsChallenge = ClaimsChallengeUtilities.GetClaimsChallenge(response);
209+
if (!string.IsNullOrEmpty(claimsChallenge))
210+
{
211+
await processor.OnClaimsChallenageAsync(newRequest, claimsChallenge, cancelToken).ConfigureAwait(false);
212+
response = await next(newRequest, cancelToken, cancelAction, signal);
213+
}
214+
}
215+
catch (AuthenticationFailedException e)
216+
{
217+
throw e.WithAdditionalMessage(response?.GetWwwAuthenticateMessage());
218+
}
219+
}
220+
}
221+
return response;
222+
}
223+
194224
/// <summary>
195225
///
196226
/// </summary>
@@ -202,8 +232,7 @@ internal Func<HttpRequestMessage, CancellationToken, Action, SignalDelegate, Nex
202232
return async (request, cancelToken, cancelAction, signal, next) =>
203233
{
204234
PatchRequestUri(context, request);
205-
await AuthorizeRequest(context, request, cancelToken, resourceId, resourceId);
206-
return await next(request, cancelToken, cancelAction, signal);
235+
return await AuthenticationHelper(context, resourceId, resourceId, request, cancelToken, cancelAction, signal, next);
207236
};
208237
}
209238

@@ -213,17 +242,17 @@ internal Func<HttpRequestMessage, CancellationToken, Action, SignalDelegate, Nex
213242
/// <param name="context"></param>
214243
/// <param name="endpointResourceIdKey"></param>
215244
/// <param name="request"></param>
216-
/// <param name="outerToken"></param>
245+
/// <param name="cancellationToken"></param>
217246
/// <returns></returns>
218-
internal async Task AuthorizeRequest(IAzureContext context, HttpRequestMessage request, CancellationToken outerToken, string endpointResourceIdKey,
247+
internal async Task<IAccessToken> AuthorizeRequest(IAzureContext context, HttpRequestMessage request, CancellationToken cancellationToken, string endpointResourceIdKey,
219248
string endpointSuffixKey, TokenAudienceConverterDelegate tokenAudienceConverter = null, IDictionary<string, object> extensibleParamters = null)
220249
{
221250
if (context == null || context.Account == null || context.Environment == null)
222251
{
223252
throw new InvalidOperationException(Resources.InvalidAzureContext);
224253
}
225254

226-
await Task.Run(() =>
255+
return await Task.Run(() =>
227256
{
228257
if (tokenAudienceConverter != null)
229258
{
@@ -233,7 +262,8 @@ await Task.Run(() =>
233262
}
234263
var authToken = _authenticator.Authenticate(context.Account, context.Environment, context.Tenant.Id, null, "Never", null, endpointResourceIdKey);
235264
authToken.AuthorizeRequest((type, token) => request.Headers.Authorization = new System.Net.Http.Headers.AuthenticationHeaderValue(type, token));
236-
}, outerToken);
265+
return authToken;
266+
}, cancellationToken);
237267
}
238268

239269
private (string CurEnvEndpointResourceId, string CurEnvEndpointSuffix, string BaseEnvEndpointResourceId, string BaseEnvEndpointSuffix) GetEndpointInfo(IAzureEnvironment environment, string endpointResourceIdKey, string endpointSuffixKey)

src/Accounts/Accounts/Models/Version2016_06_01/SubscriptionClientWrapper.cs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
using Microsoft.Azure.Internal.Subscriptions;
1919
using Microsoft.Azure.Internal.Subscriptions.Models;
2020
using Microsoft.Azure.Internal.Subscriptions.Models.Utilities;
21-
using Microsoft.Rest;
2221
using Microsoft.WindowsAzure.Commands.Utilities.Common;
2322
using System.Collections.Generic;
2423
using System.Linq;
@@ -41,7 +40,7 @@ public IList<AzureTenant> ListAccountTenants(IAccessToken accessToken, IAzureEnv
4140
{
4241
subscriptionClient = AzureSession.Instance.ClientFactory.CreateCustomArmClient<SubscriptionClient>(
4342
environment.GetEndpointAsUri(AzureEnvironment.Endpoint.ResourceManager),
44-
new TokenCredentials(accessToken.AccessToken) as ServiceClientCredentials,
43+
new RenewingTokenCredential(accessToken),
4544
AzureSession.Instance.ClientFactory.GetCustomHandlers());
4645

4746
var tenants = new GenericPageEnumerable<TenantIdDescription>(subscriptionClient.Tenants.List, subscriptionClient.Tenants.ListNext, ulong.MaxValue, 0).ToList();
@@ -71,7 +70,7 @@ public IList<AzureSubscription> ListAllSubscriptionsForTenant(IAccessToken acces
7170
{
7271
using (var subscriptionClient = AzureSession.Instance.ClientFactory.CreateCustomArmClient<SubscriptionClient>(
7372
environment.GetEndpointAsUri(AzureEnvironment.Endpoint.ResourceManager),
74-
new TokenCredentials(accessToken.AccessToken) as ServiceClientCredentials,
73+
new RenewingTokenCredential(accessToken),
7574
AzureSession.Instance.ClientFactory.GetCustomHandlers()))
7675
{
7776
return (subscriptionClient.ListAllSubscriptions()?
@@ -83,7 +82,7 @@ public AzureSubscription GetSubscriptionById(string subscriptionId, IAccessToken
8382
{
8483
using (var subscriptionClient = AzureSession.Instance.ClientFactory.CreateCustomArmClient<SubscriptionClient>(
8584
environment.GetEndpointAsUri(AzureEnvironment.Endpoint.ResourceManager),
86-
new TokenCredentials(accessToken.AccessToken) as ServiceClientCredentials,
85+
new RenewingTokenCredential(accessToken),
8786
AzureSession.Instance.ClientFactory.GetCustomHandlers()))
8887
{
8988
var subscription = subscriptionClient.Subscriptions.Get(subscriptionId);

src/Accounts/Accounts/Models/Version2021_01_01/SubscriptionClientWrapper.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ public IList<AzureTenant> ListAccountTenants(IAccessToken accessToken, IAzureEnv
4242
{
4343
subscriptionClient = AzureSession.Instance.ClientFactory.CreateCustomArmClient<SubscriptionClient>(
4444
environment.GetEndpointAsUri(AzureEnvironment.Endpoint.ResourceManager),
45-
new TokenCredentials(accessToken.AccessToken) as ServiceClientCredentials,
45+
new RenewingTokenCredential(accessToken),
4646
AzureSession.Instance.ClientFactory.GetCustomHandlers());
4747

4848
var tenants = new GenericPageEnumerable<TenantIdDescription>(subscriptionClient.Tenants.List, subscriptionClient.Tenants.ListNext, ulong.MaxValue, 0).ToList();
@@ -72,7 +72,7 @@ public IList<AzureSubscription> ListAllSubscriptionsForTenant(IAccessToken acces
7272
{
7373
using (var subscriptionClient = AzureSession.Instance.ClientFactory.CreateCustomArmClient<SubscriptionClient>(
7474
environment.GetEndpointAsUri(AzureEnvironment.Endpoint.ResourceManager),
75-
new TokenCredentials(accessToken.AccessToken) as ServiceClientCredentials,
75+
new RenewingTokenCredential(accessToken),
7676
AzureSession.Instance.ClientFactory.GetCustomHandlers()))
7777
{
7878
return subscriptionClient.ListAllSubscriptions()?
@@ -84,7 +84,7 @@ public AzureSubscription GetSubscriptionById(string subscriptionId, IAccessToken
8484
{
8585
using (var subscriptionClient = AzureSession.Instance.ClientFactory.CreateCustomArmClient<SubscriptionClient>(
8686
environment.GetEndpointAsUri(AzureEnvironment.Endpoint.ResourceManager),
87-
new TokenCredentials(accessToken.AccessToken) as ServiceClientCredentials,
87+
new RenewingTokenCredential(accessToken),
8888
AzureSession.Instance.ClientFactory.GetCustomHandlers()))
8989
{
9090
var subscription = subscriptionClient.Subscriptions.Get(subscriptionId);

src/Accounts/Accounts/Tenant/GetAzureRMTenant.cs

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717
using Microsoft.Azure.Commands.Common.Authentication.Models;
1818
using Microsoft.Azure.Commands.Profile.Models;
1919
using Microsoft.Azure.Commands.ResourceManager.Common;
20-
using Microsoft.WindowsAzure.Commands.Common;
20+
using System.Collections.Concurrent;
2121
using System.Linq;
2222
using System.Management.Automation;
23+
using System.Threading.Tasks;
2324

2425
namespace Microsoft.Azure.Commands.Profile
2526
{
@@ -36,11 +37,26 @@ public class GetAzureRMTenantCommand : AzureRMCmdlet
3637
[ValidateNotNullOrEmpty]
3738
public string TenantId { get; set; }
3839

40+
3941
public override void ExecuteCmdlet()
4042
{
4143
var profileClient = new RMProfileClient(AzureRmProfileProvider.Instance.GetProfile<AzureRmProfile>());
44+
profileClient.WarningLog = (message) => _tasks.Enqueue(new Task(() => this.WriteWarning(message)));
45+
46+
var tenants = profileClient.ListTenants(TenantId).Select((t) => new PSAzureTenant(t));
47+
HandleActions();
48+
WriteObject(tenants, enumerateCollection: true);
49+
}
4250

43-
WriteObject(profileClient.ListTenants(TenantId).Select((t) => new PSAzureTenant(t)), enumerateCollection: true);
51+
private ConcurrentQueue<Task> _tasks = new ConcurrentQueue<Task>();
52+
53+
private void HandleActions()
54+
{
55+
Task task;
56+
while (_tasks.TryDequeue(out task))
57+
{
58+
task.RunSynchronously();
59+
}
4460
}
4561
}
4662
}

src/Accounts/Accounts/Token/GetAzureRmAccessToken.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,6 @@ public override void ExecuteCmdlet()
120120
{
121121
var tokenParts = accessToken.AccessToken.Split('.');
122122
var decodedToken = Base64UrlHelper.DecodeToString(tokenParts[1]);
123-
124123
var tokenDocument = JsonDocument.Parse(decodedToken);
125124
int expSeconds = tokenDocument.RootElement.EnumerateObject()
126125
.Where(p => p.Name == "exp")
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
// ----------------------------------------------------------------------------------
2+
//
3+
// Copyright Microsoft Corporation
4+
// Licensed under the Apache License, Version 2.0 (the "License");
5+
// you may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
// Unless required by applicable law or agreed to in writing, software
9+
// distributed under the License is distributed on an "AS IS" BASIS,
10+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
// See the License for the specific language governing permissions and
12+
// limitations under the License.
13+
// ----------------------------------------------------------------------------------
14+
15+
using System.Collections.Generic;
16+
using System.Net.Http;
17+
using System.Threading.Tasks;
18+
19+
namespace Microsoft.Azure.Commands.Common.Utilities
20+
{
21+
internal static class HttpRequestMessageExtension
22+
{
23+
internal static HttpRequestMessage CloneAndDispose(this HttpRequestMessage original, System.Uri requestUri = null, System.Net.Http.HttpMethod method = null)
24+
{
25+
using (original)
26+
{
27+
return original.Clone(requestUri, method);
28+
}
29+
}
30+
31+
internal static Task<HttpRequestMessage> CloneWithContentAndDispose(this HttpRequestMessage original, System.Uri requestUri = null, System.Net.Http.HttpMethod method = null)
32+
{
33+
using (original)
34+
{
35+
return original.CloneWithContent(requestUri, method);
36+
}
37+
}
38+
39+
/// <summary>
40+
/// Clones an HttpRequestMessage (without the content)
41+
/// </summary>
42+
/// <param name="original">Original HttpRequestMessage (Will be diposed before returning)</param>
43+
/// <returns>A clone of the HttpRequestMessage</returns>
44+
internal static HttpRequestMessage Clone(this HttpRequestMessage original, System.Uri requestUri = null, System.Net.Http.HttpMethod method = null)
45+
{
46+
var clone = new HttpRequestMessage
47+
{
48+
Method = method ?? original.Method,
49+
RequestUri = requestUri ?? original.RequestUri,
50+
Version = original.Version,
51+
};
52+
53+
foreach (KeyValuePair<string, object> prop in original.Properties)
54+
{
55+
clone.Properties.Add(prop);
56+
}
57+
58+
foreach (KeyValuePair<string, IEnumerable<string>> header in original.Headers)
59+
{
60+
/*
61+
**temporarily skip cloning telemetry related headers**
62+
clone.Headers.TryAddWithoutValidation(header.Key, header.Value);
63+
*/
64+
if (!"x-ms-unique-id".Equals(header.Key) && !"x-ms-client-request-id".Equals(header.Key) && !"CommandName".Equals(header.Key) && !"FullCommandName".Equals(header.Key) && !"ParameterSetName".Equals(header.Key) && !"User-Agent".Equals(header.Key))
65+
{
66+
clone.Headers.TryAddWithoutValidation(header.Key, header.Value);
67+
}
68+
}
69+
70+
return clone;
71+
}
72+
73+
/// <summary>
74+
/// Clones an HttpRequestMessage (including the content stream and content headers)
75+
/// </summary>
76+
/// <param name="original">Original HttpRequestMessage (Will be diposed before returning)</param>
77+
/// <returns>A clone of the HttpRequestMessage</returns>
78+
internal static async Task<HttpRequestMessage> CloneWithContent(this HttpRequestMessage original, System.Uri requestUri = null, System.Net.Http.HttpMethod method = null)
79+
{
80+
var clone = original.Clone(requestUri, method);
81+
var stream = new System.IO.MemoryStream();
82+
if (original.Content != null)
83+
{
84+
await original.Content.CopyToAsync(stream).ConfigureAwait(false);
85+
stream.Position = 0;
86+
clone.Content = new StreamContent(stream);
87+
if (original.Content.Headers != null)
88+
{
89+
foreach (var h in original.Content.Headers)
90+
{
91+
clone.Content.Headers.Add(h.Key, h.Value);
92+
}
93+
}
94+
}
95+
return clone;
96+
}
97+
}
98+
}

0 commit comments

Comments
 (0)