Skip to content

Commit d1096c3

Browse files
committed
Removed async execution of HeadBucket and attached it to the completableFuture of main request
1 parent a712820 commit d1096c3

File tree

5 files changed

+111
-52
lines changed

5 files changed

+111
-52
lines changed

services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crossregion/S3CrossRegionAsyncClient.java

Lines changed: 73 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -24,21 +24,18 @@
2424
import java.util.concurrent.CompletableFuture;
2525
import java.util.concurrent.ConcurrentHashMap;
2626
import java.util.function.Function;
27-
import java.util.function.Supplier;
2827
import software.amazon.awssdk.annotations.SdkInternalApi;
29-
import software.amazon.awssdk.awscore.exception.AwsServiceException;
3028
import software.amazon.awssdk.regions.Region;
3129
import software.amazon.awssdk.services.s3.DelegatingS3AsyncClient;
3230
import software.amazon.awssdk.services.s3.S3AsyncClient;
33-
import software.amazon.awssdk.services.s3.model.HeadBucketRequest;
3431
import software.amazon.awssdk.services.s3.model.S3Exception;
3532
import software.amazon.awssdk.services.s3.model.S3Request;
3633
import software.amazon.awssdk.utils.CompletableFutureUtils;
3734

3835
@SdkInternalApi
3936
public final class S3CrossRegionAsyncClient extends DelegatingS3AsyncClient {
4037

41-
private final Map<String, CompletableFuture<Region>> bucketToRegionCache = new ConcurrentHashMap<>();
38+
private final Map<String, Region> bucketToRegionCache = new ConcurrentHashMap<>();
4239

4340
public S3CrossRegionAsyncClient(S3AsyncClient s3Client) {
4441
super(s3Client);
@@ -57,46 +54,85 @@ protected <T extends S3Request, ReturnT> CompletableFuture<ReturnT> invokeOperat
5754

5855
if (bucketToRegionCache.containsKey(bucketName)) {
5956
return operation.apply(requestWithDecoratedEndpointProvider(request,
60-
regionSupplier(bucketName),
57+
() -> bucketToRegionCache.get(bucketName),
6158
serviceClientConfiguration().endpointProvider().get()));
6259
}
63-
return operation.apply(request).thenApply(CompletableFuture::completedFuture)
64-
.exceptionally(exception -> {
65-
if (isS3RedirectException(exception.getCause())) {
66-
bucketToRegionCache.remove(bucketName);
67-
getBucketRegionFromException((S3Exception) exception.getCause())
68-
.ifPresent(
69-
region -> bucketToRegionCache.put(bucketName,
70-
CompletableFuture.completedFuture(Region.of(region))));
71-
return operation.apply(
72-
requestWithDecoratedEndpointProvider(request,
73-
regionSupplier(bucketName),
74-
serviceClientConfiguration().endpointProvider().get()));
75-
}
76-
return CompletableFutureUtils.failedFuture(exception);
77-
}).thenCompose(Function.identity());
78-
}
7960

61+
CompletableFuture<ReturnT> returnFuture = new CompletableFuture<>();
62+
operation.apply(request)
63+
.whenComplete((r, t) -> {
64+
if (t != null) {
65+
if (isS3RedirectException(t.getCause())) {
66+
bucketToRegionCache.remove(bucketName);
67+
Optional<String> bucketRegionFromException =
68+
getBucketRegionFromException((S3Exception) t.getCause());
8069

81-
private Supplier<Region> regionSupplier(String bucket) {
82-
CompletableFuture<Region> completableFuture = bucketToRegionCache.computeIfAbsent(bucket, this::regionCompletableFuture);
83-
return () -> completableFuture.join();
70+
if (bucketRegionFromException.isPresent()) {
71+
sendRequestWithRightRegion(request, operation, bucketName, returnFuture,
72+
bucketRegionFromException);
73+
} else {
74+
fetchRegionAndSendRequest(request, operation, bucketName, returnFuture);
75+
}
76+
return;
77+
}
78+
returnFuture.completeExceptionally(t);
79+
return;
80+
}
81+
returnFuture.complete(r);
82+
});
83+
return returnFuture;
8484
}
8585

86-
private CompletableFuture<Region> regionCompletableFuture(String bucketName) {
87-
return CompletableFuture.supplyAsync(() -> {
88-
try {
89-
((S3AsyncClient) delegate()).headBucket(HeadBucketRequest.builder().bucket(bucketName).build()).join();
90-
} catch (Exception exception) {
91-
if (isS3RedirectException(exception.getCause())) {
92-
String region = getBucketRegionFromException((S3Exception) exception.getCause())
93-
.orElseThrow(() -> AwsServiceException.create("Region name not found in Redirect error",
94-
exception));
95-
return Region.of(region);
86+
private <T extends S3Request, ReturnT> void fetchRegionAndSendRequest(T request,
87+
Function<T, CompletableFuture<ReturnT>> operation,
88+
String bucketName,
89+
CompletableFuture<ReturnT> returnFuture) {
90+
91+
// // TODO: will fix the casts with separate PR
92+
((S3AsyncClient) delegate()).headBucket(b -> b.bucket(bucketName)).whenComplete((response,
93+
throwable) -> {
94+
if (throwable != null) {
95+
if (isS3RedirectException(throwable.getCause())) {
96+
bucketToRegionCache.remove(bucketName);
97+
Optional<String> bucketRegion = getBucketRegionFromException((S3Exception) throwable.getCause());
98+
99+
if (bucketRegion.isPresent()) {
100+
bucketToRegionCache.put(bucketName, Region.of(bucketRegion.get()));
101+
sendRequestWithRightRegion(request, operation, bucketName, returnFuture, bucketRegion);
102+
} else {
103+
returnFuture.completeExceptionally(throwable);
104+
}
105+
} else {
106+
returnFuture.completeExceptionally(throwable);
96107
}
97-
throw exception;
108+
} else {
109+
CompletableFuture<ReturnT> newFuture = operation.apply(request);
110+
CompletableFutureUtils.forwardResultTo(newFuture, returnFuture);
111+
CompletableFutureUtils.forwardExceptionTo(returnFuture, newFuture);
98112
}
99-
return ((S3AsyncClient) delegate()).serviceClientConfiguration().region();
100113
});
101114
}
102-
}
115+
116+
private <T extends S3Request, ReturnT> void sendRequestWithRightRegion(T request,
117+
Function<T, CompletableFuture<ReturnT>> operation,
118+
String bucketName,
119+
CompletableFuture<ReturnT> returnFuture,
120+
Optional<String> bucketRegionFromException) {
121+
String region = bucketRegionFromException.get();
122+
bucketToRegionCache.put(bucketName, Region.of(region));
123+
doSendRequestWithRightRegion(request, operation, returnFuture, region);
124+
}
125+
126+
private <T extends S3Request, ReturnT> void doSendRequestWithRightRegion(T request,
127+
Function<T, CompletableFuture<ReturnT>> operation,
128+
CompletableFuture<ReturnT> returnFuture,
129+
String region) {
130+
CompletableFuture<ReturnT> newFuture = operation.apply(
131+
requestWithDecoratedEndpointProvider(request,
132+
() -> Region.of(region),
133+
serviceClientConfiguration().endpointProvider().get()));
134+
CompletableFutureUtils.forwardResultTo(newFuture, returnFuture);
135+
// forward exception
136+
CompletableFutureUtils.forwardExceptionTo(returnFuture, newFuture);
137+
}
138+
}

services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crossregion/S3CrossRegionSyncClient.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import java.util.Optional;
2424
import java.util.concurrent.ConcurrentHashMap;
2525
import java.util.function.Function;
26-
import java.util.function.Supplier;
2726
import software.amazon.awssdk.annotations.SdkInternalApi;
2827
import software.amazon.awssdk.regions.Region;
2928
import software.amazon.awssdk.services.s3.DelegatingS3Client;

services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crossregion/S3CrossRegionAsyncClientRedirectTest.java

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ public void setup() {
4949
@Override
5050
protected void stubRedirectSuccessSuccess() {
5151
when(mockDelegateAsyncClient.listObjects(any(ListObjectsRequest.class)))
52-
.thenReturn(CompletableFutureUtils.failedFuture(redirectException(301, CROSS_REGION, null, null)))
52+
.thenReturn(CompletableFutureUtils.failedFuture(new CompletionException(redirectException(301, CROSS_REGION, null, null))))
5353
.thenReturn(CompletableFuture.completedFuture(ListObjectsResponse.builder().contents(S3_OBJECTS).build()))
5454
.thenReturn(CompletableFuture.completedFuture(ListObjectsResponse.builder().contents(S3_OBJECTS).build()));
5555
}
@@ -76,7 +76,8 @@ protected void stubServiceClientConfiguration() {
7676
@Override
7777
protected void stubClientAPICallWithFirstRedirectThenSuccessWithRegionInErrorResponse() {
7878
when(mockDelegateAsyncClient.listObjects(any(ListObjectsRequest.class)))
79-
.thenReturn(CompletableFutureUtils.failedFuture(redirectException(301, CROSS_REGION, null, null)))
79+
.thenReturn(CompletableFutureUtils.failedFuture(new CompletionException(redirectException(301, CROSS_REGION, null,
80+
null))))
8081
.thenReturn(CompletableFuture.completedFuture(ListObjectsResponse.builder().contents(S3_OBJECTS).build()
8182
));
8283
}
@@ -94,37 +95,40 @@ protected ListBucketsResponse noBucketCallToService() throws Throwable {
9495
@Override
9596
protected void stubApiWithNoBucketField() {
9697
when(mockDelegateAsyncClient.listBuckets(any(ListBucketsRequest.class)))
97-
.thenReturn(CompletableFutureUtils.failedFuture(redirectException(301, CROSS_REGION, null, "Redirect")))
98+
.thenReturn(CompletableFutureUtils.failedFuture(new CompletionException(redirectException(301, CROSS_REGION, null,
99+
"Redirect"))))
98100
.thenReturn(CompletableFuture.completedFuture(ListBucketsResponse.builder().build()
99101
));
100102
}
101103

102104
@Override
103105
protected void stubHeadBucketRedirect() {
104106
when(mockDelegateAsyncClient.headBucket(any(HeadBucketRequest.class)))
105-
.thenThrow(new CompletionException(redirectException(301,CROSS_REGION, null, null)));
107+
.thenReturn(CompletableFutureUtils.failedFuture(new CompletionException(redirectException(301,CROSS_REGION, null, null))));
106108
when(mockDelegateAsyncClient.headBucket(any(Consumer.class)))
107-
.thenThrow(new CompletionException(redirectException(301,CROSS_REGION, null, null)));
109+
.thenReturn(CompletableFutureUtils.failedFuture(new CompletionException(redirectException(301,CROSS_REGION, null, null))));
108110
}
109111

110112
@Override
111113
protected void stubRedirectWithNoRegionAndThenSuccess() {
112114
when(mockDelegateAsyncClient.listObjects(any(ListObjectsRequest.class)))
113-
.thenReturn(CompletableFutureUtils.failedFuture(redirectException(301, null, null, null)))
115+
.thenReturn(CompletableFutureUtils.failedFuture(new CompletionException(redirectException(301, null, null, null))))
114116
.thenReturn(CompletableFuture.completedFuture(ListObjectsResponse.builder().contents(S3_OBJECTS).build()))
115117
.thenReturn(CompletableFuture.completedFuture(ListObjectsResponse.builder().contents(S3_OBJECTS).build()));
116118
}
117119

118120
@Override
119121
protected void stubRedirectThenError() {
120122
when(mockDelegateAsyncClient.listObjects(any(ListObjectsRequest.class)))
121-
.thenReturn(CompletableFutureUtils.failedFuture(redirectException(301, CROSS_REGION, null, null)))
122-
.thenReturn(CompletableFutureUtils.failedFuture(redirectException(400, null, "InvalidArgument", "Invalid id")));
123+
.thenReturn(CompletableFutureUtils.failedFuture(new CompletionException(redirectException(301, CROSS_REGION, null,
124+
null))))
125+
.thenReturn(CompletableFutureUtils.failedFuture(new CompletionException(redirectException(400, null,
126+
"InvalidArgument", "Invalid id"))));
123127
}
124128

125129
@Override
126130
protected void verifyHeadBucketServiceCall(int times) {
127-
verify(mockDelegateAsyncClient, times(times)).headBucket(any(HeadBucketRequest.class));
131+
verify(mockDelegateAsyncClient, times(times)).headBucket(any(Consumer.class));
128132
}
129133

130134
@Override

services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crossregion/S3CrossRegionAsyncClientTest.java

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import java.util.List;
2727
import java.util.concurrent.CompletableFuture;
2828
import java.util.concurrent.CompletionException;
29+
import java.util.concurrent.ExecutionException;
2930
import java.util.function.Consumer;
3031
import java.util.stream.Collectors;
3132
import java.util.stream.Stream;
@@ -34,6 +35,7 @@
3435
import org.junit.jupiter.params.ParameterizedTest;
3536
import org.junit.jupiter.params.provider.Arguments;
3637
import org.junit.jupiter.params.provider.MethodSource;
38+
import software.amazon.awssdk.core.ResponseBytes;
3739
import software.amazon.awssdk.core.async.AsyncResponseTransformer;
3840
import software.amazon.awssdk.core.exception.SdkClientException;
3941
import software.amazon.awssdk.core.interceptor.Context;
@@ -53,6 +55,7 @@
5355
import software.amazon.awssdk.services.s3.endpoints.internal.DefaultS3EndpointProvider;
5456
import software.amazon.awssdk.services.s3.internal.crossregion.endpointprovider.BucketEndpointProvider;
5557
import software.amazon.awssdk.services.s3.model.GetObjectRequest;
58+
import software.amazon.awssdk.services.s3.model.GetObjectResponse;
5659
import software.amazon.awssdk.services.s3.model.ListObjectsV2Response;
5760
import software.amazon.awssdk.services.s3.model.S3Exception;
5861
import software.amazon.awssdk.services.s3.paginators.ListObjectsV2Publisher;
@@ -139,7 +142,8 @@ void crossRegionClient_createdWithWrapping_SuccessfullyIntercepts(Consumer<MockA
139142
}
140143

141144
@Test
142-
void crossRegionClient_CallsHeadObject_when_regionNameNotPresentInFallBackCall() {
145+
void crossRegionClient_CallsHeadObject_when_regionNameNotPresentInFallBackCall(){
146+
mockAsyncHttpClient.reset();
143147
mockAsyncHttpClient.stubResponses(customHttpResponse(301, null ),
144148
customHttpResponse(301, CROSS_REGION ),
145149
successHttpResponse(), successHttpResponse());
@@ -188,7 +192,7 @@ void crossRegionClient_CallsHeadObjectErrors_shouldTerminateTheAPI() {
188192

189193
assertThatExceptionOfType(CompletionException.class)
190194
.isThrownBy(() -> crossRegionClient.getObject(r -> r.bucket(BUCKET).key(KEY), AsyncResponseTransformer.toBytes()).join())
191-
.withMessageContaining("Endpoint resolution failed");
195+
.withMessageContaining("software.amazon.awssdk.services.s3.model.S3Exception: null (Service: S3, Status Code: 400, Request ID: null)");
192196

193197
List<SdkHttpRequest> requests = mockAsyncHttpClient.getRequests();
194198
assertThat(requests).hasSize(2);
@@ -213,8 +217,8 @@ void crossRegionClient_CallsHeadObjectWithNoRegion_shouldTerminateHeadBucketAPI(
213217

214218
assertThatExceptionOfType(CompletionException.class)
215219
.isThrownBy(() -> crossRegionClient.getObject(r -> r.bucket(BUCKET).key(KEY), AsyncResponseTransformer.toBytes()).join())
216-
.withMessageContaining("Endpoint resolution failed")
217-
.withCauseInstanceOf(SdkClientException.class).withRootCauseExactlyInstanceOf(S3Exception.class);
220+
.withMessageContaining("software.amazon.awssdk.services.s3.model.S3Exception: null (Service: S3, Status Code: 301, Request ID: null)")
221+
.withCauseInstanceOf(S3Exception.class).withRootCauseExactlyInstanceOf(S3Exception.class);
218222

219223
List<SdkHttpRequest> requests = mockAsyncHttpClient.getRequests();
220224
assertThat(requests).hasSize(2);
@@ -228,6 +232,22 @@ void crossRegionClient_CallsHeadObjectWithNoRegion_shouldTerminateHeadBucketAPI(
228232
SdkHttpMethod.HEAD));
229233
}
230234

235+
236+
@Test
237+
void crossRegionClient_cancelsTheThread_when_futureIsCancelled(){
238+
mockAsyncHttpClient.reset();
239+
mockAsyncHttpClient.stubResponses(customHttpResponse(301, null ),
240+
customHttpResponse(301, CROSS_REGION ),
241+
successHttpResponse(), successHttpResponse());
242+
S3AsyncClient crossRegionClient =
243+
clientBuilder().endpointOverride(null).region(OVERRIDE_CONFIGURED_REGION).serviceConfiguration(c -> c.crossRegionAccessEnabled(true)).build();
244+
CompletableFuture<ResponseBytes<GetObjectResponse>> completableFuture = crossRegionClient.getObject(r -> r.bucket(BUCKET).key(KEY)
245+
, AsyncResponseTransformer.toBytes());
246+
247+
completableFuture.cancel(true);
248+
assertThat(completableFuture.isCancelled()).isTrue();
249+
}
250+
231251
private S3AsyncClientBuilder clientBuilder() {
232252
return S3AsyncClient.builder()
233253
.httpClient(mockAsyncHttpClient)

services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crossregion/S3DecoratorRedirectTestBase.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ void apiCallFailure_when_CallFailsAfterRedirection() {
9999
stubRedirectThenError();
100100
assertThatExceptionOfType(S3Exception.class)
101101
.isThrownBy(() -> apiCallToService())
102-
.withMessage("Invalid id (Service: S3, Status Code: 400, Request ID: 1, Extended Request ID: A1)");
102+
.withMessageContaining("Invalid id (Service: S3, Status Code: 400, Request ID: 1, Extended Request ID: A1)");
103103
verifyHeadBucketServiceCall(0);
104104
}
105105

0 commit comments

Comments
 (0)