Skip to content

Commit 90af4d5

Browse files
authored
Merge pull request #885 from blancqua/blancqua/www-auth-fix
Fix Www-Authenticate not being read
2 parents a8ff743 + 85846f8 commit 90af4d5

File tree

2 files changed

+105
-128
lines changed

2 files changed

+105
-128
lines changed

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AzureArcManagedIdentitySource.java

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414
import java.nio.file.Files;
1515
import java.nio.file.Path;
1616
import java.nio.file.Paths;
17-
import java.util.Collections;
18-
import java.util.HashMap;
17+
import java.util.*;
1918

2019
class AzureArcManagedIdentitySource extends AbstractManagedIdentitySource{
2120

@@ -26,6 +25,7 @@ class AzureArcManagedIdentitySource extends AbstractManagedIdentitySource{
2625
private static final String LINUX_PATH = "/var/opt/azcmagent/tokens/";
2726
private static final String FILE_EXTENSION = ".key";
2827
private static final int MAX_FILE_SIZE_BYTES = 4096;
28+
private static final String WWW_AUTHENTICATE_HEADER = "WWW-Authenticate";
2929

3030
private final URI MSI_ENDPOINT;
3131

@@ -92,20 +92,20 @@ public ManagedIdentityResponse handleResponse(
9292
ManagedIdentityParameters parameters,
9393
IHttpResponse response) {
9494

95-
LOG.info("[Managed Identity] Response received. Status code: {response.StatusCode}");
95+
LOG.info("[Managed Identity] Response received. Status code: {}", response.statusCode());
9696

9797
if (response.statusCode() == HttpURLConnection.HTTP_UNAUTHORIZED) {
98-
if(!response.headers().containsKey("WWW-Authenticate")) {
99-
LOG.error("[Managed Identity] WWW-Authenticate header is expected but not found.");
100-
throw new MsalServiceException(MsalErrorMessage.MANAGED_IDENTITY_NO_CHALLENGE_ERROR, MsalError.MANAGED_IDENTITY_REQUEST_FAILED,
101-
ManagedIdentitySourceType.AZURE_ARC);
102-
}
103-
104-
String challenge = response.headers().get("WWW-Authenticate").get(0);
98+
String challenge =
99+
readChallengeFrom(response)
100+
.orElseGet(() -> {
101+
LOG.error("[Managed Identity] {} is expected but not found.", WWW_AUTHENTICATE_HEADER);
102+
throw new MsalServiceException(MsalErrorMessage.MANAGED_IDENTITY_NO_CHALLENGE_ERROR, MsalError.MANAGED_IDENTITY_REQUEST_FAILED,
103+
ManagedIdentitySourceType.AZURE_ARC);
104+
});
105105
String[] splitChallenge = challenge.split("=");
106106

107107
if (splitChallenge.length != 2) {
108-
LOG.error("[Managed Identity] The WWW-Authenticate header for Azure arc managed identity is not an expected format.");
108+
LOG.error("[Managed Identity] The {} header for Azure arc managed identity is not an expected format.", WWW_AUTHENTICATE_HEADER);
109109
throw new MsalServiceException(MsalErrorMessage.MANAGED_IDENTITY_INVALID_CHALLENGE, MsalError.MANAGED_IDENTITY_REQUEST_FAILED,
110110
ManagedIdentitySourceType.AZURE_ARC);
111111
}
@@ -150,6 +150,16 @@ public ManagedIdentityResponse handleResponse(
150150
return super.handleResponse(parameters, response);
151151
}
152152

153+
private Optional<String> readChallengeFrom(IHttpResponse response) {
154+
return response.headers()
155+
.entrySet()
156+
.stream()
157+
.filter(entry -> WWW_AUTHENTICATE_HEADER.equalsIgnoreCase(entry.getKey()))
158+
.map(Map.Entry::getValue)
159+
.flatMap(Collection::stream)
160+
.findFirst();
161+
}
162+
153163
private void validateFile(Path path) {
154164
String osName = System.getProperty("os.name").toLowerCase();
155165
if (!(osName.contains("windows") || osName.contains("linux"))) {
@@ -170,7 +180,7 @@ private void validateFile(Path path) {
170180
ManagedIdentitySourceType.AZURE_ARC);
171181
}
172182

173-
LOG.error("[Managed Identity] Path passed validation.");
183+
LOG.info("[Managed Identity] Path passed validation.");
174184
}
175185

176186
private boolean isValidWindowsPath(Path path) {

msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTests.java

Lines changed: 83 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -4,29 +4,29 @@
44
package com.microsoft.aad.msal4j;
55

66
import com.nimbusds.oauth2.sdk.util.URLUtils;
7-
import org.apache.http.HttpStatus;
8-
import org.junit.jupiter.api.BeforeAll;
9-
import org.junit.jupiter.api.BeforeEach;
7+
import org.junit.jupiter.api.Nested;
108
import org.junit.jupiter.api.Test;
119
import org.junit.jupiter.api.TestInstance;
1210
import org.junit.jupiter.api.extension.ExtendWith;
1311
import org.junit.jupiter.params.ParameterizedTest;
1412
import org.junit.jupiter.params.provider.MethodSource;
13+
import org.junit.jupiter.params.provider.ValueSource;
1514
import org.mockito.junit.jupiter.MockitoExtension;
1615

1716
import java.net.SocketException;
18-
import java.net.URISyntaxException;
1917
import java.nio.file.Path;
2018
import java.nio.file.Paths;
21-
import java.time.Instant;
22-
import java.time.temporal.ChronoUnit;
23-
import java.util.Collections;
2419
import java.util.HashMap;
2520
import java.util.List;
2621
import java.util.Map;
2722
import java.util.concurrent.CompletableFuture;
2823
import java.util.concurrent.ExecutionException;
2924

25+
import static com.microsoft.aad.msal4j.ManagedIdentitySourceType.*;
26+
import static com.microsoft.aad.msal4j.MsalError.*;
27+
import static com.microsoft.aad.msal4j.MsalErrorMessage.*;
28+
import static java.util.Collections.*;
29+
import static org.apache.http.HttpStatus.*;
3030
import static org.junit.jupiter.api.Assertions.*;
3131
import static org.mockito.Mockito.*;
3232

@@ -77,8 +77,8 @@ private HttpRequest expectedRequest(ManagedIdentitySourceType source, String res
7777
case APP_SERVICE: {
7878
endpoint = appServiceEndpoint;
7979

80-
queryParameters.put("api-version", Collections.singletonList("2019-08-01"));
81-
queryParameters.put("resource", Collections.singletonList(resource));
80+
queryParameters.put("api-version", singletonList("2019-08-01"));
81+
queryParameters.put("resource", singletonList(resource));
8282

8383
headers.put("X-IDENTITY-HEADER", "secret");
8484
break;
@@ -89,43 +89,43 @@ private HttpRequest expectedRequest(ManagedIdentitySourceType source, String res
8989
headers.put("ContentType", "application/x-www-form-urlencoded");
9090
headers.put("Metadata", "true");
9191

92-
bodyParameters.put("resource", Collections.singletonList(resource));
92+
bodyParameters.put("resource", singletonList(resource));
9393

94-
queryParameters.put("resource", Collections.singletonList(resource));
94+
queryParameters.put("resource", singletonList(resource));
9595
return new HttpRequest(HttpMethod.GET, computeUri(endpoint, queryParameters), headers, URLUtils.serializeParameters(bodyParameters));
9696
}
9797
case IMDS: {
9898
endpoint = IMDS_ENDPOINT;
99-
queryParameters.put("api-version", Collections.singletonList("2018-02-01"));
100-
queryParameters.put("resource", Collections.singletonList(resource));
99+
queryParameters.put("api-version", singletonList("2018-02-01"));
100+
queryParameters.put("resource", singletonList(resource));
101101
headers.put("Metadata", "true");
102102
break;
103103
}
104104
case AZURE_ARC: {
105105
endpoint = azureArcEndpoint;
106106

107-
queryParameters.put("api-version", Collections.singletonList("2019-11-01"));
108-
queryParameters.put("resource", Collections.singletonList(resource));
107+
queryParameters.put("api-version", singletonList("2019-11-01"));
108+
queryParameters.put("resource", singletonList(resource));
109109

110110
headers.put("Metadata", "true");
111111
break;
112112
}
113113
case SERVICE_FABRIC:
114114
endpoint = serviceFabricEndpoint;
115-
queryParameters.put("api-version", Collections.singletonList("2019-07-01-preview"));
116-
queryParameters.put("resource", Collections.singletonList(resource));
115+
queryParameters.put("api-version", singletonList("2019-07-01-preview"));
116+
queryParameters.put("resource", singletonList(resource));
117117
break;
118118
}
119119

120120
switch (id.getIdType()) {
121121
case CLIENT_ID:
122-
queryParameters.put("client_id", Collections.singletonList(id.getUserAssignedId()));
122+
queryParameters.put("client_id", singletonList(id.getUserAssignedId()));
123123
break;
124124
case RESOURCE_ID:
125-
queryParameters.put("mi_res_id", Collections.singletonList(id.getUserAssignedId()));
125+
queryParameters.put("mi_res_id", singletonList(id.getUserAssignedId()));
126126
break;
127127
case OBJECT_ID:
128-
queryParameters.put("object_id", Collections.singletonList(id.getUserAssignedId()));
128+
queryParameters.put("object_id", singletonList(id.getUserAssignedId()));
129129
break;
130130
}
131131

@@ -511,43 +511,6 @@ void managedIdentity_RequestFailed_UnreachableNetwork(ManagedIdentitySourceType
511511
verify(httpClientMock, times(1)).send(any());
512512
}
513513

514-
@Test
515-
void azureArcManagedIdentity_MissingAuthHeader() throws Exception {
516-
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(ManagedIdentitySourceType.AZURE_ARC, azureArcEndpoint);
517-
ManagedIdentityApplication.setEnvironmentVariables(environmentVariables);
518-
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);
519-
520-
HttpResponse response = new HttpResponse();
521-
response.statusCode(HttpStatus.SC_UNAUTHORIZED);
522-
523-
when(httpClientMock.send(any())).thenReturn(response);
524-
525-
miApp = ManagedIdentityApplication
526-
.builder(ManagedIdentityId.systemAssigned())
527-
.httpClient(httpClientMock)
528-
.build();
529-
530-
// Clear caching to avoid cross test pollution.
531-
miApp.tokenCache().accessTokens.clear();
532-
533-
try {
534-
miApp.acquireTokenForManagedIdentity(
535-
ManagedIdentityParameters.builder(resource)
536-
.build()).get();
537-
} catch (Exception exception) {
538-
assert(exception.getCause() instanceof MsalServiceException);
539-
540-
MsalServiceException miException = (MsalServiceException) exception.getCause();
541-
assertEquals(ManagedIdentitySourceType.AZURE_ARC.name(), miException.managedIdentitySource());
542-
assertEquals(MsalError.MANAGED_IDENTITY_REQUEST_FAILED, miException.errorCode());
543-
assertEquals(MsalErrorMessage.MANAGED_IDENTITY_NO_CHALLENGE_ERROR, miException.getMessage());
544-
return;
545-
}
546-
547-
fail("MsalServiceException is expected but not thrown.");
548-
verify(httpClientMock, times(1)).send(any());
549-
}
550-
551514
@ParameterizedTest
552515
@MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataError")
553516
void managedIdentity_SharedCache(ManagedIdentitySourceType source, String endpoint) throws Exception {
@@ -589,81 +552,85 @@ void managedIdentity_SharedCache(ManagedIdentitySourceType source, String endpoi
589552
verify(httpClientMock, times(1)).send(any());
590553
}
591554

592-
@Test
593-
void azureArcManagedIdentity_InvalidAuthHeader() throws Exception {
594-
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(ManagedIdentitySourceType.AZURE_ARC, azureArcEndpoint);
595-
ManagedIdentityApplication.setEnvironmentVariables(environmentVariables);
596-
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);
597-
598-
HttpResponse response = new HttpResponse();
599-
response.statusCode(HttpStatus.SC_UNAUTHORIZED);
600-
response.headers().put("WWW-Authenticate", Collections.singletonList("xyz"));
555+
@Nested
556+
class AzureArc {
601557

602-
when(httpClientMock.send(any())).thenReturn(response);
558+
@Test
559+
void missingAuthHeader() throws Exception {
560+
mockHttpResponse(emptyMap());
603561

604-
miApp = ManagedIdentityApplication
605-
.builder(ManagedIdentityId.systemAssigned())
606-
.httpClient(httpClientMock)
607-
.build();
608-
609-
// Clear caching to avoid cross test pollution.
610-
miApp.tokenCache().accessTokens.clear();
562+
assertMsalServiceException(MANAGED_IDENTITY_REQUEST_FAILED, MANAGED_IDENTITY_NO_CHALLENGE_ERROR);
563+
}
611564

612-
try {
613-
miApp.acquireTokenForManagedIdentity(
614-
ManagedIdentityParameters.builder(resource)
615-
.build()).get();
616-
} catch (Exception exception) {
617-
assert(exception.getCause() instanceof MsalServiceException);
565+
@ParameterizedTest
566+
@ValueSource(strings = {"WWW-Authenticate", "Www-Authenticate"})
567+
void invalidAuthHeader(String authHeaderKey) throws Exception {
568+
mockHttpResponse(singletonMap(authHeaderKey, singletonList("xyz")));
618569

619-
MsalServiceException miException = (MsalServiceException) exception.getCause();
620-
assertEquals(ManagedIdentitySourceType.AZURE_ARC.name(), miException.managedIdentitySource());
621-
assertEquals(MsalError.MANAGED_IDENTITY_REQUEST_FAILED, miException.errorCode());
622-
assertEquals(MsalErrorMessage.MANAGED_IDENTITY_INVALID_CHALLENGE, miException.getMessage());
623-
return;
570+
assertMsalServiceException(MANAGED_IDENTITY_REQUEST_FAILED,
571+
MANAGED_IDENTITY_INVALID_CHALLENGE);
624572
}
625573

626-
fail("MsalServiceException is expected but not thrown.");
627-
verify(httpClientMock, times(1)).send(any());
628-
}
574+
@ParameterizedTest
575+
@ValueSource(strings = {"WWW-Authenticate", "Www-Authenticate"})
576+
void validPathWithMissingFile(String authHeaderKey)
577+
throws Exception {
578+
Path validPathWithMissingFile = Paths.get(
579+
System.getenv("ProgramData") + "/AzureConnectedMachineAgent/Tokens/secret.key");
629580

630-
@Test
631-
void azureArcManagedIdentityAuthheaderValidationTest() throws Exception {
632-
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(ManagedIdentitySourceType.AZURE_ARC, azureArcEndpoint);
633-
ManagedIdentityApplication.setEnvironmentVariables(environmentVariables);
634-
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);
581+
mockHttpResponse(singletonMap(authHeaderKey, singletonList("Basic realm=" + validPathWithMissingFile)));
635582

636-
//Both a missing file and an invalid path structure should throw an exception
637-
Path validPathWithMissingFile = Paths.get(System.getenv("ProgramData")+ "/AzureConnectedMachineAgent/Tokens/secret.key");
638-
Path invalidPathWithRealFile = Paths.get(this.getClass().getResource("/msi-azure-arc-secret.txt").toURI());
583+
assertMsalServiceException(MANAGED_IDENTITY_FILE_READ_ERROR,
584+
MANAGED_IDENTITY_INVALID_FILEPATH);
585+
}
639586

640-
// Mock 401 response that returns WWW-Authenticate header
641-
HttpResponse response = new HttpResponse();
642-
response.statusCode(HttpStatus.SC_UNAUTHORIZED);
643-
response.headers().put("WWW-Authenticate", Collections.singletonList("Basic realm=" + validPathWithMissingFile));
587+
@ParameterizedTest
588+
@ValueSource(strings = {"WWW-Authenticate", "Www-Authenticate"})
589+
void invalidPathWithRealFile(String authHeaderKey)
590+
throws Exception {
591+
Path invalidPathWithRealFile = Paths.get(
592+
this.getClass().getResource("/msi-azure-arc-secret.txt").toURI());
644593

645-
when(httpClientMock.send(expectedRequest(ManagedIdentitySourceType.AZURE_ARC, resource))).thenReturn(response);
594+
mockHttpResponse(singletonMap(authHeaderKey, singletonList("Basic realm=" + invalidPathWithRealFile)));
646595

647-
miApp = ManagedIdentityApplication
648-
.builder(ManagedIdentityId.systemAssigned())
649-
.httpClient(httpClientMock)
650-
.build();
596+
assertMsalServiceException(MANAGED_IDENTITY_FILE_READ_ERROR,
597+
MANAGED_IDENTITY_INVALID_FILEPATH);
598+
}
651599

652-
// Clear caching to avoid cross test pollution.
653-
miApp.tokenCache().accessTokens.clear();
600+
private void mockHttpResponse(Map<String, ? extends List<String>> responseHeaders) throws Exception {
601+
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(AZURE_ARC, azureArcEndpoint);
602+
ManagedIdentityApplication.setEnvironmentVariables(environmentVariables);
603+
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);
654604

655-
CompletableFuture<IAuthenticationResult> future = miApp.acquireTokenForManagedIdentity(ManagedIdentityParameters.builder(resource).build());
605+
HttpResponse response = new HttpResponse();
606+
response.statusCode(SC_UNAUTHORIZED);
607+
response.headers().putAll(responseHeaders);
656608

657-
ExecutionException ex = assertThrows(ExecutionException.class, future::get);
658-
assertTrue(ex.getCause() instanceof MsalServiceException);
659-
assertTrue(ex.getMessage().contains(MsalErrorMessage.MANAGED_IDENTITY_INVALID_FILEPATH));
609+
when(httpClientMock.send(
610+
expectedRequest(AZURE_ARC, resource))).thenReturn(
611+
response);
660612

661-
response.headers().put("WWW-Authenticate", Collections.singletonList("Basic realm=" + invalidPathWithRealFile));
613+
miApp = ManagedIdentityApplication
614+
.builder(ManagedIdentityId.systemAssigned())
615+
.httpClient(httpClientMock)
616+
.build();
662617

663-
future = miApp.acquireTokenForManagedIdentity(ManagedIdentityParameters.builder(resource).build());
618+
// Clear caching to avoid cross test pollution.
619+
miApp.tokenCache().accessTokens.clear();
620+
}
664621

665-
ex = assertThrows(ExecutionException.class, future::get);
666-
assertTrue(ex.getCause() instanceof MsalServiceException);
667-
assertTrue(ex.getMessage().contains(MsalErrorMessage.MANAGED_IDENTITY_INVALID_FILEPATH));
622+
private void assertMsalServiceException(String errorCode, String message) throws Exception {
623+
CompletableFuture<IAuthenticationResult> future =
624+
miApp.acquireTokenForManagedIdentity(
625+
ManagedIdentityParameters.builder(resource).build());
626+
627+
ExecutionException ex = assertThrows(ExecutionException.class, future::get);
628+
assertInstanceOf(MsalServiceException.class, ex.getCause());
629+
MsalServiceException msalException = (MsalServiceException) ex.getCause();
630+
assertEquals(AZURE_ARC.name(),
631+
msalException.managedIdentitySource());
632+
assertEquals(errorCode, msalException.errorCode());
633+
assertTrue(ex.getMessage().contains(message));
634+
}
668635
}
669636
}

0 commit comments

Comments
 (0)