Skip to content

Generate unique error response deserializers #44

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 5, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package software.amazon.smithy.typescript.codegen;

import java.util.stream.Collectors;
import java.util.stream.Stream;
import software.amazon.smithy.codegen.core.Symbol;
import software.amazon.smithy.codegen.core.SymbolProvider;
import software.amazon.smithy.codegen.core.SymbolReference;
Expand Down Expand Up @@ -129,7 +130,7 @@ private void renderNonErrorStructure() {
* <pre>{@code
* import * as _smithy from "../lib/smithy";
*
* export interface NoSuchResource extends _smithy.SmithyException {
* export interface NoSuchResource extends _smithy.SmithyException, $MetadataBearer {
* __type: "NoSuchResource";
* $fault: "client";
* resourceType: string | undefined;
Expand All @@ -146,7 +147,16 @@ private void renderErrorStructure() {
ErrorTrait errorTrait = shape.getTrait(ErrorTrait.class).orElseThrow(IllegalStateException::new);
Symbol symbol = symbolProvider.toSymbol(shape);
writer.writeShapeDocs(shape);
writer.openBlock("export interface $L extends _smithy.SmithyException {", symbol.getName());

// Find symbol references with the "extends" property, and add SmithyException.
String extendsFrom = Stream.concat(
Stream.of("_smithy.SmithyException"),
symbol.getReferences().stream()
.filter(ref -> ref.getProperty(SymbolVisitor.IMPLEMENTS_INTERFACE_PROPERTY).isPresent())
.map(SymbolReference::getAlias)
).collect(Collectors.joining(", "));

writer.openBlock("export interface $L extends $L {", symbol.getName(), extendsFrom);
writer.write("__type: $S;", shape.getId().getName());
writer.write("$$fault: $S;", errorTrait.getValue());
StructuredMemberWriter config = new StructuredMemberWriter(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import software.amazon.smithy.model.shapes.StructureShape;
import software.amazon.smithy.model.shapes.TimestampShape;
import software.amazon.smithy.model.shapes.UnionShape;
import software.amazon.smithy.model.traits.ErrorTrait;
import software.amazon.smithy.model.traits.HttpTrait;
import software.amazon.smithy.model.traits.TimestampFormatTrait.Format;
import software.amazon.smithy.typescript.codegen.ApplicationProtocol;
Expand All @@ -59,8 +60,9 @@ public abstract class HttpBindingProtocolGenerator implements ProtocolGenerator

private static final Logger LOGGER = Logger.getLogger(HttpBindingProtocolGenerator.class.getName());

private final Set<Shape> documentSerializingShapes = new TreeSet<>();
private final Set<Shape> documentDeserializingShapes = new TreeSet<>();
private final Set<Shape> serializingDocumentShapes = new TreeSet<>();
private final Set<Shape> deserializingDocumentShapes = new TreeSet<>();
private final Set<StructureShape> deserializingErrorShapes = new TreeSet<>();

@Override
public ApplicationProtocol getApplicationProtocol() {
Expand Down Expand Up @@ -105,8 +107,9 @@ public ApplicationProtocol getApplicationProtocol() {

@Override
public void generateSharedComponents(GenerationContext context) {
generateDocumentShapeSerializers(context, documentSerializingShapes);
generateDocumentShapeDeserializers(context, documentDeserializingShapes);
deserializingErrorShapes.forEach(error -> generateErrorDeserializer(context, error));
generateDocumentShapeSerializers(context, serializingDocumentShapes);
generateDocumentShapeDeserializers(context, deserializingDocumentShapes);
HttpProtocolGeneratorUtils.generateMetadataDeserializer(context, getApplicationProtocol().getResponseType());
}

Expand Down Expand Up @@ -197,7 +200,7 @@ private void generateOperationSerializer(
documentBindings.stream()
.map(HttpBinding::getMember)
.map(member -> context.getModel().getShapeIndex().getShape(member.getTarget()).get())
.forEach(documentSerializingShapes::add);
.forEach(serializingDocumentShapes::add);
writer.write("body: body,");
}
if (!queryBindings.isEmpty()) {
Expand All @@ -221,7 +224,7 @@ private List<HttpBinding> writeRequestLabels(

if (!labelBindings.isEmpty()) {
ShapeIndex index = context.getModel().getShapeIndex();
writer.write("let resolvedPath = $S;", trait.getUri());
writer.write("const resolvedPath = $S;", trait.getUri());
for (HttpBinding binding : labelBindings) {
String memberName = symbolProvider.toMemberName(binding.getMember());
writer.openBlock("if (input.$L !== undefined) {", "}", memberName, () -> {
Expand All @@ -247,7 +250,7 @@ private List<HttpBinding> writeRequestQueryString(

if (!queryBindings.isEmpty()) {
ShapeIndex index = context.getModel().getShapeIndex();
writer.write("let query: any = {};");
writer.write("const query: any = {};");
for (HttpBinding binding : queryBindings) {
String memberName = symbolProvider.toMemberName(binding.getMember());
writer.openBlock("if (input.$L !== undefined) {", "}", memberName, () -> {
Expand All @@ -271,7 +274,7 @@ private void writeHeaders(
SymbolProvider symbolProvider = context.getSymbolProvider();

// Headers are always present either from the default document or the payload.
writer.write("let headers: any = {};");
writer.write("const headers: any = {};");
writer.write("headers['Content-Type'] = $S;", bindingIndex.determineRequestContentType(
operation, getDocumentContentType()));

Expand Down Expand Up @@ -316,7 +319,7 @@ private List<HttpBinding> writeRequestBody(

if (!documentBindings.isEmpty()) {
// Write the default `body` property.
context.getWriter().write("let body: any = undefined;");
context.getWriter().write("let body: any = {};");
serializeInputDocument(context, operation, documentBindings);
return documentBindings;
}
Expand Down Expand Up @@ -466,7 +469,7 @@ private String getNamedMembersInputParam(
* <p>For example:
*
* <pre>{@code
* let bodyParams: any = {};
* const bodyParams: any = {};
* if (input.barValue !== undefined) {
* bodyParams['barValue'] = input.barValue;
* }
Expand Down Expand Up @@ -515,8 +518,8 @@ private void generateOperationDeserializer(
});

// Start deserializing the response.
writer.write("let data: any = await parseBody(output.body, context)");
writer.openBlock("let contents: $T = {", "};", outputType, () -> {
writer.write("const data: any = await parseBody(output.body, context)");
writer.openBlock("const contents: $T = {", "};", outputType, () -> {
writer.write("$$metadata: deserializeMetadata(output),");

// Only set a type and the members if we have output.
Expand All @@ -533,27 +536,65 @@ private void generateOperationDeserializer(
// Track all shapes bound to the document so their deserializers may be generated.
documentBindings.forEach(binding -> {
Shape target = shapeIndex.getShape(binding.getMember().getTarget()).get();
documentDeserializingShapes.add(target);
deserializingDocumentShapes.add(target);
});
writer.write("return Promise.resolve(contents);");
});
writer.write("");

// Write out the error deserialization dispatcher.
documentDeserializingShapes.addAll(HttpProtocolGeneratorUtils.generateErrorDispatcher(
context, operation, responseType, this::writeErrorCodeParser));
Set<StructureShape> errorShapes = HttpProtocolGeneratorUtils.generateErrorDispatcher(
context, operation, responseType, this::writeErrorCodeParser);
deserializingErrorShapes.addAll(errorShapes);
}

private void generateErrorDeserializer(GenerationContext context, StructureShape error) {
TypeScriptWriter writer = context.getWriter();
SymbolProvider symbolProvider = context.getSymbolProvider();
HttpBindingIndex bindingIndex = context.getModel().getKnowledge(HttpBindingIndex.class);
ShapeIndex shapeIndex = context.getModel().getShapeIndex();
Symbol errorSymbol = symbolProvider.toSymbol(error);
String errorDeserMethodName = ProtocolGenerator.getDeserFunctionName(errorSymbol,
context.getProtocolName()) + "Response";

writer.openBlock("const $L = (\n"
+ " output: any,\n"
+ " context: SerdeContext\n"
+ "): $T => {", "};", errorDeserMethodName, errorSymbol, () -> {
writer.write("const data: any = output.body;");

writer.openBlock("const contents: $T = {", "};", errorSymbol, () -> {
writer.write("__type: $S,", error.getId().getName());
writer.write("$$fault: $S,", error.getTrait(ErrorTrait.class).get().getValue());
writer.write("$$metadata: deserializeMetadata(output),");
// Set all the members to undefined to meet type constraints.
new TreeMap<>(error.getAllMembers())
.forEach((memberName, memberShape) -> writer.write("$L: undefined,", memberName));
});

readHeaders(context, error, bindingIndex);
List<HttpBinding> documentBindings = readResponseBody(context, error, bindingIndex);
// Track all shapes bound to the document so their deserializers may be generated.
documentBindings.forEach(binding -> {
Shape target = shapeIndex.getShape(binding.getMember().getTarget()).get();
deserializingDocumentShapes.add(target);
});
writer.write("return contents;");
});

writer.write("");
}

private void readHeaders(
GenerationContext context,
OperationShape operation,
Shape operationOrError,
HttpBindingIndex bindingIndex
) {
TypeScriptWriter writer = context.getWriter();
SymbolProvider symbolProvider = context.getSymbolProvider();

ShapeIndex index = context.getModel().getShapeIndex();
for (HttpBinding binding : bindingIndex.getRequestBindings(operation, Location.HEADER)) {
for (HttpBinding binding : bindingIndex.getResponseBindings(operationOrError, Location.HEADER)) {
String memberName = symbolProvider.toMemberName(binding.getMember());
writer.openBlock("if (output.headers[$S] !== undefined) {", "}", binding.getLocationName(), () -> {
Shape target = index.getShape(binding.getMember().getTarget()).get();
Expand All @@ -564,7 +605,8 @@ private void readHeaders(
}

// Handle loading up prefix headers.
List<HttpBinding> prefixHeaderBindings = bindingIndex.getResponseBindings(operation, Location.PREFIX_HEADERS);
List<HttpBinding> prefixHeaderBindings =
bindingIndex.getResponseBindings(operationOrError, Location.PREFIX_HEADERS);
if (!prefixHeaderBindings.isEmpty()) {
// Run through the headers one time, matching any prefix groups.
writer.openBlock("Object.keys(output.headers).forEach(header -> {", "});", () -> {
Expand Down Expand Up @@ -592,16 +634,16 @@ private void readHeaders(

private List<HttpBinding> readResponseBody(
GenerationContext context,
OperationShape operation,
Shape operationOrError,
HttpBindingIndex bindingIndex
) {
TypeScriptWriter writer = context.getWriter();
List<HttpBinding> documentBindings = bindingIndex.getResponseBindings(operation, Location.DOCUMENT);
List<HttpBinding> documentBindings = bindingIndex.getResponseBindings(operationOrError, Location.DOCUMENT);
documentBindings.sort(Comparator.comparing(HttpBinding::getMemberName));
List<HttpBinding> payloadBindings = bindingIndex.getResponseBindings(operation, Location.PAYLOAD);
List<HttpBinding> payloadBindings = bindingIndex.getResponseBindings(operationOrError, Location.PAYLOAD);

if (!documentBindings.isEmpty()) {
deserializeOutputDocument(context, operation, documentBindings);
deserializeOutputDocument(context, operationOrError, documentBindings);
return documentBindings;
}
if (!payloadBindings.isEmpty()) {
Expand Down Expand Up @@ -813,12 +855,12 @@ private String getNumberOutputParam(Location bindingType, String dataSource, Sha
* }</pre>
*
* @param context The generation context.
* @param operation The operation being generated.
* @param operationOrError The operation or error with a document being deserialized.
* @param documentBindings The bindings to read from the document.
*/
protected abstract void deserializeOutputDocument(
GenerationContext context,
OperationShape operation,
Shape operationOrError,
List<HttpBinding> documentBindings
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import software.amazon.smithy.codegen.core.SymbolReference;
import software.amazon.smithy.model.shapes.OperationShape;
import software.amazon.smithy.model.shapes.Shape;
import software.amazon.smithy.model.shapes.StructureShape;
import software.amazon.smithy.model.traits.TimestampFormatTrait.Format;
import software.amazon.smithy.typescript.codegen.TypeScriptWriter;
import software.amazon.smithy.typescript.codegen.integration.ProtocolGenerator.GenerationContext;
Expand Down Expand Up @@ -116,17 +117,17 @@ static void generateMetadataDeserializer(GenerationContext context, SymbolRefere
* @param operation The operation to generate for.
* @param responseType The response type for the HTTP protocol.
* @param errorCodeGenerator A consumer
* @return A set of all error shapes for the operation that were dispatched to.
* @return A set of all error structure shapes for the operation that were dispatched to.
*/
static Set<Shape> generateErrorDispatcher(
static Set<StructureShape> generateErrorDispatcher(
GenerationContext context,
OperationShape operation,
SymbolReference responseType,
Consumer<GenerationContext> errorCodeGenerator
) {
TypeScriptWriter writer = context.getWriter();
SymbolProvider symbolProvider = context.getSymbolProvider();
Set<Shape> errorShapes = new TreeSet<>();
Set<StructureShape> errorShapes = new TreeSet<>();

Symbol symbol = symbolProvider.toSymbol(operation);
Symbol outputType = symbol.expectProperty("outputType", Symbol.class);
Expand All @@ -136,22 +137,28 @@ static Set<Shape> generateErrorDispatcher(
+ " output: $T,\n"
+ " context: SerdeContext,\n"
+ "): Promise<$T> {", "}", errorMethodName, responseType, outputType, () -> {
writer.write("let data: any = await parseBody(output.body, context);");
writer.write("const data: any = await parseBody(output.body, context);");
// Create a holding object since we have already parsed the body, but retain the rest of the output.
writer.openBlock("const parsedOutput: any = {", "};", () -> {
writer.write("...output,");
writer.write("body: data,");
});
writer.write("let response: any;");
writer.write("let errorCode: String;");
errorCodeGenerator.accept(context);
writer.openBlock("switch (errorCode) {", "}", () -> {
// Generate the case statement for each error, invoking the specific deserializer.
new TreeSet<>(operation.getErrors()).forEach(errorId -> {
Shape error = context.getModel().getShapeIndex().getShape(errorId).get();
StructureShape error = context.getModel().getShapeIndex().getShape(errorId)
.get().asStructureShape().get();
// Track errors bound to the operation so their deserializers may be generated.
errorShapes.add(error);
Symbol errorSymbol = symbolProvider.toSymbol(error);
String errorDeserMethodName = ProtocolGenerator.getDeserFunctionName(errorSymbol,
context.getProtocolName());
context.getProtocolName()) + "Response";
writer.openBlock("case $S:\ncase $S:", " break;", errorId.getName(), errorId.toString(), () -> {
// Dispatch to the error deserialization function.
writer.write("response = $L(data, context);", errorDeserMethodName);
writer.write("response = $L(parsedOutput, context);", errorDeserMethodName);
});
});

Expand All @@ -164,7 +171,7 @@ static Set<Shape> generateErrorDispatcher(
writer.write("$$fault: \"client\",");
}).dedent();
});
writer.write("return Promise.reject(response);");
writer.write("return Promise.reject(Object.assign(new Error(response.__type), response));");
});
writer.write("");

Expand Down
Loading