Skip to content

Commit 7d9ce19

Browse files
authored
Merge pull request #858 from AzureAD/avdunn/mi-refreshon
Set refreshOn in MI flows to half of token lifetime
2 parents 6f431a3 + c46e5d2 commit 7d9ce19

File tree

4 files changed

+59
-42
lines changed

4 files changed

+59
-42
lines changed

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AcquireTokenByManagedIdentitySupplier.java

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ class AcquireTokenByManagedIdentitySupplier extends AuthenticationResultSupplier
1313

1414
private static final Logger LOG = LoggerFactory.getLogger(AcquireTokenByManagedIdentitySupplier.class);
1515

16+
private static final int TWO_HOURS = 2*3600;
17+
1618
private ManagedIdentityParameters managedIdentityParameters;
1719

1820
AcquireTokenByManagedIdentitySupplier(ManagedIdentityApplication managedIdentityApplication, MsalRequest msalRequest) {
@@ -93,15 +95,27 @@ private AuthenticationResult fetchNewAccessTokenAndSaveToCache(TokenRequestExecu
9395
}
9496

9597
private AuthenticationResult createFromManagedIdentityResponse(ManagedIdentityResponse managedIdentityResponse) {
96-
long expiresOn = Long.valueOf(managedIdentityResponse.expiresOn);
97-
long refreshOn = expiresOn > 2 * 3600 ? (expiresOn / 2) : 0L;
98+
long expiresOn = Long.parseLong(managedIdentityResponse.expiresOn);
99+
long refreshOn = calculateRefreshOn(expiresOn);
100+
AuthenticationResultMetadata metadata = AuthenticationResultMetadata.builder()
101+
.refreshOn(refreshOn)
102+
.build();
98103

99104
return AuthenticationResult.builder()
100105
.accessToken(managedIdentityResponse.getAccessToken())
101106
.scopes(managedIdentityParameters.resource())
102107
.expiresOn(expiresOn)
103108
.extExpiresOn(0)
104109
.refreshOn(refreshOn)
110+
.metadata(metadata)
105111
.build();
106112
}
113+
114+
private long calculateRefreshOn(long expiresOn){
115+
long timestampSeconds = System.currentTimeMillis() / 1000;
116+
long expiresIn = expiresOn - timestampSeconds;
117+
118+
//The refreshOn value should be half the value of the token lifetime, if the lifetime is greater than two hours
119+
return expiresIn > TWO_HOURS ? (expiresIn / 2) + timestampSeconds : 0;
120+
}
107121
}

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ManagedIdentityApplication.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ public class ManagedIdentityApplication extends AbstractApplicationBase implemen
2525
static TokenCache sharedTokenCache = new TokenCache();
2626

2727
@Getter(value = AccessLevel.PUBLIC)
28-
static ManagedIdentitySourceType managedIdentitySource = ManagedIdentityClient.getManagedIdentitySource();
28+
ManagedIdentitySourceType managedIdentitySource = ManagedIdentityClient.getManagedIdentitySource();
2929

3030
@Getter(value = AccessLevel.PACKAGE)
3131
static IEnvironmentVariables environmentVariables;

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ManagedIdentityClient.java

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,36 +12,24 @@
1212
class ManagedIdentityClient {
1313
private static final Logger LOG = LoggerFactory.getLogger(ManagedIdentityClient.class);
1414

15-
private static ManagedIdentitySourceType managedIdentitySourceType;
16-
17-
protected static void resetManagedIdentitySourceType() {
18-
managedIdentitySourceType = ManagedIdentitySourceType.NONE;
19-
}
20-
2115
static ManagedIdentitySourceType getManagedIdentitySource() {
22-
if (managedIdentitySourceType != null && managedIdentitySourceType != ManagedIdentitySourceType.NONE) {
23-
return managedIdentitySourceType;
24-
}
25-
2616
IEnvironmentVariables environmentVariables = AbstractManagedIdentitySource.getEnvironmentVariables();
2717

2818
if (!StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(Constants.IDENTITY_ENDPOINT)) &&
2919
!StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(Constants.IDENTITY_HEADER))) {
3020
if (!StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(Constants.IDENTITY_SERVER_THUMBPRINT))) {
31-
managedIdentitySourceType = ManagedIdentitySourceType.SERVICE_FABRIC;
21+
return ManagedIdentitySourceType.SERVICE_FABRIC;
3222
} else {
33-
managedIdentitySourceType = ManagedIdentitySourceType.APP_SERVICE;
23+
return ManagedIdentitySourceType.APP_SERVICE;
3424
}
3525
} else if (!StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(Constants.MSI_ENDPOINT))) {
36-
managedIdentitySourceType = ManagedIdentitySourceType.CLOUD_SHELL;
26+
return ManagedIdentitySourceType.CLOUD_SHELL;
3727
} else if (!StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(Constants.IDENTITY_ENDPOINT)) &&
3828
!StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(Constants.IMDS_ENDPOINT))) {
39-
managedIdentitySourceType = ManagedIdentitySourceType.AZURE_ARC;
29+
return ManagedIdentitySourceType.AZURE_ARC;
4030
} else {
41-
managedIdentitySourceType = ManagedIdentitySourceType.DEFAULT_TO_IMDS;
31+
return ManagedIdentitySourceType.DEFAULT_TO_IMDS;
4232
}
43-
44-
return managedIdentitySourceType;
4533
}
4634

4735
AbstractManagedIdentitySource managedIdentitySource;
@@ -64,11 +52,7 @@ ManagedIdentityResponse getManagedIdentityResponse(ManagedIdentityParameters par
6452
private static AbstractManagedIdentitySource createManagedIdentitySource(MsalRequest msalRequest,
6553
ServiceBundle serviceBundle) {
6654

67-
if (managedIdentitySourceType == null || managedIdentitySourceType == ManagedIdentitySourceType.NONE) {
68-
managedIdentitySourceType = getManagedIdentitySource();
69-
}
70-
71-
switch (managedIdentitySourceType) {
55+
switch (getManagedIdentitySource()) {
7256
case SERVICE_FABRIC:
7357
return ServiceFabricManagedIdentitySource.create(msalRequest, serviceBundle);
7458
case APP_SERVICE:

msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTests.java

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class ManagedIdentityTests {
4444
private static ManagedIdentityApplication miApp;
4545

4646
private String getSuccessfulResponse(String resource) {
47-
long expiresOn = Instant.now().plus(1, ChronoUnit.HOURS).getEpochSecond();
47+
long expiresOn = (System.currentTimeMillis() / 1000) + (24 * 3600);//A long-lived, 24 hour token
4848
return "{\"access_token\":\"accesstoken\",\"expires_on\":\"" + expiresOn + "\",\"resource\":\"" + resource + "\",\"token_type\":" +
4949
"\"Bearer\",\"client_id\":\"client_id\"}";
5050
}
@@ -155,18 +155,22 @@ private HttpResponse expectedResponse(int statusCode, String response) {
155155
void managedIdentity_GetManagedIdentitySource(ManagedIdentitySourceType source, String endpoint, ManagedIdentitySourceType expectedSource) {
156156
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint);
157157
ManagedIdentityApplication.setEnvironmentVariables(environmentVariables);
158-
ManagedIdentityClient.resetManagedIdentitySourceType();
159158

160-
ManagedIdentitySourceType managedIdentitySourceType = ManagedIdentityClient.getManagedIdentitySource();
161-
assertEquals(expectedSource, managedIdentitySourceType);
159+
miApp = ManagedIdentityApplication
160+
.builder(ManagedIdentityId.systemAssigned())
161+
.build();
162+
163+
ManagedIdentitySourceType miClientSourceType = ManagedIdentityClient.getManagedIdentitySource();
164+
ManagedIdentitySourceType miAppSourceType = miApp.managedIdentitySource;
165+
assertEquals(expectedSource, miClientSourceType);
166+
assertEquals(expectedSource, miAppSourceType);
162167
}
163168

164169
@ParameterizedTest
165170
@MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createData")
166171
void managedIdentityTest_SystemAssigned_SuccessfulResponse(ManagedIdentitySourceType source, String endpoint, String resource) throws Exception {
167172
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint);
168173
ManagedIdentityApplication.setEnvironmentVariables(environmentVariables);
169-
ManagedIdentityClient.resetManagedIdentitySourceType();
170174
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);
171175

172176
when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource)));
@@ -201,7 +205,6 @@ void managedIdentityTest_SystemAssigned_SuccessfulResponse(ManagedIdentitySource
201205
void managedIdentityTest_UserAssigned_SuccessfulResponse(ManagedIdentitySourceType source, String endpoint, ManagedIdentityId id) throws Exception {
202206
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint);
203207
ManagedIdentityApplication.setEnvironmentVariables(environmentVariables);
204-
ManagedIdentityClient.resetManagedIdentitySourceType();
205208
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);
206209

207210
when(httpClientMock.send(expectedRequest(source, resource, id))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource)));
@@ -222,12 +225,38 @@ void managedIdentityTest_UserAssigned_SuccessfulResponse(ManagedIdentitySourceTy
222225
verify(httpClientMock, times(1)).send(any());
223226
}
224227

228+
@Test
229+
void managedIdentityTest_RefreshOnHalfOfExpiresOn() throws Exception {
230+
//All managed identity flows use the same AcquireTokenByManagedIdentitySupplier where refreshOn is set,
231+
// so any of the MI options should let us verify that it's being set correctly
232+
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(ManagedIdentitySourceType.APP_SERVICE, appServiceEndpoint);
233+
ManagedIdentityApplication.setEnvironmentVariables(environmentVariables);
234+
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);
235+
236+
when(httpClientMock.send(expectedRequest(ManagedIdentitySourceType.APP_SERVICE, resource))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource)));
237+
238+
miApp = ManagedIdentityApplication
239+
.builder(ManagedIdentityId.systemAssigned())
240+
.httpClient(httpClientMock)
241+
.build();
242+
243+
AuthenticationResult result = (AuthenticationResult) miApp.acquireTokenForManagedIdentity(
244+
ManagedIdentityParameters.builder(resource)
245+
.build()).get();
246+
247+
long timestampSeconds = (System.currentTimeMillis() / 1000);
248+
249+
assertNotNull(result.accessToken());
250+
assertEquals((result.expiresOn() - timestampSeconds)/2, result.refreshOn() - timestampSeconds);
251+
252+
verify(httpClientMock, times(1)).send(any());
253+
}
254+
225255
@ParameterizedTest
226256
@MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataUserAssignedNotSupported")
227257
void managedIdentityTest_UserAssigned_NotSupported(ManagedIdentitySourceType source, String endpoint, ManagedIdentityId id) throws Exception {
228258
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint);
229259
ManagedIdentityApplication.setEnvironmentVariables(environmentVariables);
230-
ManagedIdentityClient.resetManagedIdentitySourceType();
231260
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);
232261

233262
miApp = ManagedIdentityApplication
@@ -264,7 +293,6 @@ void managedIdentityTest_DifferentScopes_RequestsNewToken(ManagedIdentitySourceT
264293

265294
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint);
266295
ManagedIdentityApplication.setEnvironmentVariables(environmentVariables);
267-
ManagedIdentityClient.resetManagedIdentitySourceType();
268296
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);
269297

270298
when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource)));
@@ -298,7 +326,6 @@ void managedIdentityTest_DifferentScopes_RequestsNewToken(ManagedIdentitySourceT
298326
void managedIdentityTest_WrongScopes(ManagedIdentitySourceType source, String endpoint, String resource) throws Exception {
299327
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint);
300328
ManagedIdentityApplication.setEnvironmentVariables(environmentVariables);
301-
ManagedIdentityClient.resetManagedIdentitySourceType();
302329
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);
303330

304331
if (environmentVariables.getEnvironmentVariable("SourceType").equals(ManagedIdentitySourceType.CLOUD_SHELL.toString())) {
@@ -337,7 +364,6 @@ void managedIdentityTest_WrongScopes(ManagedIdentitySourceType source, String en
337364
void managedIdentityTest_Retry(ManagedIdentitySourceType source, String endpoint, String resource) throws Exception {
338365
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint);
339366
ManagedIdentityApplication.setEnvironmentVariables(environmentVariables);
340-
ManagedIdentityClient.resetManagedIdentitySourceType();
341367
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);
342368

343369
miApp = ManagedIdentityApplication
@@ -388,7 +414,6 @@ void managedIdentityTest_Retry(ManagedIdentitySourceType source, String endpoint
388414
void managedIdentity_RequestFailed_NoPayload(ManagedIdentitySourceType source, String endpoint) throws Exception {
389415
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint);
390416
ManagedIdentityApplication.setEnvironmentVariables(environmentVariables);
391-
ManagedIdentityClient.resetManagedIdentitySourceType();
392417
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);
393418

394419
when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(500, ""));
@@ -423,7 +448,6 @@ void managedIdentity_RequestFailed_NoPayload(ManagedIdentitySourceType source, S
423448
void managedIdentity_RequestFailed_NullResponse(ManagedIdentitySourceType source, String endpoint) throws Exception {
424449
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint);
425450
ManagedIdentityApplication.setEnvironmentVariables(environmentVariables);
426-
ManagedIdentityClient.resetManagedIdentitySourceType();
427451
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);
428452

429453
when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(200, ""));
@@ -458,7 +482,6 @@ void managedIdentity_RequestFailed_NullResponse(ManagedIdentitySourceType source
458482
void managedIdentity_RequestFailed_UnreachableNetwork(ManagedIdentitySourceType source, String endpoint) throws Exception {
459483
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint);
460484
ManagedIdentityApplication.setEnvironmentVariables(environmentVariables);
461-
ManagedIdentityClient.resetManagedIdentitySourceType();
462485
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);
463486

464487
when(httpClientMock.send(expectedRequest(source, resource))).thenThrow(new SocketException("A socket operation was attempted to an unreachable network."));
@@ -492,7 +515,6 @@ void managedIdentity_RequestFailed_UnreachableNetwork(ManagedIdentitySourceType
492515
void azureArcManagedIdentity_MissingAuthHeader() throws Exception {
493516
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(ManagedIdentitySourceType.AZURE_ARC, azureArcEndpoint);
494517
ManagedIdentityApplication.setEnvironmentVariables(environmentVariables);
495-
ManagedIdentityClient.resetManagedIdentitySourceType();
496518
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);
497519

498520
HttpResponse response = new HttpResponse();
@@ -531,7 +553,6 @@ void azureArcManagedIdentity_MissingAuthHeader() throws Exception {
531553
void managedIdentity_SharedCache(ManagedIdentitySourceType source, String endpoint) throws Exception {
532554
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint);
533555
ManagedIdentityApplication.setEnvironmentVariables(environmentVariables);
534-
ManagedIdentityClient.resetManagedIdentitySourceType();
535556
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);
536557

537558
when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource)));
@@ -572,7 +593,6 @@ void managedIdentity_SharedCache(ManagedIdentitySourceType source, String endpoi
572593
void azureArcManagedIdentity_InvalidAuthHeader() throws Exception {
573594
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(ManagedIdentitySourceType.AZURE_ARC, azureArcEndpoint);
574595
ManagedIdentityApplication.setEnvironmentVariables(environmentVariables);
575-
ManagedIdentityClient.resetManagedIdentitySourceType();
576596
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);
577597

578598
HttpResponse response = new HttpResponse();
@@ -611,7 +631,6 @@ void azureArcManagedIdentity_InvalidAuthHeader() throws Exception {
611631
void azureArcManagedIdentityAuthheaderValidationTest() throws Exception {
612632
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(ManagedIdentitySourceType.AZURE_ARC, azureArcEndpoint);
613633
ManagedIdentityApplication.setEnvironmentVariables(environmentVariables);
614-
ManagedIdentityClient.resetManagedIdentitySourceType();
615634
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);
616635

617636
//Both a missing file and an invalid path structure should throw an exception

0 commit comments

Comments
 (0)