Skip to content

Commit 04b6730

Browse files
authored
Remove identity join in async client for endpoint discovery (#3947)
1 parent fd683d8 commit 04b6730

File tree

3 files changed

+73
-51
lines changed

3 files changed

+73
-51
lines changed

codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/AsyncClientClass.java

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import com.squareup.javapoet.ParameterizedTypeName;
3636
import com.squareup.javapoet.TypeName;
3737
import com.squareup.javapoet.TypeSpec;
38+
import com.squareup.javapoet.WildcardTypeName;
3839
import java.net.URI;
3940
import java.nio.ByteBuffer;
4041
import java.util.ArrayList;
@@ -80,6 +81,7 @@
8081
import software.amazon.awssdk.core.endpointdiscovery.EndpointDiscoveryRefreshCache;
8182
import software.amazon.awssdk.core.endpointdiscovery.EndpointDiscoveryRequest;
8283
import software.amazon.awssdk.core.metrics.CoreMetric;
84+
import software.amazon.awssdk.identity.spi.AwsCredentialsIdentity;
8385
import software.amazon.awssdk.metrics.MetricCollector;
8486
import software.amazon.awssdk.metrics.MetricPublisher;
8587
import software.amazon.awssdk.metrics.NoOpMetricCollector;
@@ -374,24 +376,29 @@ protected MethodSpec.Builder operationBody(MethodSpec.Builder builder, Operation
374376
}
375377
}
376378

377-
builder.addStatement("$T cachedEndpoint = null", URI.class);
379+
builder.addStatement("$T<$T> endpointFuture = $T.completedFuture(null)",
380+
CompletableFuture.class, URI.class, CompletableFuture.class);
378381
builder.beginControlFlow("if (endpointDiscoveryEnabled)");
379382

380-
builder.addCode("$T key = $N.overrideConfiguration()", String.class, opModel.getInput().getVariableName())
383+
ParameterizedTypeName identityFutureTypeName = ParameterizedTypeName.get(ClassName.get(CompletableFuture.class),
384+
WildcardTypeName.subtypeOf(AwsCredentialsIdentity.class));
385+
builder.addCode("$T identityFuture = $N.overrideConfiguration()", identityFutureTypeName,
386+
opModel.getInput().getVariableName())
381387
.addCode(" .flatMap($T::credentialsIdentityProvider)", AwsRequestOverrideConfiguration.class)
382388
.addCode(" .orElseGet(() -> clientConfiguration.option($T.CREDENTIALS_IDENTITY_PROVIDER))",
383389
AwsClientOption.class)
384-
// TODO: avoid join inside async
385-
.addCode(" .resolveIdentity().join().accessKeyId();");
390+
.addCode(" .resolveIdentity();");
386391

387-
builder.addCode("$1T endpointDiscoveryRequest = $1T.builder()", EndpointDiscoveryRequest.class)
388-
.addCode(" .required($L)", opModel.getInputShape().getEndpointDiscovery().isRequired())
389-
.addCode(" .defaultEndpoint(clientConfiguration.option($T.ENDPOINT))", SdkClientOption.class)
390-
.addCode(" .overrideConfiguration($N.overrideConfiguration().orElse(null))",
392+
builder.addCode("endpointFuture = identityFuture.thenApply(credentials -> {")
393+
.addCode(" $1T endpointDiscoveryRequest = $1T.builder()", EndpointDiscoveryRequest.class)
394+
.addCode(" .required($L)", opModel.getInputShape().getEndpointDiscovery().isRequired())
395+
.addCode(" .defaultEndpoint(clientConfiguration.option($T.ENDPOINT))", SdkClientOption.class)
396+
.addCode(" .overrideConfiguration($N.overrideConfiguration().orElse(null))",
391397
opModel.getInput().getVariableName())
392-
.addCode(" .build();");
398+
.addCode(" .build();")
399+
.addCode(" return endpointDiscoveryCache.get(credentials.accessKeyId(), endpointDiscoveryRequest);")
400+
.addCode("});");
393401

394-
builder.addStatement("cachedEndpoint = endpointDiscoveryCache.get(key, endpointDiscoveryRequest)");
395402
builder.endControlFlow();
396403
}
397404

codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/JsonProtocolSpec.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -240,8 +240,9 @@ public CodeBlock asyncExecutionHandler(IntermediateModel intermediateModel, Oper
240240
: pojoResponseType;
241241
TypeName executeFutureValueType = executeFutureValueType(opModel, poetExtensions);
242242

243-
builder.add("\n\n$T<$T> executeFuture = clientHandler.execute(new $T<$T, $T>()\n",
244-
CompletableFuture.class, executeFutureValueType, ClientExecutionParams.class, requestType, responseType)
243+
builder.add("\n\n$T<$T> executeFuture = ", CompletableFuture.class, executeFutureValueType)
244+
.add(opModel.getEndpointDiscovery() != null ? "endpointFuture.thenCompose(cachedEndpoint -> " : "")
245+
.add("clientHandler.execute(new $T<$T, $T>()\n", ClientExecutionParams.class, requestType, responseType)
245246
.add(".withOperationName(\"$N\")\n", opModel.getOperationName())
246247
.add(".withMarshaller($L)\n", asyncMarshaller(model, opModel, marshaller, protocolFactory))
247248
.add(asyncRequestBody(opModel))
@@ -257,8 +258,9 @@ public CodeBlock asyncExecutionHandler(IntermediateModel intermediateModel, Oper
257258
.add(HttpChecksumRequiredTrait.putHttpChecksumAttribute(opModel))
258259
.add(HttpChecksumTrait.create(opModel))
259260
.add(NoneAuthTypeRequestTrait.create(opModel))
260-
.add(".withInput($L)$L);",
261-
opModel.getInput().getVariableName(), asyncResponseTransformerVariable(isStreaming, isRestJson, opModel));
261+
.add(".withInput($L)$L)",
262+
opModel.getInput().getVariableName(), asyncResponseTransformerVariable(isStreaming, isRestJson, opModel))
263+
.add(opModel.getEndpointDiscovery() != null ? ");" : ";");
262264

263265
if (opModel.hasStreamingOutput()) {
264266
builder.addStatement("$T<$T, ReturnT> finalAsyncResponseTransformer = asyncResponseTransformer",

codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-endpoint-discovery-async.java

Lines changed: 50 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import software.amazon.awssdk.core.endpointdiscovery.EndpointDiscoveryRequest;
2424
import software.amazon.awssdk.core.http.HttpResponseHandler;
2525
import software.amazon.awssdk.core.metrics.CoreMetric;
26+
import software.amazon.awssdk.identity.spi.AwsCredentialsIdentity;
2627
import software.amazon.awssdk.metrics.MetricCollector;
2728
import software.amazon.awssdk.metrics.MetricPublisher;
2829
import software.amazon.awssdk.metrics.NoOpMetricCollector;
@@ -175,26 +176,30 @@ public CompletableFuture<TestDiscoveryIdentifiersRequiredResponse> testDiscovery
175176
throw new IllegalStateException(
176177
"This operation requires endpoint discovery, but endpoint discovery was disabled on the client.");
177178
}
178-
URI cachedEndpoint = null;
179+
CompletableFuture<URI> endpointFuture = CompletableFuture.completedFuture(null);
179180
if (endpointDiscoveryEnabled) {
180-
String key = testDiscoveryIdentifiersRequiredRequest.overrideConfiguration()
181-
.flatMap(AwsRequestOverrideConfiguration::credentialsIdentityProvider)
182-
.orElseGet(() -> clientConfiguration.option(AwsClientOption.CREDENTIALS_IDENTITY_PROVIDER))
183-
.resolveIdentity().join().accessKeyId();
184-
EndpointDiscoveryRequest endpointDiscoveryRequest = EndpointDiscoveryRequest.builder().required(true)
185-
.defaultEndpoint(clientConfiguration.option(SdkClientOption.ENDPOINT))
186-
.overrideConfiguration(testDiscoveryIdentifiersRequiredRequest.overrideConfiguration().orElse(null))
187-
.build();
188-
cachedEndpoint = endpointDiscoveryCache.get(key, endpointDiscoveryRequest);
181+
CompletableFuture<? extends AwsCredentialsIdentity> identityFuture =
182+
testDiscoveryIdentifiersRequiredRequest.overrideConfiguration()
183+
.flatMap(AwsRequestOverrideConfiguration::credentialsIdentityProvider)
184+
.orElseGet(() -> clientConfiguration.option(AwsClientOption.CREDENTIALS_IDENTITY_PROVIDER))
185+
.resolveIdentity();
186+
endpointFuture = identityFuture.thenApply(credentials -> {
187+
EndpointDiscoveryRequest endpointDiscoveryRequest = EndpointDiscoveryRequest.builder().required(true)
188+
.defaultEndpoint(clientConfiguration.option(SdkClientOption.ENDPOINT))
189+
.overrideConfiguration(testDiscoveryIdentifiersRequiredRequest.overrideConfiguration().orElse(null))
190+
.build();
191+
return endpointDiscoveryCache.get(credentials.accessKeyId(), endpointDiscoveryRequest);
192+
});
189193
}
190194

191-
CompletableFuture<TestDiscoveryIdentifiersRequiredResponse> executeFuture = clientHandler
192-
.execute(new ClientExecutionParams<TestDiscoveryIdentifiersRequiredRequest, TestDiscoveryIdentifiersRequiredResponse>()
195+
CompletableFuture<TestDiscoveryIdentifiersRequiredResponse> executeFuture =
196+
endpointFuture.thenCompose(cachedEndpoint ->
197+
clientHandler.execute(new ClientExecutionParams<TestDiscoveryIdentifiersRequiredRequest, TestDiscoveryIdentifiersRequiredResponse>()
193198
.withOperationName("TestDiscoveryIdentifiersRequired")
194199
.withMarshaller(new TestDiscoveryIdentifiersRequiredRequestMarshaller(protocolFactory))
195200
.withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler)
196201
.withMetricCollector(apiCallMetricCollector).discoveredEndpoint(cachedEndpoint)
197-
.withInput(testDiscoveryIdentifiersRequiredRequest));
202+
.withInput(testDiscoveryIdentifiersRequiredRequest)));
198203
CompletableFuture<TestDiscoveryIdentifiersRequiredResponse> whenCompleted = executeFuture.whenComplete((r, e) -> {
199204
metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect()));
200205
});
@@ -243,25 +248,29 @@ public CompletableFuture<TestDiscoveryOptionalResponse> testDiscoveryOptional(
243248
operationMetadata);
244249
boolean endpointDiscoveryEnabled = clientConfiguration.option(SdkClientOption.ENDPOINT_DISCOVERY_ENABLED);
245250
boolean endpointOverridden = clientConfiguration.option(SdkClientOption.ENDPOINT_OVERRIDDEN) == Boolean.TRUE;
246-
URI cachedEndpoint = null;
251+
CompletableFuture<URI> endpointFuture = CompletableFuture.completedFuture(null);
247252
if (endpointDiscoveryEnabled) {
248-
String key = testDiscoveryOptionalRequest.overrideConfiguration()
249-
.flatMap(AwsRequestOverrideConfiguration::credentialsIdentityProvider)
250-
.orElseGet(() -> clientConfiguration.option(AwsClientOption.CREDENTIALS_IDENTITY_PROVIDER))
251-
.resolveIdentity().join().accessKeyId();
252-
EndpointDiscoveryRequest endpointDiscoveryRequest = EndpointDiscoveryRequest.builder().required(false)
253-
.defaultEndpoint(clientConfiguration.option(SdkClientOption.ENDPOINT))
254-
.overrideConfiguration(testDiscoveryOptionalRequest.overrideConfiguration().orElse(null)).build();
255-
cachedEndpoint = endpointDiscoveryCache.get(key, endpointDiscoveryRequest);
253+
CompletableFuture<? extends AwsCredentialsIdentity> identityFuture =
254+
testDiscoveryOptionalRequest.overrideConfiguration()
255+
.flatMap(AwsRequestOverrideConfiguration::credentialsIdentityProvider)
256+
.orElseGet(() -> clientConfiguration.option(AwsClientOption.CREDENTIALS_IDENTITY_PROVIDER))
257+
.resolveIdentity();
258+
endpointFuture = identityFuture.thenApply(credentials -> {
259+
EndpointDiscoveryRequest endpointDiscoveryRequest = EndpointDiscoveryRequest.builder().required(false)
260+
.defaultEndpoint(clientConfiguration.option(SdkClientOption.ENDPOINT))
261+
.overrideConfiguration(testDiscoveryOptionalRequest.overrideConfiguration().orElse(null)).build();
262+
return endpointDiscoveryCache.get(credentials.accessKeyId(), endpointDiscoveryRequest);
263+
});
256264
}
257265

258-
CompletableFuture<TestDiscoveryOptionalResponse> executeFuture = clientHandler
259-
.execute(new ClientExecutionParams<TestDiscoveryOptionalRequest, TestDiscoveryOptionalResponse>()
266+
CompletableFuture<TestDiscoveryOptionalResponse> executeFuture =
267+
endpointFuture.thenCompose(cachedEndpoint ->
268+
clientHandler.execute(new ClientExecutionParams<TestDiscoveryOptionalRequest, TestDiscoveryOptionalResponse>()
260269
.withOperationName("TestDiscoveryOptional")
261270
.withMarshaller(new TestDiscoveryOptionalRequestMarshaller(protocolFactory))
262271
.withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler)
263272
.withMetricCollector(apiCallMetricCollector).discoveredEndpoint(cachedEndpoint)
264-
.withInput(testDiscoveryOptionalRequest));
273+
.withInput(testDiscoveryOptionalRequest)));
265274
CompletableFuture<TestDiscoveryOptionalResponse> whenCompleted = executeFuture.whenComplete((r, e) -> {
266275
metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect()));
267276
});
@@ -318,25 +327,29 @@ public CompletableFuture<TestDiscoveryRequiredResponse> testDiscoveryRequired(
318327
throw new IllegalStateException(
319328
"This operation requires endpoint discovery, but endpoint discovery was disabled on the client.");
320329
}
321-
URI cachedEndpoint = null;
330+
CompletableFuture<URI> endpointFuture = CompletableFuture.completedFuture(null);
322331
if (endpointDiscoveryEnabled) {
323-
String key = testDiscoveryRequiredRequest.overrideConfiguration()
324-
.flatMap(AwsRequestOverrideConfiguration::credentialsIdentityProvider)
325-
.orElseGet(() -> clientConfiguration.option(AwsClientOption.CREDENTIALS_IDENTITY_PROVIDER))
326-
.resolveIdentity().join().accessKeyId();
327-
EndpointDiscoveryRequest endpointDiscoveryRequest = EndpointDiscoveryRequest.builder().required(true)
328-
.defaultEndpoint(clientConfiguration.option(SdkClientOption.ENDPOINT))
329-
.overrideConfiguration(testDiscoveryRequiredRequest.overrideConfiguration().orElse(null)).build();
330-
cachedEndpoint = endpointDiscoveryCache.get(key, endpointDiscoveryRequest);
332+
CompletableFuture<? extends AwsCredentialsIdentity> identityFuture =
333+
testDiscoveryRequiredRequest.overrideConfiguration()
334+
.flatMap(AwsRequestOverrideConfiguration::credentialsIdentityProvider)
335+
.orElseGet(() -> clientConfiguration.option(AwsClientOption.CREDENTIALS_IDENTITY_PROVIDER))
336+
.resolveIdentity();
337+
endpointFuture = identityFuture.thenApply(credentials -> {
338+
EndpointDiscoveryRequest endpointDiscoveryRequest = EndpointDiscoveryRequest.builder().required(true)
339+
.defaultEndpoint(clientConfiguration.option(SdkClientOption.ENDPOINT))
340+
.overrideConfiguration(testDiscoveryRequiredRequest.overrideConfiguration().orElse(null)).build();
341+
return endpointDiscoveryCache.get(credentials.accessKeyId(), endpointDiscoveryRequest);
342+
});
331343
}
332344

333-
CompletableFuture<TestDiscoveryRequiredResponse> executeFuture = clientHandler
334-
.execute(new ClientExecutionParams<TestDiscoveryRequiredRequest, TestDiscoveryRequiredResponse>()
345+
CompletableFuture<TestDiscoveryRequiredResponse> executeFuture =
346+
endpointFuture.thenCompose(cachedEndpoint ->
347+
clientHandler.execute(new ClientExecutionParams<TestDiscoveryRequiredRequest, TestDiscoveryRequiredResponse>()
335348
.withOperationName("TestDiscoveryRequired")
336349
.withMarshaller(new TestDiscoveryRequiredRequestMarshaller(protocolFactory))
337350
.withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler)
338351
.withMetricCollector(apiCallMetricCollector).discoveredEndpoint(cachedEndpoint)
339-
.withInput(testDiscoveryRequiredRequest));
352+
.withInput(testDiscoveryRequiredRequest)));
340353
CompletableFuture<TestDiscoveryRequiredResponse> whenCompleted = executeFuture.whenComplete((r, e) -> {
341354
metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect()));
342355
});

0 commit comments

Comments
 (0)