Skip to content

Commit 9d8ab0c

Browse files
authored
Merge pull request #791 from AzureAD/avdunn/cert-service-fabric
Validate Service Fabric certs
2 parents 1f16b0c + 8e2945d commit 9d8ab0c

File tree

5 files changed

+256
-11
lines changed

5 files changed

+256
-11
lines changed

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/DefaultHttpClient.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@
2121
import java.util.Map;
2222

2323
class DefaultHttpClient implements IHttpClient {
24-
private final static Logger LOG = LoggerFactory.getLogger(DefaultHttpClient.class);
24+
private static final Logger LOG = LoggerFactory.getLogger(DefaultHttpClient.class);
2525

26-
private final Proxy proxy;
27-
private final SSLSocketFactory sslSocketFactory;
26+
final Proxy proxy;
27+
final SSLSocketFactory sslSocketFactory;
2828

2929
//By default, rely on the timeout behavior of the services requests are sent to
30-
private int connectTimeout = 0;
31-
private int readTimeout = 0;
30+
int connectTimeout = 0;
31+
int readTimeout = 0;
3232

3333
DefaultHttpClient(Proxy proxy, SSLSocketFactory sslSocketFactory, Integer connectTimeout, Integer readTimeout) {
3434
this.proxy = proxy;
@@ -77,7 +77,7 @@ private HttpResponse executeHttpPost(HttpRequest httpRequest) throws Exception {
7777
}
7878
}
7979

80-
private HttpURLConnection openConnection(final URL finalURL)
80+
HttpURLConnection openConnection(final URL finalURL)
8181
throws IOException {
8282
URLConnection connection;
8383

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
package com.microsoft.aad.msal4j;
5+
6+
import javax.net.ssl.HostnameVerifier;
7+
import javax.net.ssl.HttpsURLConnection;
8+
import javax.net.ssl.SSLSocketFactory;
9+
import java.io.IOException;
10+
import java.net.HttpURLConnection;
11+
import java.net.Proxy;
12+
import java.net.URL;
13+
import java.net.URLConnection;
14+
import javax.net.ssl.SSLContext;
15+
import javax.net.ssl.SSLSession;
16+
import javax.net.ssl.TrustManager;
17+
import javax.net.ssl.X509TrustManager;
18+
import java.security.KeyManagementException;
19+
import java.security.MessageDigest;
20+
import java.security.NoSuchAlgorithmException;
21+
import java.security.cert.Certificate;
22+
import java.security.cert.CertificateEncodingException;
23+
import java.security.cert.CertificateException;
24+
import java.security.cert.X509Certificate;
25+
26+
/** An extension for the default HttpClient which is meant to perform any extra HTTP behavior needed for a managed identity flow.
27+
* <p>
28+
* Currently the only extra behavior is the Service Fabric flow, where we must add a certificate thumbprint to the HTTP connection.
29+
*/
30+
class DefaultHttpClientManagedIdentity extends DefaultHttpClient {
31+
32+
public static final HostnameVerifier ALL_HOSTS_ACCEPT_HOSTNAME_VERIFIER;
33+
34+
static {
35+
ALL_HOSTS_ACCEPT_HOSTNAME_VERIFIER = new HostnameVerifier() {
36+
@SuppressWarnings("BadHostnameVerifier")
37+
@Override
38+
public boolean verify(String hostname, SSLSession session) {
39+
return true;
40+
}
41+
};
42+
}
43+
44+
DefaultHttpClientManagedIdentity(Proxy proxy, SSLSocketFactory sslSocketFactory, Integer connectTimeout, Integer readTimeout) {
45+
super(proxy, sslSocketFactory, connectTimeout, readTimeout);
46+
}
47+
48+
@Override
49+
HttpURLConnection openConnection(final URL finalURL)
50+
throws IOException {
51+
URLConnection connection;
52+
53+
if (proxy != null) {
54+
connection = finalURL.openConnection(proxy);
55+
} else {
56+
connection = finalURL.openConnection();
57+
}
58+
59+
connection.setConnectTimeout(connectTimeout);
60+
connection.setReadTimeout(readTimeout);
61+
62+
if (connection instanceof HttpURLConnection) {
63+
return (HttpURLConnection) connection;
64+
} else {
65+
HttpsURLConnection httpsConnection = (HttpsURLConnection) connection;
66+
67+
if (sslSocketFactory != null) {
68+
httpsConnection.setSSLSocketFactory(sslSocketFactory);
69+
}
70+
71+
if (System.getenv(Constants.IDENTITY_SERVER_THUMBPRINT) != null) {
72+
addTrustedCertificateThumbprint(httpsConnection, System.getenv(Constants.IDENTITY_SERVER_THUMBPRINT));
73+
}
74+
75+
return httpsConnection;
76+
}
77+
}
78+
79+
/**
80+
*
81+
* Pins the specified HTTPS URL Connection to work against a specific server-side certificate with
82+
* the specified thumbprint only.
83+
*
84+
* @param httpsUrlConnection The https url connection to configure
85+
* @param certificateThumbprint The thumbprint of the certificate
86+
*/
87+
public static void addTrustedCertificateThumbprint(HttpsURLConnection httpsUrlConnection,
88+
String certificateThumbprint) {
89+
//We expect the connection to work against a specific server side certificate only, so it's safe to disable the
90+
// host name verification.
91+
if (httpsUrlConnection.getHostnameVerifier() != ALL_HOSTS_ACCEPT_HOSTNAME_VERIFIER) {
92+
httpsUrlConnection.setHostnameVerifier(ALL_HOSTS_ACCEPT_HOSTNAME_VERIFIER);
93+
}
94+
95+
// Create a Trust manager that trusts only certificate with specified thumbprint.
96+
TrustManager[] certificateTrust = new TrustManager[]{new X509TrustManager() {
97+
public X509Certificate[] getAcceptedIssuers() {
98+
return new X509Certificate[]{};
99+
}
100+
101+
public void checkClientTrusted(X509Certificate[] certificates, String authenticationType)
102+
throws CertificateException {
103+
throw new CertificateException("No client side certificate configured.");
104+
}
105+
106+
public void checkServerTrusted(X509Certificate[] certificates, String authenticationType)
107+
throws CertificateException {
108+
if (certificates == null || certificates.length == 0) {
109+
throw new CertificateException("Did not receive any certificate from the server.");
110+
}
111+
112+
for (X509Certificate x509Certificate : certificates) {
113+
String sslCertificateThumbprint = extractCertificateThumbprint(x509Certificate);
114+
if (certificateThumbprint.equalsIgnoreCase(sslCertificateThumbprint)) {
115+
return;
116+
}
117+
}
118+
throw new RuntimeException("Thumbprint of certificates received did not match the expected thumbprint.");
119+
}
120+
}
121+
};
122+
123+
SSLSocketFactory sslSocketFactory;
124+
try {
125+
SSLContext sslContext = SSLContext.getInstance("TLS");
126+
sslContext.init(null, certificateTrust, null);
127+
sslSocketFactory = sslContext.getSocketFactory();
128+
} catch (NoSuchAlgorithmException | KeyManagementException e) {
129+
throw new RuntimeException("Error Creating SSL Context", e);
130+
}
131+
132+
// Pin the connection to a specific certificate with specified thumbprint.
133+
if (httpsUrlConnection.getSSLSocketFactory() != sslSocketFactory) {
134+
httpsUrlConnection.setSSLSocketFactory(sslSocketFactory);
135+
}
136+
}
137+
138+
private static String extractCertificateThumbprint(Certificate certificate) {
139+
try {
140+
StringBuilder thumbprint = new StringBuilder();
141+
MessageDigest messageDigest = MessageDigest.getInstance("SHA-1");
142+
143+
byte[] encodedCertificate;
144+
145+
try {
146+
encodedCertificate = certificate.getEncoded();
147+
} catch (CertificateEncodingException e) {
148+
throw new RuntimeException(e);
149+
}
150+
151+
byte[] updatedDigest = messageDigest.digest(encodedCertificate);
152+
153+
for (byte b : updatedDigest) {
154+
int unsignedByte = b & 0xff;
155+
156+
if (unsignedByte < 16) {
157+
thumbprint.append("0");
158+
}
159+
thumbprint.append(Integer.toHexString(unsignedByte));
160+
}
161+
return thumbprint.toString();
162+
} catch (NoSuchAlgorithmException e) {
163+
throw new MsalClientException("NoSuchAlgorithmException when extracting certificate thumbprint: ", e.getMessage());
164+
}
165+
}
166+
167+
}

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/HttpHelper.java

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,44 @@ public IHttpResponse executeHttpRequest(HttpRequest httpRequest,
6767
return httpResponse;
6868
}
6969

70+
//Overloaded version of the more commonly used HTTP executor. It does not use ServiceBundle, allowing an HTTP call to be
71+
// made only with more bespoke request-level parameters rather than those from the app-level ServiceBundle
72+
IHttpResponse executeHttpRequest(HttpRequest httpRequest,
73+
RequestContext requestContext,
74+
TelemetryManager telemetryManager,
75+
IHttpClient httpClient) {
76+
checkForThrottling(requestContext);
77+
78+
HttpEvent httpEvent = new HttpEvent(); // for tracking http telemetry
79+
IHttpResponse httpResponse;
80+
81+
try (TelemetryHelper telemetryHelper = telemetryManager.createTelemetryHelper(
82+
requestContext.telemetryRequestId(),
83+
requestContext.clientId(),
84+
httpEvent,
85+
false)) {
86+
87+
addRequestInfoToTelemetry(httpRequest, httpEvent);
88+
89+
try {
90+
httpResponse = executeHttpRequestWithRetries(httpRequest, httpClient);
91+
92+
} catch (Exception e) {
93+
httpEvent.setOauthErrorCode(AuthenticationErrorCode.UNKNOWN);
94+
throw new MsalClientException(e);
95+
}
96+
97+
addResponseInfoToTelemetry(httpResponse, httpEvent);
98+
99+
if (httpResponse.headers() != null) {
100+
HttpHelper.verifyReturnedCorrelationId(httpRequest, httpResponse);
101+
}
102+
}
103+
processThrottlingInstructions(httpResponse, requestContext);
104+
105+
return httpResponse;
106+
}
107+
70108
private String getRequestThumbprint(RequestContext requestContext) {
71109
StringBuilder sb = new StringBuilder();
72110
sb.append(requestContext.clientId() + POINT_DELIMITER);

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ServiceFabricManagedIdentitySource.java

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import org.slf4j.Logger;
77
import org.slf4j.LoggerFactory;
88

9+
import java.net.SocketException;
910
import java.net.URI;
1011
import java.net.URISyntaxException;
1112
import java.util.Collections;
@@ -22,6 +23,12 @@ class ServiceFabricManagedIdentitySource extends AbstractManagedIdentitySource {
2223
private final ManagedIdentityIdType idType;
2324
private final String userAssignedId;
2425

26+
//Service Fabric requires a special check for an environment variable containing a certificate thumbprint used for validating requests.
27+
//No other flow need this and an app developer may not be aware of it, so it was decided that for the Service Fabric flow we will simply override
28+
// any HttpClient that may have been set by the app developer with our own client which performs the validation logic.
29+
private final IHttpClient httpClient = new DefaultHttpClientManagedIdentity(null, null, null, null);
30+
private final HttpHelper httpHelper = new HttpHelper(httpClient);
31+
2532
@Override
2633
public void createManagedIdentityRequest(String resource) {
2734
managedIdentityRequest.baseEndpoint = msiEndpoint;
@@ -53,21 +60,54 @@ private ServiceFabricManagedIdentitySource(MsalRequest msalRequest, ServiceBundl
5360
this.userAssignedId = ((ManagedIdentityApplication) msalRequest.application()).getManagedIdentityId().getUserAssignedId();
5461
}
5562

63+
@Override
64+
public ManagedIdentityResponse getManagedIdentityResponse(
65+
ManagedIdentityParameters parameters) {
66+
67+
createManagedIdentityRequest(parameters.resource);
68+
IHttpResponse response;
69+
70+
try {
71+
72+
HttpRequest httpRequest = managedIdentityRequest.method.equals(HttpMethod.GET) ?
73+
new HttpRequest(HttpMethod.GET,
74+
managedIdentityRequest.computeURI().toString(),
75+
managedIdentityRequest.headers) :
76+
new HttpRequest(HttpMethod.POST,
77+
managedIdentityRequest.computeURI().toString(),
78+
managedIdentityRequest.headers,
79+
managedIdentityRequest.getBodyAsString());
80+
81+
response = httpHelper.executeHttpRequest(httpRequest, managedIdentityRequest.requestContext(), serviceBundle.getTelemetryManager(),
82+
httpClient);
83+
} catch (URISyntaxException e) {
84+
throw new RuntimeException(e);
85+
} catch (MsalClientException e) {
86+
if (e.getCause() instanceof SocketException) {
87+
throw new MsalServiceException(e.getMessage(), MsalError.MANAGED_IDENTITY_UNREACHABLE_NETWORK, managedIdentitySourceType);
88+
}
89+
90+
throw e;
91+
}
92+
93+
return handleResponse(parameters, response);
94+
}
95+
5696
static AbstractManagedIdentitySource create(MsalRequest msalRequest, ServiceBundle serviceBundle) {
5797

5898
IEnvironmentVariables environmentVariables = getEnvironmentVariables((ManagedIdentityParameters) msalRequest.requestContext().apiParameters());
59-
String msiEndpoint = environmentVariables.getEnvironmentVariable(Constants.MSI_ENDPOINT);
60-
String identityHeader = environmentVariables.getEnvironmentVariable(Constants.IDENTITY_ENDPOINT);
99+
String identityEndpoint = environmentVariables.getEnvironmentVariable(Constants.IDENTITY_ENDPOINT);
100+
String identityHeader = environmentVariables.getEnvironmentVariable(Constants.IDENTITY_HEADER);
61101
String identityServerThumbprint = environmentVariables.getEnvironmentVariable(Constants.IDENTITY_SERVER_THUMBPRINT);
62102

63103

64-
if (StringHelper.isNullOrBlank(msiEndpoint) || StringHelper.isNullOrBlank(identityHeader) || StringHelper.isNullOrBlank(identityServerThumbprint))
104+
if (StringHelper.isNullOrBlank(identityEndpoint) || StringHelper.isNullOrBlank(identityHeader) || StringHelper.isNullOrBlank(identityServerThumbprint))
65105
{
66106
LOG.info("[Managed Identity] Service fabric managed identity is unavailable.");
67107
return null;
68108
}
69109

70-
return new ServiceFabricManagedIdentitySource(msalRequest, serviceBundle, validateAndGetUri(msiEndpoint), identityHeader);
110+
return new ServiceFabricManagedIdentitySource(msalRequest, serviceBundle, validateAndGetUri(identityEndpoint), identityHeader);
71111
}
72112

73113
private static URI validateAndGetUri(String msiEndpoint)

msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ void managedIdentityTest_Retry(ManagedIdentitySourceType source, String endpoint
360360
return;
361361
}
362362

363-
fail("MsalManagedIdentityException is expected but not thrown.");
363+
fail("MsalServiceException is expected but not thrown.");
364364
}
365365

366366
@ParameterizedTest

0 commit comments

Comments
 (0)