Skip to content

Commit 958ffaa

Browse files
authored
feat: generate unified error dispatcher (#1150)
* feat: generate unified error dispatcher * change Promise<unknown> to Promise<never>
1 parent a985fef commit 958ffaa

File tree

5 files changed

+70
-37
lines changed

5 files changed

+70
-37
lines changed

.changeset/shy-nails-wonder.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
---
2+
---

smithy-typescript-codegen/src/main/java/software/amazon/smithy/typescript/codegen/integration/HttpBindingProtocolGenerator.java

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,7 @@ public void generateResponseDeserializers(GenerationContext context) {
517517

518518
Set<OperationShape> containedOperations = new TreeSet<>(
519519
topDownIndex.getContainedOperations(context.getService()));
520+
520521
for (OperationShape operation : containedOperations) {
521522
OptionalUtils.ifPresentOrElse(
522523
operation.getTrait(HttpTrait.class),
@@ -525,6 +526,18 @@ public void generateResponseDeserializers(GenerationContext context) {
525526
"Unable to generate %s protocol response bindings for %s because it does not have an "
526527
+ "http binding trait", getName(), operation.getId())));
527528
}
529+
530+
SymbolReference responseType = getApplicationProtocol().getResponseType();
531+
Set<StructureShape> errorShapes = HttpProtocolGeneratorUtils.generateUnifiedErrorDispatcher(
532+
context,
533+
containedOperations.stream().toList(),
534+
responseType,
535+
this::writeErrorCodeParser,
536+
isErrorCodeInBody,
537+
this::getErrorBodyLocation,
538+
this::getOperationErrors
539+
);
540+
deserializingErrorShapes.addAll(errorShapes);
528541
}
529542

530543
private void generateOperationResponseSerializer(
@@ -2091,7 +2104,7 @@ private void generateOperationResponseDeserializer(
20912104
// e.g., deserializeAws_restJson1_1ExecuteStatement
20922105
String methodName = ProtocolGenerator.getDeserFunctionShortName(symbol);
20932106
String methodLongName = ProtocolGenerator.getDeserFunctionName(symbol, getName());
2094-
String errorMethodName = methodName + "Error";
2107+
String errorMethodName = "de_CommandError";
20952108
// Add the normalized output type.
20962109
Symbol outputType = symbol.expectProperty("outputType", Symbol.class);
20972110
String contextType = CodegenUtils.getOperationDeserializerContextType(context.getSettings(), writer,
@@ -2129,11 +2142,6 @@ private void generateOperationResponseDeserializer(
21292142
writer.write("return contents;");
21302143
});
21312144
writer.write("");
2132-
// Write out the error deserialization dispatcher.
2133-
Set<StructureShape> errorShapes = HttpProtocolGeneratorUtils.generateErrorDispatcher(
2134-
context, operation, responseType, this::writeErrorCodeParser,
2135-
isErrorCodeInBody, this::getErrorBodyLocation, this::getOperationErrors);
2136-
deserializingErrorShapes.addAll(errorShapes);
21372145
}
21382146

21392147
private void generateErrorDeserializer(GenerationContext context, StructureShape error) {

smithy-typescript-codegen/src/main/java/software/amazon/smithy/typescript/codegen/integration/HttpProtocolGeneratorUtils.java

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -298,51 +298,49 @@ public static void writeRetryableTrait(TypeScriptWriter writer, StructureShape e
298298

299299
/**
300300
* Writes a function used to dispatch to the proper error deserializer
301-
* for each error that the operation can return. The generated function
301+
* for each error that any operation can return. The generated function
302302
* assumes a deserialization function is generated for the structures
303303
* returned.
304304
*
305305
* @param context The generation context.
306-
* @param operation The operation to generate for.
307306
* @param responseType The response type for the HTTP protocol.
308307
* @param errorCodeGenerator A consumer
309308
* @param shouldParseErrorBody Flag indicating whether need to parse response body in this dispatcher function
310309
* @param bodyErrorLocationModifier A function that returns the location of an error in a body given a data source.
311310
* @param operationErrorsToShapes A map of error names to their {@link ShapeId}.
312311
* @return A set of all error structure shapes for the operation that were dispatched to.
313312
*/
314-
static Set<StructureShape> generateErrorDispatcher(
315-
GenerationContext context,
316-
OperationShape operation,
317-
SymbolReference responseType,
318-
Consumer<GenerationContext> errorCodeGenerator,
319-
boolean shouldParseErrorBody,
320-
BiFunction<GenerationContext, String, String> bodyErrorLocationModifier,
321-
BiFunction<GenerationContext, OperationShape, Map<String, ShapeId>> operationErrorsToShapes
313+
static Set<StructureShape> generateUnifiedErrorDispatcher(
314+
GenerationContext context,
315+
List<OperationShape> operations,
316+
SymbolReference responseType,
317+
Consumer<GenerationContext> errorCodeGenerator,
318+
boolean shouldParseErrorBody,
319+
BiFunction<GenerationContext, String, String> bodyErrorLocationModifier,
320+
BiFunction<GenerationContext, List<OperationShape>, Map<String, ShapeId>> operationErrorsToShapes
322321
) {
323322
TypeScriptWriter writer = context.getWriter();
324323
SymbolProvider symbolProvider = context.getSymbolProvider();
325324
Set<StructureShape> errorShapes = new TreeSet<>();
326325

327-
Symbol symbol = symbolProvider.toSymbol(operation);
328-
Symbol outputType = symbol.expectProperty("outputType", Symbol.class);
329-
String errorMethodName = ProtocolGenerator.getDeserFunctionShortName(symbol) + "Error";
330-
String errorMethodLongName = ProtocolGenerator.getDeserFunctionName(symbol, context.getProtocolName())
331-
+ "Error";
326+
String errorMethodName = "de_CommandError";
327+
String errorMethodLongName = "deserialize_"
328+
+ ProtocolGenerator.getSanitizedName(context.getProtocolName())
329+
+ "CommandError";
332330

333331
writer.writeDocs(errorMethodLongName);
334-
writer.openBlock("const $L = async(\n"
335-
+ " output: $T,\n"
336-
+ " context: __SerdeContext,\n"
337-
+ "): Promise<$T> => {", "}", errorMethodName, responseType, outputType, () -> {
332+
writer.openBlock("const $L = async(\n"
333+
+ " output: $T,\n"
334+
+ " context: __SerdeContext,\n"
335+
+ "): Promise<never> => {", "}", errorMethodName, responseType, () -> {
338336
// Prepare error response for parsing error code. If error code needs to be parsed from response body
339337
// then we collect body and parse it to JS object, otherwise leave the response body as is.
340338
if (shouldParseErrorBody) {
341339
writer.openBlock("const parsedOutput: any = {", "};",
342-
() -> {
343-
writer.write("...output,");
344-
writer.write("body: await parseErrorBody(output.body, context)");
345-
});
340+
() -> {
341+
writer.write("...output,");
342+
writer.write("body: await parseErrorBody(output.body, context)");
343+
});
346344
}
347345

348346
// Error responses must be at least BaseException interface
@@ -370,7 +368,8 @@ static Set<StructureShape> generateErrorDispatcher(
370368
});
371369
};
372370

373-
Map<String, ShapeId> operationNamesToShapes = operationErrorsToShapes.apply(context, operation);
371+
Map<String, ShapeId> operationNamesToShapes = operationErrorsToShapes.apply(context, operations);
372+
374373
if (!operationNamesToShapes.isEmpty()) {
375374
writer.openBlock("switch (errorCode) {", "}", () -> {
376375
// Generate the case statement for each error, invoking the specific deserializer.

smithy-typescript-codegen/src/main/java/software/amazon/smithy/typescript/codegen/integration/HttpRpcProtocolGenerator.java

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,18 @@ public void generateResponseDeserializers(GenerationContext context) {
224224
for (OperationShape operation : containedOperations) {
225225
generateOperationDeserializer(context, operation);
226226
}
227+
228+
SymbolReference responseType = getApplicationProtocol().getResponseType();
229+
Set<StructureShape> errorShapes = HttpProtocolGeneratorUtils.generateUnifiedErrorDispatcher(
230+
context,
231+
containedOperations.stream().toList(),
232+
responseType,
233+
this::writeErrorCodeParser,
234+
isErrorCodeInBody,
235+
this::getErrorBodyLocation,
236+
this::getOperationErrors
237+
);
238+
deserializingErrorShapes.addAll(errorShapes);
227239
}
228240

229241
private void generateOperationSerializer(GenerationContext context, OperationShape operation) {
@@ -435,7 +447,7 @@ private void generateOperationDeserializer(GenerationContext context, OperationS
435447
// e.g., deserializeAws_restJson1_1ExecuteStatement
436448
String methodName = ProtocolGenerator.getDeserFunctionShortName(symbol);
437449
String methodLongName = ProtocolGenerator.getDeserFunctionName(symbol, getName());
438-
String errorMethodName = methodName + "Error";
450+
String errorMethodName = "de_CommandError";
439451
String serdeContextType = CodegenUtils.getOperationDeserializerContextType(context.getSettings(), writer,
440452
context.getModel(), operation);
441453
// Add the normalized output type.
@@ -465,12 +477,6 @@ private void generateOperationDeserializer(GenerationContext context, OperationS
465477
writer.write("return response;");
466478
});
467479
writer.write("");
468-
469-
// Write out the error deserialization dispatcher.
470-
Set<StructureShape> errorShapes = HttpProtocolGeneratorUtils.generateErrorDispatcher(
471-
context, operation, responseType, this::writeErrorCodeParser,
472-
isErrorCodeInBody, this::getErrorBodyLocation, this::getOperationErrors);
473-
deserializingErrorShapes.addAll(errorShapes);
474480
}
475481

476482
private void generateErrorDeserializer(GenerationContext context, StructureShape error) {

smithy-typescript-codegen/src/main/java/software/amazon/smithy/typescript/codegen/integration/ProtocolGenerator.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
package software.amazon.smithy.typescript.codegen.integration;
1717

1818
import java.util.Collection;
19+
import java.util.LinkedHashMap;
1920
import java.util.Map;
2021
import java.util.stream.Collectors;
2122
import software.amazon.smithy.codegen.core.CodegenException;
@@ -303,6 +304,23 @@ default Map<String, ShapeId> getOperationErrors(GenerationContext context, Opera
303304
return HttpProtocolGeneratorUtils.getOperationErrors(context, operation);
304305
}
305306

307+
/**
308+
* Returns a map of error names to their {@link ShapeId}.
309+
*
310+
* @param context the generation context
311+
* @param operations the operation shapes to retrieve errors for
312+
* @return map of error names to {@link ShapeId}
313+
*/
314+
default Map<String, ShapeId> getOperationErrors(GenerationContext context, Collection<OperationShape> operations) {
315+
Map<String, ShapeId> errors = new LinkedHashMap<>();
316+
for (OperationShape operation : operations) {
317+
errors.putAll(
318+
getOperationErrors(context, operation)
319+
);
320+
}
321+
return errors;
322+
}
323+
306324
/**
307325
* Context object used for service serialization and deserialization.
308326
*/

0 commit comments

Comments
 (0)