Skip to content

Commit 8479e03

Browse files
committed
Refactor retry logic and add special behavior for IMDS
1 parent b02c8f2 commit 8479e03

12 files changed

+248
-73
lines changed

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AbstractClientApplicationBase.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -568,7 +568,8 @@ public T correlationId(String val) {
568568
new TelemetryManager(telemetryConsumer, builder.onlySendFailureTelemetry),
569569
new HttpHelper(builder.httpClient == null ?
570570
new DefaultHttpClient(builder.proxy, builder.sslSocketFactory, builder.connectTimeoutForDefaultHttpClient, builder.readTimeoutForDefaultHttpClient) :
571-
builder.httpClient)
571+
builder.httpClient,
572+
new DefaultRetryPolicy())
572573
);
573574

574575
if (aadAadInstanceDiscoveryResponse != null) {
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
package com.microsoft.aad.msal4j;
5+
6+
/**
7+
* Default retry policy for most MSAL Java flows
8+
*/
9+
class DefaultRetryPolicy implements IRetryPolicy {
10+
private static final int RETRY_NUM = 1;
11+
private static int RETRY_DELAY_MS = 1000;
12+
13+
@Override
14+
public boolean isRetryable(IHttpResponse httpResponse) {
15+
return httpResponse.statusCode() >= 500 &&
16+
httpResponse.statusCode() < 600 &&
17+
HttpHelper.getRetryAfterHeader(httpResponse) == null;
18+
}
19+
20+
@Override
21+
public int getMaxRetryCount(IHttpResponse httpResponse) {
22+
return RETRY_NUM;
23+
}
24+
25+
@Override
26+
public int getRetryDelayMs(IHttpResponse httpResponse) {
27+
return RETRY_DELAY_MS;
28+
}
29+
30+
//Package-private methods to allow much quicker testing. The delay values should be treated as constants in any non-test scenario.
31+
static void setRetryDelayMs(int retryDelayMs) {
32+
RETRY_DELAY_MS = retryDelayMs;
33+
}
34+
35+
static void resetToDefaults() {
36+
RETRY_DELAY_MS = 1000;
37+
}
38+
}

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/HttpHelper.java

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,18 @@ class HttpHelper implements IHttpHelper {
1818

1919
private static final Logger log = LoggerFactory.getLogger(HttpHelper.class);
2020
public static final String RETRY_AFTER_HEADER = "Retry-After";
21-
public static final int RETRY_NUM = 2;
22-
public static final int RETRY_DELAY_MS = 1000;
2321

2422
public static final int HTTP_STATUS_200 = 200;
2523
public static final int HTTP_STATUS_400 = 400;
2624
public static final int HTTP_STATUS_429 = 429;
2725
public static final int HTTP_STATUS_500 = 500;
2826

2927
private IHttpClient httpClient;
28+
private IRetryPolicy retryPolicy;
3029

31-
HttpHelper(IHttpClient httpClient) {
30+
HttpHelper(IHttpClient httpClient, IRetryPolicy retryPolicy) {
3231
this.httpClient = httpClient;
32+
this.retryPolicy = retryPolicy != null ? retryPolicy : new DefaultRetryPolicy();
3333
}
3434

3535
public IHttpResponse executeHttpRequest(HttpRequest httpRequest,
@@ -141,20 +141,19 @@ private String getRequestThumbprint(RequestContext requestContext) {
141141
return StringHelper.createSha256Hash(sb.toString());
142142
}
143143

144-
boolean isRetryable(IHttpResponse httpResponse) {
145-
return httpResponse.statusCode() >= HTTP_STATUS_500 &&
146-
getRetryAfterHeader(httpResponse) == null;
147-
}
148-
149144
IHttpResponse executeHttpRequestWithRetries(HttpRequest httpRequest, IHttpClient httpClient)
150145
throws Exception {
151-
IHttpResponse httpResponse = null;
152-
for (int i = 0; i < RETRY_NUM; i++) {
146+
IHttpResponse httpResponse = httpClient.send(httpRequest);
147+
148+
int retryCount = 0;
149+
int maxRetries = retryPolicy.getMaxRetryCount(httpResponse);
150+
151+
while (retryPolicy.isRetryable(httpResponse) && retryCount < maxRetries) {
152+
Thread.sleep(retryPolicy.getRetryDelayMs(httpResponse));
153+
154+
retryCount++;
155+
153156
httpResponse = httpClient.send(httpRequest);
154-
if (!isRetryable(httpResponse)) {
155-
break;
156-
}
157-
Thread.sleep(RETRY_DELAY_MS);
158157
}
159158

160159
return httpResponse;
@@ -191,7 +190,7 @@ private void processThrottlingInstructions(IHttpResponse httpResponse, RequestCo
191190
}
192191
}
193192

194-
private Integer getRetryAfterHeader(IHttpResponse httpResponse) {
193+
static Integer getRetryAfterHeader(IHttpResponse httpResponse) {
195194

196195
if (httpResponse.headers() != null) {
197196
TreeMap<String, List<String>> headers = new TreeMap<>(String.CASE_INSENSITIVE_ORDER);
@@ -279,4 +278,8 @@ private static void verifyReturnedCorrelationId(final HttpRequest httpRequest,
279278
log.info(msg);
280279
}
281280
}
281+
282+
void setRetryPolicy(IRetryPolicy retryPolicy) {
283+
this.retryPolicy = retryPolicy;
284+
}
282285
}

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/HttpHelperManagedIdentity.java

Lines changed: 0 additions & 52 deletions
This file was deleted.

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/IMDSManagedIdentitySource.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@ public IMDSManagedIdentitySource(MsalRequest msalRequest,
3636
super(msalRequest, serviceBundle, ManagedIdentitySourceType.IMDS);
3737
IEnvironmentVariables environmentVariables = getEnvironmentVariables();
3838

39+
//IMDS uses a different retry policy than the default used in other MI flows
40+
IHttpHelper httpHelper = serviceBundle.getHttpHelper();
41+
if (httpHelper instanceof HttpHelper) {
42+
((HttpHelper) httpHelper).setRetryPolicy(new IMDSRetryPolicy());
43+
}
44+
3945
if (!StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(Constants.AZURE_POD_IDENTITY_AUTHORITY_HOST))){
4046
LOG.info(String.format("[Managed Identity] Environment variable AZURE_POD_IDENTITY_AUTHORITY_HOST for IMDS returned endpoint: %s", environmentVariables.getEnvironmentVariable(Constants.AZURE_POD_IDENTITY_AUTHORITY_HOST)));
4147
try {
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
package com.microsoft.aad.msal4j;
5+
6+
//IMDS uses a different try policy than other MI flows, see https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/main/docs/imds_retry_based_on_errors.md
7+
class IMDSRetryPolicy extends ManagedIdentityRetryPolicy {
8+
private static final int LINEAR_RETRY_NUM = 7;
9+
private static int LINEAR_RETRY_DELAY_MS = 10000; // 10 seconds
10+
private static final int EXPONENTIAL_RETRY_NUM = 3;
11+
private static int EXPONENTIAL_RETRY_DELAY_MS = 1000; // 1 second
12+
13+
private int currentRetryCount;
14+
private int lastStatusCode;
15+
16+
@Override
17+
public boolean isRetryable(IHttpResponse httpResponse) {
18+
currentRetryCount++;
19+
lastStatusCode = httpResponse.statusCode();
20+
21+
return (lastStatusCode >= 500 && lastStatusCode < 600) ||
22+
lastStatusCode == 404 || // Not Found
23+
lastStatusCode == 408 || // Request Timeout
24+
lastStatusCode == 410 || // Gone
25+
lastStatusCode == 429; // Too Many Requests
26+
}
27+
28+
@Override
29+
public int getMaxRetryCount(IHttpResponse httpResponse) {
30+
return (httpResponse.statusCode() == 410) ? LINEAR_RETRY_NUM : EXPONENTIAL_RETRY_NUM;
31+
}
32+
33+
@Override
34+
public int getRetryDelayMs(IHttpResponse httpResponse) {
35+
// Use exponential backoff for non-410 status codes
36+
if (lastStatusCode == 410) {
37+
return LINEAR_RETRY_DELAY_MS;
38+
} else {
39+
return (int) (Math.pow(2, currentRetryCount) * EXPONENTIAL_RETRY_DELAY_MS);
40+
}
41+
}
42+
43+
//Package-private methods to allow much quicker testing. The delay values should be treated as constants in any non-test scenario.
44+
static void setRetryDelayMs(int retryDelayMs) {
45+
LINEAR_RETRY_DELAY_MS = retryDelayMs;
46+
EXPONENTIAL_RETRY_DELAY_MS = retryDelayMs;
47+
}
48+
49+
static void resetToDefaults() {
50+
LINEAR_RETRY_DELAY_MS = 10000;
51+
EXPONENTIAL_RETRY_DELAY_MS = 1000;
52+
}
53+
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package com.microsoft.aad.msal4j;
2+
3+
/**
4+
* Interface for HTTP request retry policies
5+
*/
6+
interface IRetryPolicy {
7+
/**
8+
* Determines whether a request should be retried based on the HTTP response
9+
* @param httpResponse The HTTP response to evaluate
10+
* @return true if retry should be attempted, false otherwise
11+
*/
12+
boolean isRetryable(IHttpResponse httpResponse);
13+
14+
/**
15+
* Gets the number of retries to attempt based on the HTTP response
16+
* @return maximum retry count
17+
*/
18+
int getMaxRetryCount(IHttpResponse httpResponse);
19+
20+
/**
21+
* Gets the delay in milliseconds between retry attempts
22+
* @return delay in milliseconds
23+
*/
24+
int getRetryDelayMs(IHttpResponse httpResponse);
25+
}

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ManagedIdentityApplication.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,10 @@ private ManagedIdentityApplication(Builder builder) {
3737
super.serviceBundle = new ServiceBundle(
3838
builder.executorService,
3939
new TelemetryManager(telemetryConsumer, builder.onlySendFailureTelemetry),
40-
new HttpHelperManagedIdentity(builder.httpClient == null ?
40+
new HttpHelper(builder.httpClient == null ?
4141
new DefaultHttpClient(builder.proxy, builder.sslSocketFactory, builder.connectTimeoutForDefaultHttpClient, builder.readTimeoutForDefaultHttpClient) :
42-
builder.httpClient)
42+
builder.httpClient,
43+
new ManagedIdentityRetryPolicy())
4344
);
4445
log = LoggerFactory.getLogger(ManagedIdentityApplication.class);
4546

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
package com.microsoft.aad.msal4j;
5+
6+
/**
7+
* Retry policy for most Managed Identity scenarios
8+
*/
9+
class ManagedIdentityRetryPolicy implements IRetryPolicy {
10+
private static final int RETRY_NUM = 3;
11+
private static int RETRY_DELAY_MS = 1000;
12+
13+
@Override
14+
public boolean isRetryable(IHttpResponse httpResponse) {
15+
int statusCode = httpResponse.statusCode();
16+
17+
return statusCode == 404 || // Not Found
18+
statusCode == 408 || // Request Timeout
19+
statusCode == 429 || // Too Many Requests
20+
statusCode == 500 || // Internal Server Error
21+
statusCode == 503 || // Service Unavailable
22+
statusCode == 504; // Gateway Timeout
23+
}
24+
25+
@Override
26+
public int getMaxRetryCount(IHttpResponse httpResponse) {
27+
return RETRY_NUM;
28+
}
29+
30+
@Override
31+
public int getRetryDelayMs(IHttpResponse httpResponse) {
32+
return RETRY_DELAY_MS;
33+
}
34+
35+
//Package-private methods to allow much quicker testing. The delay values should be treated as constants in any non-test scenario.
36+
static void setRetryDelayMs(int retryDelayMs) {
37+
RETRY_DELAY_MS = retryDelayMs;
38+
}
39+
40+
static void resetToDefaults() {
41+
RETRY_DELAY_MS = 1000;
42+
}
43+
}

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ServiceFabricManagedIdentitySource.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class ServiceFabricManagedIdentitySource extends AbstractManagedIdentitySource {
2525
//No other flow need this and an app developer may not be aware of it, so it was decided that for the Service Fabric flow we will simply override
2626
// any HttpClient that may have been set by the app developer with our own client which performs the validation logic.
2727
private static IHttpClient httpClient = new DefaultHttpClientManagedIdentity(null, null, null, null);
28-
private static HttpHelper httpHelper = new HttpHelperManagedIdentity(httpClient);
28+
private static HttpHelper httpHelper = new HttpHelper(httpClient, new ManagedIdentityRetryPolicy());
2929

3030
@Override
3131
public void createManagedIdentityRequest(String resource) {
@@ -122,6 +122,6 @@ private static URI validateAndGetUri(String msiEndpoint)
122122
//However, unit tests often need to mock HttpClient and need a way to inject the mocked object into this class.
123123
static void setHttpClient(IHttpClient client) {
124124
httpClient = client;
125-
httpHelper = new HttpHelperManagedIdentity(httpClient);
125+
httpHelper = new HttpHelper(httpClient, new ManagedIdentityRetryPolicy());
126126
}
127127
}

0 commit comments

Comments
 (0)