Skip to content

Adds cross region client logic for decorating endpoint provider #4026

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.function.Function;
import java.util.stream.Stream;
import software.amazon.awssdk.annotations.SdkPublicApi;
Expand Down Expand Up @@ -121,16 +122,19 @@ private MethodSpec invokeMethod() {

TypeVariableName responseTypeVariableName = STREAMING_TYPE_VARIABLE;

ParameterizedTypeName responseFutureTypeName = ParameterizedTypeName.get(ClassName.get(CompletableFuture.class),
responseTypeVariableName);

ParameterizedTypeName functionTypeName = ParameterizedTypeName
.get(ClassName.get(Function.class), requestTypeVariableName, responseTypeVariableName);
.get(ClassName.get(Function.class), requestTypeVariableName, responseFutureTypeName);

return MethodSpec.methodBuilder("invokeOperation")
.addModifiers(PROTECTED)
.addParameter(requestTypeVariableName, "request")
.addParameter(functionTypeName, "operation")
.addTypeVariable(requestTypeVariableName)
.addTypeVariable(responseTypeVariableName)
.returns(responseTypeVariableName)
.returns(responseFutureTypeName)
.addStatement("return operation.apply(request)")
.build();
}
Expand Down Expand Up @@ -213,12 +217,9 @@ protected MethodSpec.Builder paginatedMethodBody(MethodSpec.Builder builder, Ope
String methodName = PaginatorUtils.getPaginatedMethodName(opModel.getMethodName());
return builder.addModifiers(PUBLIC)
.addAnnotation(Override.class)
.addStatement("return invokeOperation($N, request -> delegate.$N(request))",
opModel.getInput().getVariableName(),
methodName);
.addStatement("return delegate.$N($N)", methodName, opModel.getInput().getVariableName());
}


@Override
protected MethodSpec.Builder utilitiesOperationBody(MethodSpec.Builder builder) {
return builder.addAnnotation(Override.class).addStatement("return delegate.$N()", UtilitiesMethod.METHOD_NAME);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,7 @@ protected MethodSpec.Builder paginatedMethodBody(MethodSpec.Builder builder, Ope
String methodName = PaginatorUtils.getPaginatedMethodName(opModel.getMethodName());
return builder.addModifiers(PUBLIC)
.addAnnotation(Override.class)
.addStatement("return invokeOperation($N, request -> delegate.$N(request))",
opModel.getInput().getVariableName(),
methodName);
.addStatement("return delegate.$N($N)", methodName, opModel.getInput().getVariableName());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -409,8 +409,7 @@ public CompletableFuture<PaginatedOperationWithResultKeyResponse> paginatedOpera
@Override
public PaginatedOperationWithResultKeyPublisher paginatedOperationWithResultKeyPaginator(
PaginatedOperationWithResultKeyRequest paginatedOperationWithResultKeyRequest) {
return invokeOperation(paginatedOperationWithResultKeyRequest,
request -> delegate.paginatedOperationWithResultKeyPaginator(request));
return delegate.paginatedOperationWithResultKeyPaginator(paginatedOperationWithResultKeyRequest);
}

/**
Expand Down Expand Up @@ -515,8 +514,7 @@ public CompletableFuture<PaginatedOperationWithoutResultKeyResponse> paginatedOp
@Override
public PaginatedOperationWithoutResultKeyPublisher paginatedOperationWithoutResultKeyPaginator(
PaginatedOperationWithoutResultKeyRequest paginatedOperationWithoutResultKeyRequest) {
return invokeOperation(paginatedOperationWithoutResultKeyRequest,
request -> delegate.paginatedOperationWithoutResultKeyPaginator(request));
return delegate.paginatedOperationWithoutResultKeyPaginator(paginatedOperationWithoutResultKeyRequest);
}

/**
Expand Down Expand Up @@ -677,7 +675,8 @@ public SdkClient delegate() {
return this.delegate;
}

protected <T extends JsonRequest, ReturnT> ReturnT invokeOperation(T request, Function<T, ReturnT> operation) {
protected <T extends JsonRequest, ReturnT> CompletableFuture<ReturnT> invokeOperation(T request,
Function<T, CompletableFuture<ReturnT>> operation) {
return operation.apply(request);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -401,8 +401,7 @@ public PaginatedOperationWithResultKeyIterable paginatedOperationWithResultKeyPa
public PaginatedOperationWithResultKeyIterable paginatedOperationWithResultKeyPaginator(
PaginatedOperationWithResultKeyRequest paginatedOperationWithResultKeyRequest) throws AwsServiceException,
SdkClientException, JsonException {
return invokeOperation(paginatedOperationWithResultKeyRequest,
request -> delegate.paginatedOperationWithResultKeyPaginator(request));
return delegate.paginatedOperationWithResultKeyPaginator(paginatedOperationWithResultKeyRequest);
}

/**
Expand Down Expand Up @@ -504,8 +503,7 @@ public PaginatedOperationWithoutResultKeyResponse paginatedOperationWithoutResul
public PaginatedOperationWithoutResultKeyIterable paginatedOperationWithoutResultKeyPaginator(
PaginatedOperationWithoutResultKeyRequest paginatedOperationWithoutResultKeyRequest) throws AwsServiceException,
SdkClientException, JsonException {
return invokeOperation(paginatedOperationWithoutResultKeyRequest,
request -> delegate.paginatedOperationWithoutResultKeyPaginator(request));
return delegate.paginatedOperationWithoutResultKeyPaginator(paginatedOperationWithoutResultKeyRequest);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,32 +16,79 @@
package software.amazon.awssdk.services.s3.internal.crossregion;

import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.function.Function;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration;
import software.amazon.awssdk.endpoints.Endpoint;
import software.amazon.awssdk.services.s3.DelegatingS3AsyncClient;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.endpoints.S3EndpointParams;
import software.amazon.awssdk.services.s3.endpoints.S3EndpointProvider;
import software.amazon.awssdk.services.s3.model.S3Request;

@SdkInternalApi
public final class S3CrossRegionAsyncClient extends DelegatingS3AsyncClient {

public S3CrossRegionAsyncClient(S3AsyncClient s3Client) {
super(s3Client);
}

@Override
protected <T extends S3Request, ReturnT> ReturnT invokeOperation(T request, Function<T, ReturnT> operation) {
Optional<String> bucket = request.getValueForField("bucket", String.class);
protected <T extends S3Request, ReturnT> CompletableFuture<ReturnT>
invokeOperation(T request, Function<T, CompletableFuture<ReturnT>> operation) {

Optional<String> bucket = request.getValueForField("Bucket", String.class);

if (!bucket.isPresent()) {
return operation.apply(request);
}

//TODO: add modifyRequest logic
return operation.apply(request);
return operation.apply(requestWithDecoratedEndpointProvider(request, bucket.get()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does handleOperationFailure do? Should we just return the future from apply instead?

CompletableFuture<T> result = operation.apply(...);
result.whenComplete((r, t) -> handleOperationFailure(...));
return result;

The reason for this is that if we return a new future (from whenComplete), then we have to handle the case where the caller cancels the future; however "real" S3 client should already have set that up the future returned from apply

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another thought, if the result gets cancelled by the caller, then we should handle that case in whenComplete as well since we probably don't want to remove that cache entry in that case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. Added this comment to the caching task description.

.whenComplete((r, t) -> handleOperationFailure(t, bucket.get()));
}

private void handleOperationFailure(Throwable t, String bucket) {
//TODO: handle failure case
}

//Cannot avoid unchecked cast without upstream changes to supply builder function
@SuppressWarnings("unchecked")
private <T extends S3Request> T requestWithDecoratedEndpointProvider(T request, String bucket) {
return (T) request.toBuilder()
.overrideConfiguration(getOrCreateConfigWithEndpointProvider(request, bucket))
.build();
}

//TODO: optimize shared sync/async code
private AwsRequestOverrideConfiguration getOrCreateConfigWithEndpointProvider(S3Request request, String bucket) {
AwsRequestOverrideConfiguration requestOverrideConfig =
request.overrideConfiguration().orElseGet(() -> AwsRequestOverrideConfiguration.builder().build());

S3EndpointProvider delegateEndpointProvider = (S3EndpointProvider)
requestOverrideConfig.endpointProvider().orElseGet(() -> serviceClientConfiguration().endpointProvider().get());

return requestOverrideConfig.toBuilder()
.endpointProvider(BucketEndpointProvider.create(delegateEndpointProvider, bucket))
.build();
}

//TODO: add cross region logic
static final class BucketEndpointProvider implements S3EndpointProvider {
private final S3EndpointProvider delegate;
private final String bucket;

private BucketEndpointProvider(S3EndpointProvider delegate, String bucket) {
this.delegate = delegate;
this.bucket = bucket;
}

public static BucketEndpointProvider create(S3EndpointProvider delegate, String bucket) {
return new BucketEndpointProvider(delegate, bucket);
}

@Override
public CompletableFuture<Endpoint> resolveEndpoint(S3EndpointParams endpointParams) {
return delegate.resolveEndpoint(endpointParams);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,31 @@
package software.amazon.awssdk.services.s3.internal.crossregion;

import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.function.Function;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration;
import software.amazon.awssdk.endpoints.Endpoint;
import software.amazon.awssdk.services.s3.DelegatingS3Client;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.endpoints.S3EndpointParams;
import software.amazon.awssdk.services.s3.endpoints.S3EndpointProvider;
import software.amazon.awssdk.services.s3.model.S3Request;

@SdkInternalApi
public final class S3CrossRegionSyncClient extends DelegatingS3Client {

public S3CrossRegionSyncClient(S3Client s3Client) {
super(s3Client);
}

@Override
protected <T extends S3Request, ReturnT> ReturnT invokeOperation(T request, Function<T, ReturnT> operation) {
Optional<String> bucket = request.getValueForField("bucket", String.class);

Optional<String> bucket = request.getValueForField("Bucket", String.class);

if (bucket.isPresent()) {
try {
operation.apply(request); //TODO: add modifyRequest logic
return operation.apply(requestWithDecoratedEndpointProvider(request, bucket.get()));
} catch (Exception e) {
handleOperationFailure(e, bucket.get());
}
Expand All @@ -47,4 +52,43 @@ protected <T extends S3Request, ReturnT> ReturnT invokeOperation(T request, Func
private void handleOperationFailure(Throwable t, String bucket) {
//TODO: handle failure case
}

@SuppressWarnings("unchecked")
private <T extends S3Request> T requestWithDecoratedEndpointProvider(T request, String bucket) {
return (T) request.toBuilder()
.overrideConfiguration(getOrCreateConfigWithEndpointProvider(request, bucket))
.build();
}

//TODO: optimize shared sync/async code
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have a backlog task in case we forget?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't think we would since it's quite annoying and I'm pretty sure we wouldn't forget, but just in case I added a task.

private AwsRequestOverrideConfiguration getOrCreateConfigWithEndpointProvider(S3Request request, String bucket) {
AwsRequestOverrideConfiguration requestOverrideConfig =
request.overrideConfiguration().orElseGet(() -> AwsRequestOverrideConfiguration.builder().build());

S3EndpointProvider delegateEndpointProvider = (S3EndpointProvider)
requestOverrideConfig.endpointProvider().orElseGet(() -> serviceClientConfiguration().endpointProvider().get());

return requestOverrideConfig.toBuilder()
.endpointProvider(BucketEndpointProvider.create(delegateEndpointProvider, bucket))
.build();
}

static final class BucketEndpointProvider implements S3EndpointProvider {
private final S3EndpointProvider delegate;
private final String bucket;

private BucketEndpointProvider(S3EndpointProvider delegate, String bucket) {
this.delegate = delegate;
this.bucket = bucket;
}

public static BucketEndpointProvider create(S3EndpointProvider delegate, String bucket) {
return new BucketEndpointProvider(delegate, bucket);
}

@Override
public CompletableFuture<Endpoint> resolveEndpoint(S3EndpointParams endpointParams) {
return delegate.resolveEndpoint(endpointParams);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package software.amazon.awssdk.services.s3.internal.crossregion;

import static org.assertj.core.api.Assertions.assertThat;

import java.net.URI;
import java.util.concurrent.CompletableFuture;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import software.amazon.awssdk.core.async.AsyncResponseTransformer;
import software.amazon.awssdk.core.interceptor.Context;
import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
import software.amazon.awssdk.core.interceptor.ExecutionInterceptor;
import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute;
import software.amazon.awssdk.endpoints.EndpointProvider;
import software.amazon.awssdk.http.AbortableInputStream;
import software.amazon.awssdk.http.HttpExecuteResponse;
import software.amazon.awssdk.http.SdkHttpResponse;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.endpoints.internal.DefaultS3EndpointProvider;
import software.amazon.awssdk.services.s3.model.GetObjectRequest;
import software.amazon.awssdk.services.s3.model.ListObjectsV2Response;
import software.amazon.awssdk.services.s3.paginators.ListObjectsV2Iterable;
import software.amazon.awssdk.services.s3.paginators.ListObjectsV2Publisher;
import software.amazon.awssdk.testutils.service.http.MockAsyncHttpClient;
import software.amazon.awssdk.utils.StringInputStream;

class S3CrossRegionAsyncClientTest {

private static final String RESPONSE = "<Res>response</Res>";
private static final String BUCKET = "bucket";
private static final String KEY = "key";
private static final String TOKEN = "token";

private final MockAsyncHttpClient mockAsyncHttpClient = new MockAsyncHttpClient();
private CaptureInterceptor captureInterceptor;
private S3AsyncClient s3Client;

@BeforeEach
public void before() {
mockAsyncHttpClient.stubNextResponse(
HttpExecuteResponse.builder()
.response(SdkHttpResponse.builder().statusCode(200).build())
.responseBody(AbortableInputStream.create(new StringInputStream(RESPONSE)))
.build());

captureInterceptor = new CaptureInterceptor();
s3Client = S3AsyncClient.builder()
.httpClient(mockAsyncHttpClient)
.endpointOverride(URI.create("http://localhost"))
.overrideConfiguration(c -> c.addExecutionInterceptor(captureInterceptor))
.build();
}

@Test
public void standardOp_crossRegionClient_noOverrideConfig_SuccessfullyIntercepts() {
S3AsyncClient crossRegionClient = new S3CrossRegionAsyncClient(s3Client);
crossRegionClient.getObject(r -> r.bucket(BUCKET).key(KEY), AsyncResponseTransformer.toBytes());
assertThat(captureInterceptor.endpointProvider).isInstanceOf(S3CrossRegionAsyncClient.BucketEndpointProvider.class);
}

@Test
public void standardOp_crossRegionClient_existingOverrideConfig_SuccessfullyIntercepts() {
S3AsyncClient crossRegionClient = new S3CrossRegionAsyncClient(s3Client);
GetObjectRequest request = GetObjectRequest.builder()
.bucket(BUCKET)
.key(KEY)
.overrideConfiguration(o -> o.putHeader("someheader", "somevalue"))
.build();
crossRegionClient.getObject(request, AsyncResponseTransformer.toBytes());
assertThat(captureInterceptor.endpointProvider).isInstanceOf(S3CrossRegionAsyncClient.BucketEndpointProvider.class);
assertThat(mockAsyncHttpClient.getLastRequest().headers().get("someheader")).isNotNull();
}

//TODO: handle paginated calls - the paginated publisher calls should also be decorated
@Test
public void paginatedOp_crossRegionClient_DoesNotIntercept() throws Exception {
S3AsyncClient crossRegionClient = new S3CrossRegionAsyncClient(s3Client);
ListObjectsV2Publisher publisher =
crossRegionClient.listObjectsV2Paginator(r -> r.bucket(BUCKET).continuationToken(TOKEN).build());
CompletableFuture<Void> future = publisher.subscribe(ListObjectsV2Response::contents);
future.get();
assertThat(captureInterceptor.endpointProvider).isInstanceOf(DefaultS3EndpointProvider.class);
}

private static final class CaptureInterceptor implements ExecutionInterceptor {

private EndpointProvider endpointProvider;

@Override
public void beforeMarshalling(Context.BeforeMarshalling context, ExecutionAttributes executionAttributes) {
endpointProvider = executionAttributes.getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER);
}
}
}
Loading