Skip to content

Commit 40f6761

Browse files
committed
chore: support delegation of determining errors for an operation (smithy-lang#489)
1 parent 3e56ab6 commit 40f6761

File tree

4 files changed

+46
-13
lines changed

4 files changed

+46
-13
lines changed

smithy-typescript-codegen/src/main/java/software/amazon/smithy/typescript/codegen/integration/HttpBindingProtocolGenerator.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2060,7 +2060,7 @@ private void generateOperationResponseDeserializer(
20602060
// Write out the error deserialization dispatcher.
20612061
Set<StructureShape> errorShapes = HttpProtocolGeneratorUtils.generateErrorDispatcher(
20622062
context, operation, responseType, this::writeErrorCodeParser,
2063-
isErrorCodeInBody, this::getErrorBodyLocation);
2063+
isErrorCodeInBody, this::getErrorBodyLocation, this::getOperationErrors);
20642064
deserializingErrorShapes.addAll(errorShapes);
20652065
}
20662066

smithy-typescript-codegen/src/main/java/software/amazon/smithy/typescript/codegen/integration/HttpProtocolGeneratorUtils.java

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,21 @@
1717

1818
import java.nio.file.Paths;
1919
import java.util.List;
20+
import java.util.Map;
2021
import java.util.Optional;
2122
import java.util.Set;
23+
import java.util.TreeMap;
2224
import java.util.TreeSet;
2325
import java.util.function.BiFunction;
2426
import java.util.function.Consumer;
27+
import java.util.function.Function;
2528
import java.util.logging.Logger;
29+
import java.util.stream.Collectors;
2630
import software.amazon.smithy.codegen.core.CodegenException;
2731
import software.amazon.smithy.codegen.core.Symbol;
2832
import software.amazon.smithy.codegen.core.SymbolProvider;
2933
import software.amazon.smithy.codegen.core.SymbolReference;
3034
import software.amazon.smithy.model.knowledge.HttpBinding.Location;
31-
import software.amazon.smithy.model.knowledge.OperationIndex;
3235
import software.amazon.smithy.model.pattern.SmithyPattern;
3336
import software.amazon.smithy.model.shapes.MemberShape;
3437
import software.amazon.smithy.model.shapes.OperationShape;
@@ -313,6 +316,7 @@ public static void writeRetryableTrait(TypeScriptWriter writer, StructureShape e
313316
* @param errorCodeGenerator A consumer
314317
* @param shouldParseErrorBody Flag indicating whether need to parse response body in this dispatcher function
315318
* @param bodyErrorLocationModifier A function that returns the location of an error in a body given a data source.
319+
* @param operationErrorsToShapes A map of error names to their {@link ShapeId}.
316320
* @return A set of all error structure shapes for the operation that were dispatched to.
317321
*/
318322
static Set<StructureShape> generateErrorDispatcher(
@@ -321,11 +325,11 @@ static Set<StructureShape> generateErrorDispatcher(
321325
SymbolReference responseType,
322326
Consumer<GenerationContext> errorCodeGenerator,
323327
boolean shouldParseErrorBody,
324-
BiFunction<GenerationContext, String, String> bodyErrorLocationModifier
328+
BiFunction<GenerationContext, String, String> bodyErrorLocationModifier,
329+
BiFunction<GenerationContext, OperationShape, Map<String, ShapeId>> operationErrorsToShapes
325330
) {
326331
TypeScriptWriter writer = context.getWriter();
327332
SymbolProvider symbolProvider = context.getSymbolProvider();
328-
OperationIndex operationIndex = OperationIndex.of(context.getModel());
329333
Set<StructureShape> errorShapes = new TreeSet<>();
330334

331335
Symbol symbol = symbolProvider.toSymbol(operation);
@@ -350,10 +354,6 @@ static Set<StructureShape> generateErrorDispatcher(
350354
SymbolReference baseExceptionReference = getClientBaseException(context);
351355
errorCodeGenerator.accept(context);
352356

353-
TreeSet<StructureShape> structureShapes = new TreeSet<>(
354-
operationIndex.getErrors(operation, context.getService())
355-
);
356-
357357
Runnable defaultErrorHandler = () -> {
358358
if (shouldParseErrorBody) {
359359
// Body is already parsed above
@@ -379,20 +379,21 @@ static Set<StructureShape> generateErrorDispatcher(
379379
});
380380
};
381381

382-
if (!structureShapes.isEmpty()) {
382+
Map<String, ShapeId> operationNamesToShapes = operationErrorsToShapes.apply(context, operation);
383+
if (!operationNamesToShapes.isEmpty()) {
383384
writer.openBlock("switch (errorCode) {", "}", () -> {
384385
// Generate the case statement for each error, invoking the specific deserializer.
385386

386-
structureShapes.forEach(error -> {
387-
final ShapeId errorId = error.getId();
387+
operationNamesToShapes.forEach((name, errorId) -> {
388+
StructureShape error = context.getModel().expectShape(errorId).asStructureShape().get();
388389
// Track errors bound to the operation so their deserializers may be generated.
389390
errorShapes.add(error);
390391
Symbol errorSymbol = symbolProvider.toSymbol(error);
391392
String errorDeserMethodName = ProtocolGenerator.getDeserFunctionName(errorSymbol,
392393
context.getProtocolName()) + "Response";
393394
// Dispatch to the error deserialization function.
394395
String outputParam = shouldParseErrorBody ? "parsedOutput" : "output";
395-
writer.write("case $S:", errorId.getName());
396+
writer.write("case $S:", name);
396397
writer.write("case $S:", errorId.toString());
397398
writer.indent()
398399
.write("throw await $L($L, context);", errorDeserMethodName, outputParam)
@@ -468,4 +469,24 @@ private static SymbolReference getClientBaseException(GenerationContext context)
468469
.symbol(serviceExceptionSymbol)
469470
.build();
470471
}
472+
473+
/**
474+
* Returns a map of error names to their {@link ShapeId}.
475+
*
476+
* @param context the generation context
477+
* @param operation the operation shape to retrieve errors for
478+
* @return map of error names to {@link ShapeId}
479+
*/
480+
public static Map<String, ShapeId> getOperationErrors(GenerationContext context, OperationShape operation) {
481+
return operation.getErrors().stream()
482+
.collect(Collectors.toMap(
483+
shapeId -> shapeId.getName(context.getService()),
484+
Function.identity(),
485+
(x, y) -> {
486+
if (!x.equals(y)) {
487+
throw new CodegenException(String.format("conflicting error shape ids: %s, %s", x, y));
488+
}
489+
return x;
490+
}, TreeMap::new));
491+
}
471492
}

smithy-typescript-codegen/src/main/java/software/amazon/smithy/typescript/codegen/integration/HttpRpcProtocolGenerator.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ private void generateOperationDeserializer(GenerationContext context, OperationS
412412
// Write out the error deserialization dispatcher.
413413
Set<StructureShape> errorShapes = HttpProtocolGeneratorUtils.generateErrorDispatcher(
414414
context, operation, responseType, this::writeErrorCodeParser,
415-
isErrorCodeInBody, this::getErrorBodyLocation);
415+
isErrorCodeInBody, this::getErrorBodyLocation, this::getOperationErrors);
416416
deserializingErrorShapes.addAll(errorShapes);
417417
}
418418

smithy-typescript-codegen/src/main/java/software/amazon/smithy/typescript/codegen/integration/ProtocolGenerator.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
package software.amazon.smithy.typescript.codegen.integration;
1717

1818
import java.util.Collection;
19+
import java.util.Map;
1920
import java.util.stream.Collectors;
2021
import software.amazon.smithy.codegen.core.CodegenException;
2122
import software.amazon.smithy.codegen.core.Symbol;
@@ -263,6 +264,17 @@ static String getSerdeFunctionSymbolComponent(Symbol symbol, Shape shape) {
263264
}
264265
}
265266

267+
/**
268+
* Returns a map of error names to their {@link ShapeId}.
269+
*
270+
* @param context the generation context
271+
* @param operation the operation shape to retrieve errors for
272+
* @return map of error names to {@link ShapeId}
273+
*/
274+
default Map<String, ShapeId> getOperationErrors(GenerationContext context, OperationShape operation) {
275+
return HttpProtocolGeneratorUtils.getOperationErrors(context, operation);
276+
}
277+
266278
/**
267279
* Context object used for service serialization and deserialization.
268280
*/

0 commit comments

Comments
 (0)