Skip to content

Commit 4bf58c3

Browse files
Merge pull request #279 from JordonPhillips/server-serde-protocol-tests
Generate ssdk request protocol tests
2 parents d520610 + 523cf7e commit 4bf58c3

File tree

2 files changed

+223
-54
lines changed

2 files changed

+223
-54
lines changed

smithy-typescript-codegen/src/main/java/software/amazon/smithy/typescript/codegen/CodegenVisitor.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ void execute() {
202202
// Generate protocol tests IFF found in the model.
203203
if (protocolGenerator != null) {
204204
ShapeId protocol = protocolGenerator.getProtocol();
205-
new HttpProtocolTestGenerator(settings, model, protocol, symbolProvider, writers).run();
205+
new HttpProtocolTestGenerator(settings, model, protocol, symbolProvider, writers, protocolGenerator).run();
206206
}
207207

208208
// Write each pending writer.

smithy-typescript-codegen/src/main/java/software/amazon/smithy/typescript/codegen/HttpProtocolTestGenerator.java

Lines changed: 222 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,17 @@
1717

1818
import static java.lang.String.format;
1919

20+
import java.io.UnsupportedEncodingException;
21+
import java.net.URLDecoder;
22+
import java.nio.charset.StandardCharsets;
2023
import java.util.List;
2124
import java.util.Locale;
2225
import java.util.Map;
2326
import java.util.Optional;
2427
import java.util.Set;
2528
import java.util.TreeSet;
2629
import java.util.logging.Logger;
30+
import java.util.stream.Collectors;
2731
import software.amazon.smithy.codegen.core.CodegenException;
2832
import software.amazon.smithy.codegen.core.Symbol;
2933
import software.amazon.smithy.codegen.core.SymbolProvider;
@@ -58,9 +62,11 @@
5862
import software.amazon.smithy.protocoltests.traits.HttpRequestTestsTrait;
5963
import software.amazon.smithy.protocoltests.traits.HttpResponseTestCase;
6064
import software.amazon.smithy.protocoltests.traits.HttpResponseTestsTrait;
65+
import software.amazon.smithy.typescript.codegen.integration.ProtocolGenerator;
6166
import software.amazon.smithy.utils.IoUtils;
6267
import software.amazon.smithy.utils.MapUtils;
6368
import software.amazon.smithy.utils.Pair;
69+
import software.amazon.smithy.utils.StringUtils;
6470

6571
/**
6672
* Generates HTTP protocol test cases to be run using Jest.
@@ -86,6 +92,7 @@ final class HttpProtocolTestGenerator implements Runnable {
8692
private final SymbolProvider symbolProvider;
8793
private final Symbol serviceSymbol;
8894
private final Set<String> additionalStubs = new TreeSet<>();
95+
private final ProtocolGenerator protocolGenerator;
8996

9097
/** Vends a TypeScript IFF it's needed. */
9198
private final TypeScriptDelegator delegator;
@@ -98,14 +105,16 @@ final class HttpProtocolTestGenerator implements Runnable {
98105
Model model,
99106
ShapeId protocol,
100107
SymbolProvider symbolProvider,
101-
TypeScriptDelegator delegator
108+
TypeScriptDelegator delegator,
109+
ProtocolGenerator protocolGenerator
102110
) {
103111
this.settings = settings;
104112
this.model = model;
105113
this.protocol = protocol;
106114
this.service = settings.getService(model);
107115
this.symbolProvider = symbolProvider;
108116
this.delegator = delegator;
117+
this.protocolGenerator = protocolGenerator;
109118
serviceSymbol = symbolProvider.toSymbol(service);
110119
}
111120

@@ -116,30 +125,11 @@ public void run() {
116125

117126
// Use a TreeSet to have a fixed ordering of tests.
118127
for (OperationShape operation : new TreeSet<>(topDownIndex.getContainedOperations(service))) {
119-
if (!operation.hasTag("server-only")) {
120-
// 1. Generate test cases for each request.
121-
operation.getTrait(HttpRequestTestsTrait.class).ifPresent(trait -> {
122-
for (HttpRequestTestCase testCase : trait.getTestCasesFor(AppliesTo.CLIENT)) {
123-
onlyIfProtocolMatches(testCase, () -> generateRequestTest(operation, testCase));
124-
}
125-
});
126-
// 2. Generate test cases for each response.
127-
operation.getTrait(HttpResponseTestsTrait.class).ifPresent(trait -> {
128-
for (HttpResponseTestCase testCase : trait.getTestCasesFor(AppliesTo.CLIENT)) {
129-
onlyIfProtocolMatches(testCase, () -> generateResponseTest(operation, testCase));
130-
}
131-
});
132-
// 3. Generate test cases for each error on each operation.
133-
for (StructureShape error : operationIndex.getErrors(operation)) {
134-
if (!error.hasTag("server-only")) {
135-
error.getTrait(HttpResponseTestsTrait.class).ifPresent(trait -> {
136-
for (HttpResponseTestCase testCase : trait.getTestCasesFor(AppliesTo.CLIENT)) {
137-
onlyIfProtocolMatches(testCase,
138-
() -> generateErrorResponseTest(operation, error, testCase));
139-
}
140-
});
141-
}
142-
}
128+
if (settings.generateClient()) {
129+
generateClientOperationTests(operation, operationIndex);
130+
}
131+
if (settings.generateServerSdk()) {
132+
generateServerOperationTests(operation, operationIndex);
143133
}
144134
}
145135

@@ -149,6 +139,45 @@ public void run() {
149139
}
150140
}
151141

142+
private void generateClientOperationTests(OperationShape operation, OperationIndex operationIndex) {
143+
if (!operation.hasTag("server-only")) {
144+
// 1. Generate test cases for each request.
145+
operation.getTrait(HttpRequestTestsTrait.class).ifPresent(trait -> {
146+
for (HttpRequestTestCase testCase : trait.getTestCasesFor(AppliesTo.CLIENT)) {
147+
onlyIfProtocolMatches(testCase, () -> generateClientRequestTest(operation, testCase));
148+
}
149+
});
150+
// 2. Generate test cases for each response.
151+
operation.getTrait(HttpResponseTestsTrait.class).ifPresent(trait -> {
152+
for (HttpResponseTestCase testCase : trait.getTestCasesFor(AppliesTo.CLIENT)) {
153+
onlyIfProtocolMatches(testCase, () -> generateResponseTest(operation, testCase));
154+
}
155+
});
156+
// 3. Generate test cases for each error on each operation.
157+
for (StructureShape error : operationIndex.getErrors(operation)) {
158+
if (!error.hasTag("server-only")) {
159+
error.getTrait(HttpResponseTestsTrait.class).ifPresent(trait -> {
160+
for (HttpResponseTestCase testCase : trait.getTestCasesFor(AppliesTo.CLIENT)) {
161+
onlyIfProtocolMatches(testCase,
162+
() -> generateErrorResponseTest(operation, error, testCase));
163+
}
164+
});
165+
}
166+
}
167+
}
168+
}
169+
170+
private void generateServerOperationTests(OperationShape operation, OperationIndex operationIndex) {
171+
if (!operation.hasTag("client-only")) {
172+
// 1. Generate test cases for each request.
173+
operation.getTrait(HttpRequestTestsTrait.class).ifPresent(trait -> {
174+
for (HttpRequestTestCase testCase : trait.getTestCasesFor(AppliesTo.SERVER)) {
175+
onlyIfProtocolMatches(testCase, () -> generateServerRequestTest(operation, testCase));
176+
}
177+
});
178+
}
179+
}
180+
152181
// Only generate test cases when its protocol matches the target protocol.
153182
private <T extends HttpMessageTestCase> void onlyIfProtocolMatches(T testCase, Runnable runnable) {
154183
if (testCase.getProtocol().equals(protocol)) {
@@ -175,7 +204,7 @@ private String createTestCaseFilename() {
175204
return TEST_CASE_FILE_TEMPLATE.replace("%s", baseName);
176205
}
177206

178-
private void generateRequestTest(OperationShape operation, HttpRequestTestCase testCase) {
207+
private void generateClientRequestTest(OperationShape operation, HttpRequestTestCase testCase) {
179208
Symbol operationSymbol = symbolProvider.toSymbol(operation);
180209

181210
String testName = testCase.getId() + ":Request";
@@ -217,6 +246,107 @@ private void generateRequestTest(OperationShape operation, HttpRequestTestCase t
217246
});
218247
}
219248

249+
private void generateServerRequestTest(OperationShape operation, HttpRequestTestCase testCase) {
250+
Symbol operationSymbol = symbolProvider.toSymbol(operation);
251+
252+
// Lowercase all the headers we're expecting as this is what we'll get.
253+
Map<String, String> headers = testCase.getHeaders().entrySet().stream()
254+
.map(entry -> new Pair<>(entry.getKey().toLowerCase(Locale.US), entry.getValue()))
255+
.collect(MapUtils.toUnmodifiableMap(Pair::getLeft, Pair::getRight));
256+
String queryParameters = Node.prettyPrintJson(buildQueryBag(testCase));
257+
String headerParameters = Node.prettyPrintJson(ObjectNode.fromStringMap(headers));
258+
String body = testCase.getBody().orElse(null);
259+
260+
String testName = testCase.getId() + ":ServerRequest";
261+
testCase.getDocumentation().ifPresent(writer::writeDocs);
262+
writer.openBlock("it($S, async () => {", "});\n", testName, () -> {
263+
// TODO: use the symbol provider when it's ready
264+
String serviceName = StringUtils.capitalize(service.getId().getName());
265+
String operationName = StringUtils.capitalize(operation.getId().getName());
266+
Symbol inputType = operationSymbol.expectProperty("inputType", Symbol.class);
267+
Symbol outputType = operationSymbol.expectProperty("outputType", Symbol.class);
268+
269+
// Create a mock function to set in place of the server operation function so we can capture
270+
// input and other information.
271+
writer.write("let testFunction = jest.fn();");
272+
writer.write("testFunction.mockReturnValue({});");
273+
274+
// We use a partial here so that we don't have to define the entire service, but still get the advantages
275+
// the type checker, including excess property checking. Later on we'll use `as` to cast this to the
276+
// full service so that we can actually use it.
277+
writer.addImport(serviceName + "Service", null, "./server");
278+
writer.openBlock("const testService: Partial<$LService> = {", "};", serviceName, () -> {
279+
writer.addImport("Operation", "__Operation", "@aws-smithy/server-common");
280+
writer.write("$L: testFunction as __Operation<$T, $T>,", operationName, inputType, outputType);
281+
});
282+
283+
String getHandlerName = String.format("get%sServiceHandler", serviceName);
284+
writer.addImport(getHandlerName, getHandlerName,
285+
"./protocols/" + ProtocolGenerator.getSanitizedName(protocolGenerator.getName()));
286+
287+
// Cast the service as any so TS will ignore the fact that the type being passed in is incomplete.
288+
writer.write("const handler = $L(testService as $LService);", getHandlerName, serviceName);
289+
290+
// Construct a new http request according to the test case definition.
291+
writer.openBlock("const request = new HttpRequest({", "});", () -> {
292+
writer.write("method: $S,", testCase.getMethod());
293+
writer.write("hostname: $S,", testCase.getHost().orElse("foo.example.com"));
294+
writer.write("path: $S,", testCase.getUri());
295+
writer.write("query: $L,", queryParameters);
296+
writer.write("headers: $L,", headerParameters);
297+
if (body != null) {
298+
writer.write("body: Readable.from([$S]),", body);
299+
}
300+
});
301+
writer.write("await handler.handle(request);").write("");
302+
303+
// Assert that the function has been called exactly once.
304+
writer.write("expect(testFunction.mock.calls.length).toBe(1);");
305+
306+
// Capture the input. We need to cast this to any so we can index into it.
307+
writer.write("let r: any = testFunction.mock.calls[0][0];").write("");
308+
309+
writeRequestParamAssertions(operation, testCase);
310+
});
311+
}
312+
313+
private ObjectNode buildQueryBag(HttpRequestTestCase testCase) {
314+
// The query params in the test definition is a list of strings that looks like
315+
// "Foo=Bar", so we need to split the keys from the values.
316+
Map<String, List<String>> query = testCase.getQueryParams().stream()
317+
.map(pair -> {
318+
String[] split = pair.split("=");
319+
String key;
320+
String value = "";
321+
try {
322+
// The strings we're given are url encoded, so we need to decode them. In an actual implementation
323+
// the request we're given will have already decoded these.
324+
key = URLDecoder.decode(split[0], StandardCharsets.UTF_8.toString());
325+
if (split.length > 1) {
326+
value = URLDecoder.decode(split[1], StandardCharsets.UTF_8.toString());
327+
}
328+
} catch (UnsupportedEncodingException e) {
329+
throw new RuntimeException(e);
330+
}
331+
return Pair.of(key, value);
332+
})
333+
// Query lists/sets will just use the same key repeatedly, so here we collect all the values that
334+
// share a key.
335+
.collect(Collectors.groupingBy(Pair::getKey, Collectors.mapping(Pair::getValue, Collectors.toList())));
336+
337+
ObjectNode.Builder nodeBuilder = ObjectNode.objectNodeBuilder();
338+
for (Map.Entry<String, List<String>> entry : query.entrySet()) {
339+
// The value of the query bag can either be a single string or a list, so we need to ensure individual
340+
// values are written out as individual strings.
341+
if (entry.getValue().size() == 1) {
342+
nodeBuilder.withMember(entry.getKey(), StringNode.from(entry.getValue().get(0)));
343+
} else {
344+
nodeBuilder.withMember(entry.getKey(), ArrayNode.fromStrings(entry.getValue()));
345+
}
346+
}
347+
return nodeBuilder.build();
348+
}
349+
220350
// Ensure that the serialized request matches the expected request.
221351
private void writeRequestAssertions(OperationShape operation, HttpRequestTestCase testCase) {
222352
writer.write("expect(r.method).toBe($S);", testCase.getMethod());
@@ -405,10 +535,40 @@ private void writeResponseTestSetup(OperationShape operation, HttpResponseTestCa
405535
private void writeResponseAssertions(Shape operationOrError, HttpResponseTestCase testCase) {
406536
writer.write("expect(r['$$metadata'].httpStatusCode).toBe($L);", testCase.getCode());
407537

408-
writeParamAssertions(operationOrError, testCase);
538+
writeReponseParamAssertions(operationOrError, testCase);
539+
}
540+
541+
private void writeRequestParamAssertions(OperationShape operation, HttpRequestTestCase testCase) {
542+
ObjectNode params = testCase.getParams();
543+
if (!params.isEmpty()) {
544+
StructureShape testInputShape = model.expectShape(
545+
operation.getInput().orElseThrow(() -> new CodegenException("Foo")),
546+
StructureShape.class);
547+
548+
// Use this trick wrapper to not need more complex trailing comma handling.
549+
writer.write("const paramsToValidate: any = [")
550+
.call(() -> params.accept(new CommandOutputNodeVisitor(testInputShape)))
551+
.write("][0];");
552+
553+
// Extract a payload binding if present.
554+
Optional<HttpBinding> pb = Optional.empty();
555+
HttpBindingIndex index = HttpBindingIndex.of(model);
556+
List<HttpBinding> payloadBindings = index.getRequestBindings(operation, Location.PAYLOAD);
557+
if (!payloadBindings.isEmpty()) {
558+
pb = Optional.of(payloadBindings.get(0));
559+
}
560+
final Optional<HttpBinding> payloadBinding = pb;
561+
562+
writeParamAssertions(writer, payloadBinding, () -> {
563+
// TODO: replace this with a collector from the server config once it's available
564+
writer.addImport("streamCollector", "__streamCollector", "@aws-sdk/node-http-handler");
565+
writer.write("const comparableBlob = await __streamCollector(r[$S]);",
566+
payloadBinding.get().getMemberName());
567+
});
568+
}
409569
}
410570

411-
private void writeParamAssertions(Shape operationOrError, HttpResponseTestCase testCase) {
571+
private void writeReponseParamAssertions(Shape operationOrError, HttpResponseTestCase testCase) {
412572
ObjectNode params = testCase.getParams();
413573
if (!params.isEmpty()) {
414574
StructureShape testOutputShape;
@@ -438,40 +598,49 @@ private void writeParamAssertions(Shape operationOrError, HttpResponseTestCase t
438598
return null;
439599
});
440600

441-
// If we have a streaming payload blob, we need to collect it to something that
442-
// can be compared with the test contents. This emulates the customer experience.
443-
boolean hasStreamingPayloadBlob = payloadBinding
444-
.map(binding ->
601+
writeParamAssertions(writer, payloadBinding, () -> {
602+
writer.write("const comparableBlob = await client.config.streamCollector(r[$S]);",
603+
payloadBinding.get().getMemberName());
604+
});
605+
}
606+
}
607+
608+
private void writeParamAssertions(
609+
TypeScriptWriter writer,
610+
Optional<HttpBinding> payloadBinding,
611+
Runnable writeComparableBlob
612+
) {
613+
// If we have a streaming payload blob, we need to collect it to something that
614+
// can be compared with the test contents. This emulates the customer experience.
615+
boolean hasStreamingPayloadBlob = payloadBinding
616+
.map(binding ->
445617
model.getShape(binding.getMember().getTarget())
446618
.filter(Shape::isBlobShape)
447619
.filter(s -> s.hasTrait(StreamingTrait.ID))
448620
.isPresent())
449-
.orElse(false);
621+
.orElse(false);
622+
623+
if (hasStreamingPayloadBlob) {
624+
writeComparableBlob.run();
625+
}
450626

451-
// Do the collection for payload blobs.
627+
// Perform parameter comparisons.
628+
writer.openBlock("Object.keys(paramsToValidate).forEach(param => {", "});", () -> {
629+
writer.write("expect(r[param]).toBeDefined();");
452630
if (hasStreamingPayloadBlob) {
453-
writer.write("const comparableBlob = await client.config.streamCollector(r[$S]);",
454-
payloadBinding.get().getMemberName());
631+
writer.openBlock("if (param === $S) {", "} else {", payloadBinding.get().getMemberName(), () ->
632+
writer.write("expect(equivalentContents(comparableBlob, "
633+
+ "paramsToValidate[param])).toBe(true);"));
634+
writer.indent();
455635
}
456636

457-
// Perform parameter comparisons.
458-
writer.openBlock("Object.keys(paramsToValidate).forEach(param => {", "});", () -> {
459-
writer.write("expect(r[param]).toBeDefined();");
460-
if (hasStreamingPayloadBlob) {
461-
writer.openBlock("if (param === $S) {", "} else {", payloadBinding.get().getMemberName(), () ->
462-
writer.write("expect(equivalentContents(comparableBlob, "
463-
+ "paramsToValidate[param])).toBe(true);"));
464-
writer.indent();
465-
}
637+
writer.write("expect(equivalentContents(r[param], paramsToValidate[param])).toBe(true);");
466638

467-
writer.write("expect(equivalentContents(r[param], paramsToValidate[param])).toBe(true);");
468-
469-
if (hasStreamingPayloadBlob) {
470-
writer.dedent();
471-
writer.write("}");
472-
}
473-
});
474-
}
639+
if (hasStreamingPayloadBlob) {
640+
writer.dedent();
641+
writer.write("}");
642+
}
643+
});
475644
}
476645

477646
/**

0 commit comments

Comments
 (0)