|
13 | 13 | // ----------------------------------------------------------------------------------
|
14 | 14 |
|
15 | 15 | using Microsoft.Azure.Commands.Common.Authentication.Abstractions;
|
16 |
| -using Microsoft.Azure.Commands.Common.Authentication.Properties; |
17 |
| -using Microsoft.Rest.Azure; |
18 | 16 | using System;
|
19 |
| -using System.Collections.Generic; |
20 |
| -using System.Net.Http; |
21 |
| -using System.Text; |
22 |
| -using System.Threading; |
23 | 17 |
|
24 | 18 | namespace Microsoft.Azure.Commands.Common.Authentication
|
25 | 19 | {
|
26 |
| - public class ManagedServiceAccessToken : IRenewableToken |
| 20 | + public class ManagedServiceAccessToken : ManagedServiceAccessTokenBase<ManagedServiceTokenInfo> |
27 | 21 | {
|
28 |
| - IAzureAccount _account; |
29 |
| - string _tenant; |
30 |
| - string _resourceId; |
31 |
| - IHttpOperations<ManagedServiceTokenInfo> _tokenGetter; |
32 |
| - DateTime _expiration = DateTime.UtcNow; |
33 |
| - string _accessToken; |
34 |
| - |
35 | 22 | public ManagedServiceAccessToken(IAzureAccount account, IAzureEnvironment environment, string resourceId, string tenant = "Common")
|
| 23 | + : base(account, environment, resourceId, tenant) |
36 | 24 | {
|
37 |
| - if (account == null || string.IsNullOrEmpty(account.Id) || !account.IsPropertySet(AzureAccount.Property.MSILoginUri)) |
38 |
| - { |
39 |
| - throw new ArgumentNullException(nameof(account)); |
40 |
| - } |
41 |
| - |
42 |
| - if (string.IsNullOrWhiteSpace(tenant)) |
43 |
| - { |
44 |
| - throw new ArgumentNullException(nameof(tenant)); |
45 |
| - } |
46 |
| - |
47 |
| - if (environment == null) |
48 |
| - { |
49 |
| - throw new ArgumentNullException(nameof(environment)); |
50 |
| - } |
51 |
| - |
52 |
| - _account = account; |
53 |
| - _resourceId = GetResource(resourceId, environment); |
54 |
| - var idType = GetIdentityType(account); |
55 |
| - foreach (var uri in BuildTokenUri(_account.GetProperty(AzureAccount.Property.MSILoginUri), account, idType, _resourceId)) |
56 |
| - { |
57 |
| - RequestUris.Enqueue(uri); |
58 |
| - } |
59 |
| - |
60 |
| - if (account.IsPropertySet(AzureAccount.Property.MSILoginUriBackup)) |
61 |
| - { |
62 |
| - foreach (var uri in BuildTokenUri(_account.GetProperty(AzureAccount.Property.MSILoginUriBackup), account, idType, _resourceId)) |
63 |
| - { |
64 |
| - RequestUris.Enqueue(uri); |
65 |
| - } |
66 |
| - } |
67 |
| - |
68 |
| - _tenant = tenant; |
69 |
| - IHttpOperationsFactory factory; |
70 |
| - if (!AzureSession.Instance.TryGetComponent(HttpClientOperationsFactory.Name, out factory)) |
71 |
| - { |
72 |
| - factory = HttpClientOperationsFactory.Create(); |
73 |
| - } |
74 |
| - |
75 |
| - _tokenGetter = factory.GetHttpOperations<ManagedServiceTokenInfo>(true).WithHeader("Metadata", new[] { "true" }); |
76 |
| - if (account.IsPropertySet(AzureAccount.Property.MSILoginSecret)) |
77 |
| - { |
78 |
| - _tokenGetter = _tokenGetter.WithHeader("Secret", new[] { account.GetProperty(AzureAccount.Property.MSILoginSecret) }); |
79 |
| - } |
80 |
| - } |
81 |
| - |
82 |
| - public string AccessToken |
83 |
| - { |
84 |
| - get |
85 |
| - { |
86 |
| - try |
87 |
| - { |
88 |
| - GetOrRenewAuthentication(); |
89 |
| - } |
90 |
| - catch (CloudException httpException) |
91 |
| - { |
92 |
| - throw new InvalidOperationException(string.Format(Resources.MSITokenRequestFailed, _resourceId, httpException?.Request?.RequestUri?.ToString()), httpException); |
93 |
| - } |
94 |
| - |
95 |
| - return _accessToken; |
96 |
| - } |
97 |
| - } |
98 |
| - |
99 |
| - public Queue<string> RequestUris { get; } = new Queue<string>(); |
100 |
| - |
101 |
| - public string LoginType => "ManagedService"; |
102 |
| - |
103 |
| - public string TenantId => _tenant; |
104 |
| - |
105 |
| - public string UserId => _account.Id; |
106 |
| - |
107 |
| - public DateTimeOffset ExpiresOn |
108 |
| - { |
109 |
| - get |
110 |
| - { |
111 |
| - return _expiration; |
112 |
| - } |
113 |
| - } |
114 |
| - |
115 |
| - public void AuthorizeRequest(Action<string, string> authTokenSetter) |
116 |
| - { |
117 |
| - authTokenSetter("Bearer", AccessToken); |
118 |
| - } |
119 |
| - |
120 |
| - void GetOrRenewAuthentication() |
121 |
| - { |
122 |
| - if (_expiration - DateTime.UtcNow < ManagedServiceTokenInfo.TimeoutThreshold) |
123 |
| - { |
124 |
| - ManagedServiceTokenInfo info = null; |
125 |
| - while (info == null && RequestUris.Count > 0) |
126 |
| - { |
127 |
| - var currentRequestUri = RequestUris.Dequeue(); |
128 |
| - try |
129 |
| - { |
130 |
| - info = _tokenGetter.GetAsync(currentRequestUri, CancellationToken.None).ConfigureAwait(false).GetAwaiter().GetResult(); |
131 |
| - // if a request was succesful, we should not check any other Uris |
132 |
| - RequestUris.Clear(); |
133 |
| - RequestUris.Enqueue(currentRequestUri); |
134 |
| - } |
135 |
| - catch (Exception e) when ( (e is CloudException || e is HttpRequestException) && RequestUris.Count > 0) |
136 |
| - { |
137 |
| - // skip to the next uri |
138 |
| - } |
139 |
| - } |
140 |
| - |
141 |
| - SetToken(info); |
142 |
| - } |
143 | 25 | }
|
144 | 26 |
|
145 |
| - void SetToken(ManagedServiceTokenInfo info) |
| 27 | + protected override void SetToken(ManagedServiceTokenInfo info) |
146 | 28 | {
|
147 | 29 | if (info != null)
|
148 | 30 | {
|
149 |
| - _expiration = DateTime.UtcNow + TimeSpan.FromSeconds(info.ExpiresIn); |
150 |
| - _accessToken = info.AccessToken; |
151 |
| - } |
152 |
| - } |
153 |
| - |
154 |
| - static IdentityType GetIdentityType(IAzureAccount account) |
155 |
| - { |
156 |
| - if (account == null || string.IsNullOrWhiteSpace(account.Id) || account.Id.Contains("@")) |
157 |
| - { |
158 |
| - return IdentityType.SystemAssigned; |
159 |
| - } |
160 |
| - |
161 |
| - if (account.Id.Contains("/")) |
162 |
| - { |
163 |
| - return IdentityType.Resource; |
164 |
| - } |
165 |
| - |
166 |
| - return IdentityType.ClientId; |
167 |
| - } |
168 |
| - |
169 |
| - static string GetResource(string endpointOrResource, IAzureEnvironment environment) |
170 |
| - { |
171 |
| - return environment.GetEndpoint(endpointOrResource) ?? endpointOrResource; |
172 |
| - } |
173 |
| - |
174 |
| - static IEnumerable<string> BuildTokenUri(string baseUri, IAzureAccount account, IdentityType identityType, string resourceId) |
175 |
| - { |
176 |
| - UriBuilder builder = new UriBuilder(baseUri); |
177 |
| - builder.Query = BuildTokenQuery(account, identityType, resourceId); |
178 |
| - yield return builder.Uri.ToString(); |
179 |
| - |
180 |
| - if (identityType == IdentityType.ClientId) |
181 |
| - { |
182 |
| - builder = new UriBuilder(baseUri); |
183 |
| - builder.Query = BuildTokenQuery(account, IdentityType.ObjectId, resourceId); |
184 |
| - yield return builder.Uri.ToString(); |
| 31 | + Expiration = DateTimeOffset.Now + TimeSpan.FromSeconds(info.ExpiresIn); |
| 32 | + accessToken = info.AccessToken; |
185 | 33 | }
|
186 | 34 | }
|
187 |
| - |
188 |
| - static string BuildTokenQuery(IAzureAccount account, IdentityType idType, string resource) |
189 |
| - { |
190 |
| - StringBuilder query = new StringBuilder($"resource={Uri.EscapeDataString(resource)}"); |
191 |
| - switch (idType) |
192 |
| - { |
193 |
| - case IdentityType.Resource: |
194 |
| - query.Append($"&msi_res_id={Uri.EscapeDataString(account.Id)}"); |
195 |
| - break; |
196 |
| - case IdentityType.ClientId: |
197 |
| - query.Append($"&client_id={Uri.EscapeDataString(account.Id)}"); |
198 |
| - break; |
199 |
| - case IdentityType.ObjectId: |
200 |
| - query.Append($"&object_id={Uri.EscapeDataString(account.Id)}"); |
201 |
| - break; |
202 |
| - } |
203 |
| - |
204 |
| - query.Append("&api-version=2018-02-01"); |
205 |
| - return query.ToString(); |
206 |
| - } |
207 |
| - |
208 |
| - enum IdentityType |
209 |
| - { |
210 |
| - Resource, |
211 |
| - ClientId, |
212 |
| - ObjectId, |
213 |
| - SystemAssigned |
214 |
| - } |
215 | 35 | }
|
216 | 36 | }
|
0 commit comments