Skip to content

Let users configure HTTP client #129

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Nov 22, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.microsoft.aad.msal4j;

import org.apache.http.Header;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.methods.HttpRequestBase;
import org.apache.http.entity.ContentType;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClients;
import org.apache.http.util.EntityUtils;

import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

class ApacheHttpClientAdapter implements IHttpClient {

private CloseableHttpClient httpClient;

ApacheHttpClientAdapter(){
this.httpClient = HttpClients.createDefault();
}

@Override
public IHttpResponse send(HttpRequest httpRequest) throws Exception {

HttpRequestBase request = buildApacheRequestFromMsalRequest(httpRequest);
CloseableHttpResponse response = httpClient.execute(request);

return buildMsalResponseFromApacheResponse(response);
}


private HttpRequestBase buildApacheRequestFromMsalRequest(HttpRequest httpRequest){

if(httpRequest.httpMethod() == HttpMethod.GET){
return builGetRequest(httpRequest);
} else if(httpRequest.httpMethod() == HttpMethod.POST){
return buildPostRequest(httpRequest);
} else {
throw new IllegalArgumentException("HttpRequest method should be either GET or POST");
}
}

private HttpGet builGetRequest(HttpRequest httpRequest){
HttpGet httpGet = new HttpGet(httpRequest.url().toString());

for(Map.Entry<String, String> entry: httpRequest.headers().entrySet()){
httpGet.setHeader(entry.getKey(), entry.getValue());
}

return httpGet;
}

private HttpPost buildPostRequest(HttpRequest httpRequest){

HttpPost httpPost = new HttpPost(httpRequest.url().toString());
for(Map.Entry<String, String> entry: httpRequest.headers().entrySet()){
httpPost.setHeader(entry.getKey(), entry.getValue());
}

String contentTypeHeaderValue = httpRequest.headerValue("Content-Type");
ContentType contentType = ContentType.getByMimeType(contentTypeHeaderValue);
StringEntity stringEntity = new StringEntity(httpRequest.body(), contentType);

httpPost.setEntity(stringEntity);
return httpPost;
}

private IHttpResponse buildMsalResponseFromApacheResponse(CloseableHttpResponse apacheResponse)
throws IOException {

IHttpResponse httpResponse = new HttpResponse();
((HttpResponse) httpResponse).statusCode(apacheResponse.getStatusLine().getStatusCode());

Map<String, List<String>> headers = new HashMap<>();
for(Header header: apacheResponse.getAllHeaders()){
headers.put(header.getName(), Collections.singletonList(header.getValue()));
}
((HttpResponse) httpResponse).headers(headers);

String responseBody = EntityUtils.toString(apacheResponse.getEntity(), "UTF-8");
((HttpResponse) httpResponse).body(responseBody);
return httpResponse;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.microsoft.aad.msal4j;

import labapi.LabResponse;
import labapi.LabUserProvider;
import labapi.NationalCloud;
import org.testng.Assert;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

import java.util.Collections;

public class HttpClientIT {
private LabUserProvider labUserProvider;

@BeforeClass
public void setUp() {
labUserProvider = LabUserProvider.getInstance();
}

@Test
public void acquireToken_okHttpClient() throws Exception {

LabResponse labResponse = getManagedUserAccountWithPassword();
assertAcquireTokenCommon(labResponse, new OkHttpClientAdapter());
}

@Test
public void acquireToken_apacheHttpClient() throws Exception {

LabResponse labResponse = getManagedUserAccountWithPassword();
assertAcquireTokenCommon(labResponse, new ApacheHttpClientAdapter());
}

private void assertAcquireTokenCommon(LabResponse labResponse, IHttpClient httpClient)
throws Exception{
PublicClientApplication pca = PublicClientApplication.builder(
labResponse.getAppId()).
authority(TestConstants.ORGANIZATIONS_AUTHORITY).
httpClient(httpClient).
build();

IAuthenticationResult result = pca.acquireToken(UserNamePasswordParameters.
builder(Collections.singleton(TestConstants.GRAPH_DEFAULT_SCOPE),
labResponse.getUser().getUpn(),
labResponse.getUser().getPassword().toCharArray())
.build())
.get();

Assert.assertNotNull(result);
Assert.assertNotNull(result.accessToken());
Assert.assertNotNull(result.idToken());
Assert.assertEquals(labResponse.getUser().getUpn(), result.account().username());
}

private LabResponse getManagedUserAccountWithPassword(){
LabResponse labResponse = labUserProvider.getDefaultUser(
NationalCloud.AZURE_CLOUD,
false);
labUserProvider.getUserPassword(labResponse.getUser());

return labResponse;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.microsoft.aad.msal4j;

import okhttp3.Headers;
import okhttp3.MediaType;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.Response;
import okhttp3.ResponseBody;

import java.io.IOException;

class OkHttpClientAdapter implements IHttpClient{

private OkHttpClient client;

OkHttpClientAdapter(){
this.client = new OkHttpClient();
}

@Override
public IHttpResponse send(HttpRequest httpRequest) throws IOException {

Request request = buildOkRequestFromMsalRequest(httpRequest);

Response okHttpResponse= client.newCall(request).execute();
return buildMsalResponseFromOkResponse(okHttpResponse);
}

private Request buildOkRequestFromMsalRequest(HttpRequest httpRequest){

if(httpRequest.httpMethod() == HttpMethod.GET){
return buildGetRequest(httpRequest);
} else if(httpRequest.httpMethod() == HttpMethod.POST){
return buildPostRequest(httpRequest);
} else {
throw new IllegalArgumentException("HttpRequest method should be either GET or POST");
}
}

private Request buildGetRequest(HttpRequest httpRequest){
Headers headers = Headers.of(httpRequest.headers());

return new Request.Builder()
.url(httpRequest.url())
.headers(headers)
.build();
}

private Request buildPostRequest(HttpRequest httpRequest){
Headers headers = Headers.of(httpRequest.headers());
String contentType = httpRequest.headerValue("Content-Type");
MediaType type = MediaType.parse(contentType);

RequestBody requestBody = RequestBody.create(type, httpRequest.body());

return new Request.Builder()
.url(httpRequest.url())
.post(requestBody)
.headers(headers)
.build();
}

private IHttpResponse buildMsalResponseFromOkResponse(Response okHttpResponse) throws IOException{

IHttpResponse httpResponse = new HttpResponse();
((HttpResponse) httpResponse).statusCode(okHttpResponse.code());

ResponseBody body = okHttpResponse.body();
if(body != null){
((HttpResponse) httpResponse).body(body.string());
}

Headers headers = okHttpResponse.headers();
if(headers != null){
((HttpResponse) httpResponse).headers(headers.toMultimap());
}
return httpResponse;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ public void acquireTokenWithUsernamePassword_ADFSv2() throws Exception{
assertAcquireTokenCommon(labResponse, password);
}

public void assertAcquireTokenCommon(LabResponse labResponse, String password)
private void assertAcquireTokenCommon(LabResponse labResponse, String password)
throws Exception{
PublicClientApplication pca = PublicClientApplication.builder(
labResponse.getAppId()).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ private IClientCredential getClientCredentialFromKeyStore() {
key = (PrivateKey) keystore.getKey(CERTIFICATE_ALIAS, null);
publicCertificate = (X509Certificate) keystore.getCertificate(
CERTIFICATE_ALIAS);

} catch (Exception e){
throw new RuntimeException("Error getting certificate from keystore: " + e.getMessage());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,18 +62,21 @@ private static String getInstanceDiscoveryEndpoint(String host) {

private static InstanceDiscoveryResponse sendInstanceDiscoveryRequest
(URL authorityUrl, MsalRequest msalRequest,
ServiceBundle serviceBundle) throws Exception {
ServiceBundle serviceBundle) {

String instanceDiscoveryRequestUrl = getInstanceDiscoveryEndpoint(authorityUrl.getAuthority()) +
INSTANCE_DISCOVERY_REQUEST_PARAMETERS_TEMPLATE.replace("{authorizeEndpoint}",
getAuthorizeEndpoint(authorityUrl.getAuthority(),
Authority.getTenant(authorityUrl, Authority.detectAuthorityType(authorityUrl))));

String json = HttpHelper.executeHttpRequest
(log, HttpMethod.GET, instanceDiscoveryRequestUrl, msalRequest.headers().getReadonlyHeaderMap(),
null, msalRequest.requestContext(), serviceBundle);
HttpRequest httpRequest = new HttpRequest(
HttpMethod.GET,
instanceDiscoveryRequestUrl,
msalRequest.headers().getReadonlyHeaderMap());

return JsonHelper.convertJsonToObject(json, InstanceDiscoveryResponse.class);
IHttpResponse httpResponse= HttpHelper.executeHttpRequest(httpRequest, msalRequest.requestContext(), serviceBundle);

return JsonHelper.convertJsonToObject(httpResponse.body(), InstanceDiscoveryResponse.class);
}

private static void validate(InstanceDiscoveryResponse instanceDiscoveryResponse) {
Expand Down
48 changes: 32 additions & 16 deletions src/main/java/com/microsoft/aad/msal4j/ClientApplicationBase.java
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,13 @@ abstract class ClientApplicationBase implements IClientApplicationBase {
@Getter(AccessLevel.PACKAGE)
private Consumer<List<HashMap<String, String>>> telemetryConsumer;

@Override
public Proxy proxy() {
return this.serviceBundle.getProxy();
}
@Accessors(fluent = true)
@Getter
public Proxy proxy;

@Override
public SSLSocketFactory sslSocketFactory() {
return this.serviceBundle.getSslSocketFactory();
}
@Accessors(fluent = true)
@Getter
public SSLSocketFactory sslSocketFactory;

@Accessors(fluent = true)
@Getter
Expand Down Expand Up @@ -159,18 +157,24 @@ AuthenticationResult acquireTokenCommon(MsalRequest msalRequest, Authority reque
headers.getHeaderCorrelationIdValue()));
}

TokenRequest request = new TokenRequest(requestAuthority, msalRequest, serviceBundle);
TokenRequestExecutor requestExecutor = new TokenRequestExecutor(
requestAuthority,
msalRequest,
serviceBundle);

AuthenticationResult result = request.executeOauthRequestAndProcessResponse();
AuthenticationResult result = requestExecutor.executeTokenRequest();

if(authenticationAuthority.authorityType.equals(AuthorityType.B2C)){
tokenCache.saveTokens(request, result, authenticationAuthority.host);
tokenCache.saveTokens(requestExecutor, result, authenticationAuthority.host);
} else {
InstanceDiscoveryMetadataEntry instanceDiscoveryMetadata =
AadInstanceDiscovery.GetMetadataEntry
(requestAuthority.canonicalAuthorityUrl(), validateAuthority, msalRequest, serviceBundle);
AadInstanceDiscovery.GetMetadataEntry(
requestAuthority.canonicalAuthorityUrl(),
validateAuthority,
msalRequest,
serviceBundle);

tokenCache.saveTokens(request, result, instanceDiscoveryMetadata.preferredCache);
tokenCache.saveTokens(requestExecutor, result, instanceDiscoveryMetadata.preferredCache);
}

return result;
Expand Down Expand Up @@ -226,6 +230,7 @@ abstract static class Builder<T extends Builder<T>> {
private ExecutorService executorService;
private Proxy proxy;
private SSLSocketFactory sslSocketFactory;
private IHttpClient httpClient;
private Consumer<List<HashMap<String, String>>> telemetryConsumer;
private Boolean onlySendFailureTelemetry = false;
private ITokenCacheAccessAspect tokenCacheAccessAspect;
Expand Down Expand Up @@ -344,6 +349,14 @@ public T proxy(Proxy val) {
return self();
}


public T httpClient(IHttpClient val){
validateNotNull("httpClient", val);

httpClient = val;
return self();
}

/**
* Sets SSLSocketFactory to be used by the client application for all network communication.
*
Expand Down Expand Up @@ -403,10 +416,13 @@ private static Authority createDefaultAADAuthority() {
correlationId = builder.correlationId;
logPii = builder.logPii;
telemetryConsumer = builder.telemetryConsumer;
proxy = builder.proxy;
sslSocketFactory = builder.sslSocketFactory;
serviceBundle = new ServiceBundle(
builder.executorService,
builder.proxy,
builder.sslSocketFactory,
builder.httpClient == null ?
new DefaultHttpClient(builder.proxy, builder.sslSocketFactory) :
builder.httpClient,
new TelemetryManager(telemetryConsumer, builder.onlySendFailureTelemetry));
authenticationAuthority = builder.authenticationAuthority;
tokenCache = new TokenCache(builder.tokenCacheAccessAspect);
Expand Down
Loading