Skip to content

Commit 2a15452

Browse files
authored
Adds cross region client logic for decorating endpoint provider (#4026)
* Adds cross region client logic for decorating endpoint provider
1 parent 8fcef4e commit 2a15452

File tree

11 files changed

+608
-27
lines changed

11 files changed

+608
-27
lines changed

codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/DelegatingAsyncClientClass.java

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import java.util.ArrayList;
3434
import java.util.Comparator;
3535
import java.util.List;
36+
import java.util.concurrent.CompletableFuture;
3637
import java.util.function.Function;
3738
import java.util.stream.Stream;
3839
import software.amazon.awssdk.annotations.SdkPublicApi;
@@ -121,16 +122,19 @@ private MethodSpec invokeMethod() {
121122

122123
TypeVariableName responseTypeVariableName = STREAMING_TYPE_VARIABLE;
123124

125+
ParameterizedTypeName responseFutureTypeName = ParameterizedTypeName.get(ClassName.get(CompletableFuture.class),
126+
responseTypeVariableName);
127+
124128
ParameterizedTypeName functionTypeName = ParameterizedTypeName
125-
.get(ClassName.get(Function.class), requestTypeVariableName, responseTypeVariableName);
129+
.get(ClassName.get(Function.class), requestTypeVariableName, responseFutureTypeName);
126130

127131
return MethodSpec.methodBuilder("invokeOperation")
128132
.addModifiers(PROTECTED)
129133
.addParameter(requestTypeVariableName, "request")
130134
.addParameter(functionTypeName, "operation")
131135
.addTypeVariable(requestTypeVariableName)
132136
.addTypeVariable(responseTypeVariableName)
133-
.returns(responseTypeVariableName)
137+
.returns(responseFutureTypeName)
134138
.addStatement("return operation.apply(request)")
135139
.build();
136140
}
@@ -213,12 +217,9 @@ protected MethodSpec.Builder paginatedMethodBody(MethodSpec.Builder builder, Ope
213217
String methodName = PaginatorUtils.getPaginatedMethodName(opModel.getMethodName());
214218
return builder.addModifiers(PUBLIC)
215219
.addAnnotation(Override.class)
216-
.addStatement("return invokeOperation($N, request -> delegate.$N(request))",
217-
opModel.getInput().getVariableName(),
218-
methodName);
220+
.addStatement("return delegate.$N($N)", methodName, opModel.getInput().getVariableName());
219221
}
220222

221-
222223
@Override
223224
protected MethodSpec.Builder utilitiesOperationBody(MethodSpec.Builder builder) {
224225
return builder.addAnnotation(Override.class).addStatement("return delegate.$N()", UtilitiesMethod.METHOD_NAME);

codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/DelegatingSyncClientClass.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,7 @@ protected MethodSpec.Builder paginatedMethodBody(MethodSpec.Builder builder, Ope
148148
String methodName = PaginatorUtils.getPaginatedMethodName(opModel.getMethodName());
149149
return builder.addModifiers(PUBLIC)
150150
.addAnnotation(Override.class)
151-
.addStatement("return invokeOperation($N, request -> delegate.$N(request))",
152-
opModel.getInput().getVariableName(),
153-
methodName);
151+
.addStatement("return delegate.$N($N)", methodName, opModel.getInput().getVariableName());
154152
}
155153

156154
@Override

codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-abstract-async-client-class.java

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -409,8 +409,7 @@ public CompletableFuture<PaginatedOperationWithResultKeyResponse> paginatedOpera
409409
@Override
410410
public PaginatedOperationWithResultKeyPublisher paginatedOperationWithResultKeyPaginator(
411411
PaginatedOperationWithResultKeyRequest paginatedOperationWithResultKeyRequest) {
412-
return invokeOperation(paginatedOperationWithResultKeyRequest,
413-
request -> delegate.paginatedOperationWithResultKeyPaginator(request));
412+
return delegate.paginatedOperationWithResultKeyPaginator(paginatedOperationWithResultKeyRequest);
414413
}
415414

416415
/**
@@ -515,8 +514,7 @@ public CompletableFuture<PaginatedOperationWithoutResultKeyResponse> paginatedOp
515514
@Override
516515
public PaginatedOperationWithoutResultKeyPublisher paginatedOperationWithoutResultKeyPaginator(
517516
PaginatedOperationWithoutResultKeyRequest paginatedOperationWithoutResultKeyRequest) {
518-
return invokeOperation(paginatedOperationWithoutResultKeyRequest,
519-
request -> delegate.paginatedOperationWithoutResultKeyPaginator(request));
517+
return delegate.paginatedOperationWithoutResultKeyPaginator(paginatedOperationWithoutResultKeyRequest);
520518
}
521519

522520
/**
@@ -677,7 +675,8 @@ public SdkClient delegate() {
677675
return this.delegate;
678676
}
679677

680-
protected <T extends JsonRequest, ReturnT> ReturnT invokeOperation(T request, Function<T, ReturnT> operation) {
678+
protected <T extends JsonRequest, ReturnT> CompletableFuture<ReturnT> invokeOperation(T request,
679+
Function<T, CompletableFuture<ReturnT>> operation) {
681680
return operation.apply(request);
682681
}
683682

codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-abstract-sync-client-class.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -401,8 +401,7 @@ public PaginatedOperationWithResultKeyIterable paginatedOperationWithResultKeyPa
401401
public PaginatedOperationWithResultKeyIterable paginatedOperationWithResultKeyPaginator(
402402
PaginatedOperationWithResultKeyRequest paginatedOperationWithResultKeyRequest) throws AwsServiceException,
403403
SdkClientException, JsonException {
404-
return invokeOperation(paginatedOperationWithResultKeyRequest,
405-
request -> delegate.paginatedOperationWithResultKeyPaginator(request));
404+
return delegate.paginatedOperationWithResultKeyPaginator(paginatedOperationWithResultKeyRequest);
406405
}
407406

408407
/**
@@ -504,8 +503,7 @@ public PaginatedOperationWithoutResultKeyResponse paginatedOperationWithoutResul
504503
public PaginatedOperationWithoutResultKeyIterable paginatedOperationWithoutResultKeyPaginator(
505504
PaginatedOperationWithoutResultKeyRequest paginatedOperationWithoutResultKeyRequest) throws AwsServiceException,
506505
SdkClientException, JsonException {
507-
return invokeOperation(paginatedOperationWithoutResultKeyRequest,
508-
request -> delegate.paginatedOperationWithoutResultKeyPaginator(request));
506+
return delegate.paginatedOperationWithoutResultKeyPaginator(paginatedOperationWithoutResultKeyRequest);
509507
}
510508

511509
/**

services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crossregion/S3CrossRegionAsyncClient.java

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,32 +16,79 @@
1616
package software.amazon.awssdk.services.s3.internal.crossregion;
1717

1818
import java.util.Optional;
19+
import java.util.concurrent.CompletableFuture;
1920
import java.util.function.Function;
2021
import software.amazon.awssdk.annotations.SdkInternalApi;
22+
import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration;
23+
import software.amazon.awssdk.endpoints.Endpoint;
2124
import software.amazon.awssdk.services.s3.DelegatingS3AsyncClient;
2225
import software.amazon.awssdk.services.s3.S3AsyncClient;
26+
import software.amazon.awssdk.services.s3.endpoints.S3EndpointParams;
27+
import software.amazon.awssdk.services.s3.endpoints.S3EndpointProvider;
2328
import software.amazon.awssdk.services.s3.model.S3Request;
2429

2530
@SdkInternalApi
2631
public final class S3CrossRegionAsyncClient extends DelegatingS3AsyncClient {
27-
2832
public S3CrossRegionAsyncClient(S3AsyncClient s3Client) {
2933
super(s3Client);
3034
}
3135

3236
@Override
33-
protected <T extends S3Request, ReturnT> ReturnT invokeOperation(T request, Function<T, ReturnT> operation) {
34-
Optional<String> bucket = request.getValueForField("bucket", String.class);
37+
protected <T extends S3Request, ReturnT> CompletableFuture<ReturnT>
38+
invokeOperation(T request, Function<T, CompletableFuture<ReturnT>> operation) {
39+
40+
Optional<String> bucket = request.getValueForField("Bucket", String.class);
3541

3642
if (!bucket.isPresent()) {
3743
return operation.apply(request);
3844
}
3945

40-
//TODO: add modifyRequest logic
41-
return operation.apply(request);
46+
return operation.apply(requestWithDecoratedEndpointProvider(request, bucket.get()))
47+
.whenComplete((r, t) -> handleOperationFailure(t, bucket.get()));
4248
}
4349

4450
private void handleOperationFailure(Throwable t, String bucket) {
4551
//TODO: handle failure case
4652
}
53+
54+
//Cannot avoid unchecked cast without upstream changes to supply builder function
55+
@SuppressWarnings("unchecked")
56+
private <T extends S3Request> T requestWithDecoratedEndpointProvider(T request, String bucket) {
57+
return (T) request.toBuilder()
58+
.overrideConfiguration(getOrCreateConfigWithEndpointProvider(request, bucket))
59+
.build();
60+
}
61+
62+
//TODO: optimize shared sync/async code
63+
private AwsRequestOverrideConfiguration getOrCreateConfigWithEndpointProvider(S3Request request, String bucket) {
64+
AwsRequestOverrideConfiguration requestOverrideConfig =
65+
request.overrideConfiguration().orElseGet(() -> AwsRequestOverrideConfiguration.builder().build());
66+
67+
S3EndpointProvider delegateEndpointProvider = (S3EndpointProvider)
68+
requestOverrideConfig.endpointProvider().orElseGet(() -> serviceClientConfiguration().endpointProvider().get());
69+
70+
return requestOverrideConfig.toBuilder()
71+
.endpointProvider(BucketEndpointProvider.create(delegateEndpointProvider, bucket))
72+
.build();
73+
}
74+
75+
//TODO: add cross region logic
76+
static final class BucketEndpointProvider implements S3EndpointProvider {
77+
private final S3EndpointProvider delegate;
78+
private final String bucket;
79+
80+
private BucketEndpointProvider(S3EndpointProvider delegate, String bucket) {
81+
this.delegate = delegate;
82+
this.bucket = bucket;
83+
}
84+
85+
public static BucketEndpointProvider create(S3EndpointProvider delegate, String bucket) {
86+
return new BucketEndpointProvider(delegate, bucket);
87+
}
88+
89+
@Override
90+
public CompletableFuture<Endpoint> resolveEndpoint(S3EndpointParams endpointParams) {
91+
return delegate.resolveEndpoint(endpointParams);
92+
}
93+
}
4794
}

services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crossregion/S3CrossRegionSyncClient.java

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,26 +16,31 @@
1616
package software.amazon.awssdk.services.s3.internal.crossregion;
1717

1818
import java.util.Optional;
19+
import java.util.concurrent.CompletableFuture;
1920
import java.util.function.Function;
2021
import software.amazon.awssdk.annotations.SdkInternalApi;
22+
import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration;
23+
import software.amazon.awssdk.endpoints.Endpoint;
2124
import software.amazon.awssdk.services.s3.DelegatingS3Client;
2225
import software.amazon.awssdk.services.s3.S3Client;
26+
import software.amazon.awssdk.services.s3.endpoints.S3EndpointParams;
27+
import software.amazon.awssdk.services.s3.endpoints.S3EndpointProvider;
2328
import software.amazon.awssdk.services.s3.model.S3Request;
2429

2530
@SdkInternalApi
2631
public final class S3CrossRegionSyncClient extends DelegatingS3Client {
27-
2832
public S3CrossRegionSyncClient(S3Client s3Client) {
2933
super(s3Client);
3034
}
3135

3236
@Override
3337
protected <T extends S3Request, ReturnT> ReturnT invokeOperation(T request, Function<T, ReturnT> operation) {
34-
Optional<String> bucket = request.getValueForField("bucket", String.class);
38+
39+
Optional<String> bucket = request.getValueForField("Bucket", String.class);
3540

3641
if (bucket.isPresent()) {
3742
try {
38-
operation.apply(request); //TODO: add modifyRequest logic
43+
return operation.apply(requestWithDecoratedEndpointProvider(request, bucket.get()));
3944
} catch (Exception e) {
4045
handleOperationFailure(e, bucket.get());
4146
}
@@ -47,4 +52,43 @@ protected <T extends S3Request, ReturnT> ReturnT invokeOperation(T request, Func
4752
private void handleOperationFailure(Throwable t, String bucket) {
4853
//TODO: handle failure case
4954
}
55+
56+
@SuppressWarnings("unchecked")
57+
private <T extends S3Request> T requestWithDecoratedEndpointProvider(T request, String bucket) {
58+
return (T) request.toBuilder()
59+
.overrideConfiguration(getOrCreateConfigWithEndpointProvider(request, bucket))
60+
.build();
61+
}
62+
63+
//TODO: optimize shared sync/async code
64+
private AwsRequestOverrideConfiguration getOrCreateConfigWithEndpointProvider(S3Request request, String bucket) {
65+
AwsRequestOverrideConfiguration requestOverrideConfig =
66+
request.overrideConfiguration().orElseGet(() -> AwsRequestOverrideConfiguration.builder().build());
67+
68+
S3EndpointProvider delegateEndpointProvider = (S3EndpointProvider)
69+
requestOverrideConfig.endpointProvider().orElseGet(() -> serviceClientConfiguration().endpointProvider().get());
70+
71+
return requestOverrideConfig.toBuilder()
72+
.endpointProvider(BucketEndpointProvider.create(delegateEndpointProvider, bucket))
73+
.build();
74+
}
75+
76+
static final class BucketEndpointProvider implements S3EndpointProvider {
77+
private final S3EndpointProvider delegate;
78+
private final String bucket;
79+
80+
private BucketEndpointProvider(S3EndpointProvider delegate, String bucket) {
81+
this.delegate = delegate;
82+
this.bucket = bucket;
83+
}
84+
85+
public static BucketEndpointProvider create(S3EndpointProvider delegate, String bucket) {
86+
return new BucketEndpointProvider(delegate, bucket);
87+
}
88+
89+
@Override
90+
public CompletableFuture<Endpoint> resolveEndpoint(S3EndpointParams endpointParams) {
91+
return delegate.resolveEndpoint(endpointParams);
92+
}
93+
}
5094
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License").
5+
* You may not use this file except in compliance with the License.
6+
* A copy of the License is located at
7+
*
8+
* http://aws.amazon.com/apache2.0
9+
*
10+
* or in the "license" file accompanying this file. This file is distributed
11+
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12+
* express or implied. See the License for the specific language governing
13+
* permissions and limitations under the License.
14+
*/
15+
16+
package software.amazon.awssdk.services.s3.internal.crossregion;
17+
18+
import static org.assertj.core.api.Assertions.assertThat;
19+
20+
import java.net.URI;
21+
import java.util.concurrent.CompletableFuture;
22+
import org.junit.jupiter.api.BeforeEach;
23+
import org.junit.jupiter.api.Test;
24+
import software.amazon.awssdk.core.async.AsyncResponseTransformer;
25+
import software.amazon.awssdk.core.interceptor.Context;
26+
import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
27+
import software.amazon.awssdk.core.interceptor.ExecutionInterceptor;
28+
import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute;
29+
import software.amazon.awssdk.endpoints.EndpointProvider;
30+
import software.amazon.awssdk.http.AbortableInputStream;
31+
import software.amazon.awssdk.http.HttpExecuteResponse;
32+
import software.amazon.awssdk.http.SdkHttpResponse;
33+
import software.amazon.awssdk.services.s3.S3AsyncClient;
34+
import software.amazon.awssdk.services.s3.S3Client;
35+
import software.amazon.awssdk.services.s3.endpoints.internal.DefaultS3EndpointProvider;
36+
import software.amazon.awssdk.services.s3.model.GetObjectRequest;
37+
import software.amazon.awssdk.services.s3.model.ListObjectsV2Response;
38+
import software.amazon.awssdk.services.s3.paginators.ListObjectsV2Iterable;
39+
import software.amazon.awssdk.services.s3.paginators.ListObjectsV2Publisher;
40+
import software.amazon.awssdk.testutils.service.http.MockAsyncHttpClient;
41+
import software.amazon.awssdk.utils.StringInputStream;
42+
43+
class S3CrossRegionAsyncClientTest {
44+
45+
private static final String RESPONSE = "<Res>response</Res>";
46+
private static final String BUCKET = "bucket";
47+
private static final String KEY = "key";
48+
private static final String TOKEN = "token";
49+
50+
private final MockAsyncHttpClient mockAsyncHttpClient = new MockAsyncHttpClient();
51+
private CaptureInterceptor captureInterceptor;
52+
private S3AsyncClient s3Client;
53+
54+
@BeforeEach
55+
public void before() {
56+
mockAsyncHttpClient.stubNextResponse(
57+
HttpExecuteResponse.builder()
58+
.response(SdkHttpResponse.builder().statusCode(200).build())
59+
.responseBody(AbortableInputStream.create(new StringInputStream(RESPONSE)))
60+
.build());
61+
62+
captureInterceptor = new CaptureInterceptor();
63+
s3Client = S3AsyncClient.builder()
64+
.httpClient(mockAsyncHttpClient)
65+
.endpointOverride(URI.create("http://localhost"))
66+
.overrideConfiguration(c -> c.addExecutionInterceptor(captureInterceptor))
67+
.build();
68+
}
69+
70+
@Test
71+
public void standardOp_crossRegionClient_noOverrideConfig_SuccessfullyIntercepts() {
72+
S3AsyncClient crossRegionClient = new S3CrossRegionAsyncClient(s3Client);
73+
crossRegionClient.getObject(r -> r.bucket(BUCKET).key(KEY), AsyncResponseTransformer.toBytes());
74+
assertThat(captureInterceptor.endpointProvider).isInstanceOf(S3CrossRegionAsyncClient.BucketEndpointProvider.class);
75+
}
76+
77+
@Test
78+
public void standardOp_crossRegionClient_existingOverrideConfig_SuccessfullyIntercepts() {
79+
S3AsyncClient crossRegionClient = new S3CrossRegionAsyncClient(s3Client);
80+
GetObjectRequest request = GetObjectRequest.builder()
81+
.bucket(BUCKET)
82+
.key(KEY)
83+
.overrideConfiguration(o -> o.putHeader("someheader", "somevalue"))
84+
.build();
85+
crossRegionClient.getObject(request, AsyncResponseTransformer.toBytes());
86+
assertThat(captureInterceptor.endpointProvider).isInstanceOf(S3CrossRegionAsyncClient.BucketEndpointProvider.class);
87+
assertThat(mockAsyncHttpClient.getLastRequest().headers().get("someheader")).isNotNull();
88+
}
89+
90+
//TODO: handle paginated calls - the paginated publisher calls should also be decorated
91+
@Test
92+
public void paginatedOp_crossRegionClient_DoesNotIntercept() throws Exception {
93+
S3AsyncClient crossRegionClient = new S3CrossRegionAsyncClient(s3Client);
94+
ListObjectsV2Publisher publisher =
95+
crossRegionClient.listObjectsV2Paginator(r -> r.bucket(BUCKET).continuationToken(TOKEN).build());
96+
CompletableFuture<Void> future = publisher.subscribe(ListObjectsV2Response::contents);
97+
future.get();
98+
assertThat(captureInterceptor.endpointProvider).isInstanceOf(DefaultS3EndpointProvider.class);
99+
}
100+
101+
private static final class CaptureInterceptor implements ExecutionInterceptor {
102+
103+
private EndpointProvider endpointProvider;
104+
105+
@Override
106+
public void beforeMarshalling(Context.BeforeMarshalling context, ExecutionAttributes executionAttributes) {
107+
endpointProvider = executionAttributes.getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER);
108+
}
109+
}
110+
}

0 commit comments

Comments
 (0)