Skip to content

Commit 7896c78

Browse files
Generate ssdk request protocol tests
1 parent d520610 commit 7896c78

File tree

2 files changed

+215
-54
lines changed

2 files changed

+215
-54
lines changed

smithy-typescript-codegen/src/main/java/software/amazon/smithy/typescript/codegen/CodegenVisitor.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ void execute() {
202202
// Generate protocol tests IFF found in the model.
203203
if (protocolGenerator != null) {
204204
ShapeId protocol = protocolGenerator.getProtocol();
205-
new HttpProtocolTestGenerator(settings, model, protocol, symbolProvider, writers).run();
205+
new HttpProtocolTestGenerator(settings, model, protocol, symbolProvider, writers, protocolGenerator).run();
206206
}
207207

208208
// Write each pending writer.

smithy-typescript-codegen/src/main/java/software/amazon/smithy/typescript/codegen/HttpProtocolTestGenerator.java

Lines changed: 214 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,17 @@
1717

1818
import static java.lang.String.format;
1919

20+
import java.io.UnsupportedEncodingException;
21+
import java.net.URLDecoder;
22+
import java.nio.charset.StandardCharsets;
2023
import java.util.List;
2124
import java.util.Locale;
2225
import java.util.Map;
2326
import java.util.Optional;
2427
import java.util.Set;
2528
import java.util.TreeSet;
2629
import java.util.logging.Logger;
30+
import java.util.stream.Collectors;
2731
import software.amazon.smithy.codegen.core.CodegenException;
2832
import software.amazon.smithy.codegen.core.Symbol;
2933
import software.amazon.smithy.codegen.core.SymbolProvider;
@@ -58,9 +62,11 @@
5862
import software.amazon.smithy.protocoltests.traits.HttpRequestTestsTrait;
5963
import software.amazon.smithy.protocoltests.traits.HttpResponseTestCase;
6064
import software.amazon.smithy.protocoltests.traits.HttpResponseTestsTrait;
65+
import software.amazon.smithy.typescript.codegen.integration.ProtocolGenerator;
6166
import software.amazon.smithy.utils.IoUtils;
6267
import software.amazon.smithy.utils.MapUtils;
6368
import software.amazon.smithy.utils.Pair;
69+
import software.amazon.smithy.utils.StringUtils;
6470

6571
/**
6672
* Generates HTTP protocol test cases to be run using Jest.
@@ -86,6 +92,7 @@ final class HttpProtocolTestGenerator implements Runnable {
8692
private final SymbolProvider symbolProvider;
8793
private final Symbol serviceSymbol;
8894
private final Set<String> additionalStubs = new TreeSet<>();
95+
private final ProtocolGenerator protocolGenerator;
8996

9097
/** Vends a TypeScript IFF it's needed. */
9198
private final TypeScriptDelegator delegator;
@@ -98,14 +105,16 @@ final class HttpProtocolTestGenerator implements Runnable {
98105
Model model,
99106
ShapeId protocol,
100107
SymbolProvider symbolProvider,
101-
TypeScriptDelegator delegator
108+
TypeScriptDelegator delegator,
109+
ProtocolGenerator protocolGenerator
102110
) {
103111
this.settings = settings;
104112
this.model = model;
105113
this.protocol = protocol;
106114
this.service = settings.getService(model);
107115
this.symbolProvider = symbolProvider;
108116
this.delegator = delegator;
117+
this.protocolGenerator = protocolGenerator;
109118
serviceSymbol = symbolProvider.toSymbol(service);
110119
}
111120

@@ -116,30 +125,11 @@ public void run() {
116125

117126
// Use a TreeSet to have a fixed ordering of tests.
118127
for (OperationShape operation : new TreeSet<>(topDownIndex.getContainedOperations(service))) {
119-
if (!operation.hasTag("server-only")) {
120-
// 1. Generate test cases for each request.
121-
operation.getTrait(HttpRequestTestsTrait.class).ifPresent(trait -> {
122-
for (HttpRequestTestCase testCase : trait.getTestCasesFor(AppliesTo.CLIENT)) {
123-
onlyIfProtocolMatches(testCase, () -> generateRequestTest(operation, testCase));
124-
}
125-
});
126-
// 2. Generate test cases for each response.
127-
operation.getTrait(HttpResponseTestsTrait.class).ifPresent(trait -> {
128-
for (HttpResponseTestCase testCase : trait.getTestCasesFor(AppliesTo.CLIENT)) {
129-
onlyIfProtocolMatches(testCase, () -> generateResponseTest(operation, testCase));
130-
}
131-
});
132-
// 3. Generate test cases for each error on each operation.
133-
for (StructureShape error : operationIndex.getErrors(operation)) {
134-
if (!error.hasTag("server-only")) {
135-
error.getTrait(HttpResponseTestsTrait.class).ifPresent(trait -> {
136-
for (HttpResponseTestCase testCase : trait.getTestCasesFor(AppliesTo.CLIENT)) {
137-
onlyIfProtocolMatches(testCase,
138-
() -> generateErrorResponseTest(operation, error, testCase));
139-
}
140-
});
141-
}
142-
}
128+
if (settings.generateClient()) {
129+
generateClientOperationTests(operation, operationIndex);
130+
}
131+
if (settings.generateServerSdk()) {
132+
generateServerOperationTests(operation, operationIndex);
143133
}
144134
}
145135

@@ -149,6 +139,45 @@ public void run() {
149139
}
150140
}
151141

142+
private void generateClientOperationTests(OperationShape operation, OperationIndex operationIndex) {
143+
if (!operation.hasTag("server-only")) {
144+
// 1. Generate test cases for each request.
145+
operation.getTrait(HttpRequestTestsTrait.class).ifPresent(trait -> {
146+
for (HttpRequestTestCase testCase : trait.getTestCasesFor(AppliesTo.CLIENT)) {
147+
onlyIfProtocolMatches(testCase, () -> generateClientRequestTest(operation, testCase));
148+
}
149+
});
150+
// 2. Generate test cases for each response.
151+
operation.getTrait(HttpResponseTestsTrait.class).ifPresent(trait -> {
152+
for (HttpResponseTestCase testCase : trait.getTestCasesFor(AppliesTo.CLIENT)) {
153+
onlyIfProtocolMatches(testCase, () -> generateResponseTest(operation, testCase));
154+
}
155+
});
156+
// 3. Generate test cases for each error on each operation.
157+
for (StructureShape error : operationIndex.getErrors(operation)) {
158+
if (!error.hasTag("server-only")) {
159+
error.getTrait(HttpResponseTestsTrait.class).ifPresent(trait -> {
160+
for (HttpResponseTestCase testCase : trait.getTestCasesFor(AppliesTo.CLIENT)) {
161+
onlyIfProtocolMatches(testCase,
162+
() -> generateErrorResponseTest(operation, error, testCase));
163+
}
164+
});
165+
}
166+
}
167+
}
168+
}
169+
170+
private void generateServerOperationTests(OperationShape operation, OperationIndex operationIndex) {
171+
if (!operation.hasTag("client-only")) {
172+
// 1. Generate test cases for each request.
173+
operation.getTrait(HttpRequestTestsTrait.class).ifPresent(trait -> {
174+
for (HttpRequestTestCase testCase : trait.getTestCasesFor(AppliesTo.SERVER)) {
175+
onlyIfProtocolMatches(testCase, () -> generateServerRequestTest(operation, testCase));
176+
}
177+
});
178+
}
179+
}
180+
152181
// Only generate test cases when its protocol matches the target protocol.
153182
private <T extends HttpMessageTestCase> void onlyIfProtocolMatches(T testCase, Runnable runnable) {
154183
if (testCase.getProtocol().equals(protocol)) {
@@ -175,7 +204,7 @@ private String createTestCaseFilename() {
175204
return TEST_CASE_FILE_TEMPLATE.replace("%s", baseName);
176205
}
177206

178-
private void generateRequestTest(OperationShape operation, HttpRequestTestCase testCase) {
207+
private void generateClientRequestTest(OperationShape operation, HttpRequestTestCase testCase) {
179208
Symbol operationSymbol = symbolProvider.toSymbol(operation);
180209

181210
String testName = testCase.getId() + ":Request";
@@ -217,6 +246,99 @@ private void generateRequestTest(OperationShape operation, HttpRequestTestCase t
217246
});
218247
}
219248

249+
private void generateServerRequestTest(OperationShape operation, HttpRequestTestCase testCase) {
250+
Symbol operationSymbol = symbolProvider.toSymbol(operation);
251+
252+
// Lowercase all the headers we're expecting as this is what we'll get.
253+
Map<String, String> headers = testCase.getHeaders().entrySet().stream()
254+
.map(entry -> new Pair<>(entry.getKey().toLowerCase(Locale.US), entry.getValue()))
255+
.collect(MapUtils.toUnmodifiableMap(Pair::getLeft, Pair::getRight));
256+
String queryParameters = Node.prettyPrintJson(buildQueryBag(testCase));
257+
String headerParameters = Node.prettyPrintJson(ObjectNode.fromStringMap(headers));
258+
String body = testCase.getBody().orElse(null);
259+
260+
String testName = testCase.getId() + ":ServerRequest";
261+
testCase.getDocumentation().ifPresent(writer::writeDocs);
262+
writer.openBlock("it($S, async () => {", "});\n", testName, () -> {
263+
// TODO: use the symbol provider when it's ready
264+
String serviceName = StringUtils.capitalize(service.getId().getName());
265+
String operationName = StringUtils.capitalize(operation.getId().getName());
266+
Symbol inputType = operationSymbol.expectProperty("inputType", Symbol.class);
267+
268+
// Declare "r" to capture the deserialized input
269+
writer.write("let r: any;");
270+
271+
// This doesn't actually satisfy the service's server interface since it only applies to one operation
272+
// and even then we're using any for response. This doesn't actually matter though since we should only
273+
// ever be calling this one operation and we're not using the output.
274+
writer.openBlock("class TestService {", "}", () -> {
275+
writer.openBlock("$L(input: $T, request: HttpRequest): any {", "}", operationName, inputType, () -> {
276+
writer.write("r = input;");
277+
writer.write("return {};");
278+
});
279+
});
280+
281+
String getHandlerName = String.format("get%sServiceHandler", serviceName);
282+
writer.addImport(getHandlerName, getHandlerName,
283+
"./protocols/" + ProtocolGenerator.getSanitizedName(protocolGenerator.getName()));
284+
285+
// Cast the service as any so TS will ignore the fact that the type being passed in is incomplete.
286+
writer.write("const handler = $L(new TestService() as any);", getHandlerName);
287+
288+
// Construct a new http request according to the test case definition.
289+
writer.openBlock("const request = new HttpRequest({", "});", () -> {
290+
writer.write("method: $S,", testCase.getMethod());
291+
writer.write("hostname: $S,", testCase.getHost().orElse("foo.example.com"));
292+
writer.write("path: $S,", testCase.getUri());
293+
writer.write("query: $L,", queryParameters);
294+
writer.write("headers: $L,", headerParameters);
295+
if (body != null) {
296+
writer.write("body: Readable.from([$S]),", body);
297+
}
298+
});
299+
writer.write("await handler.handle(request);");
300+
301+
writeRequestParamAssertions(operation, testCase);
302+
});
303+
}
304+
305+
private ObjectNode buildQueryBag(HttpRequestTestCase testCase) {
306+
// The query params in the test definition is a list of strings that looks like
307+
// "Foo=Bar", so we need to split the keys from the values.
308+
Map<String, List<String>> query = testCase.getQueryParams().stream()
309+
.map(pair -> {
310+
String[] split = pair.split("=");
311+
String key;
312+
String value = "";
313+
try {
314+
// The strings we're given are url encoded, so we need to decode them. In an actual implementation
315+
// the request we're given will have already decoded these.
316+
key = URLDecoder.decode(split[0], StandardCharsets.UTF_8.toString());
317+
if (split.length > 1) {
318+
value = URLDecoder.decode(split[1], StandardCharsets.UTF_8.toString());
319+
}
320+
} catch (UnsupportedEncodingException e) {
321+
throw new RuntimeException(e);
322+
}
323+
return Pair.of(key, value);
324+
})
325+
// Query lists/sets will just use the same key repeatedly, so here we collect all the values that
326+
// share a key.
327+
.collect(Collectors.groupingBy(Pair::getKey, Collectors.mapping(Pair::getValue, Collectors.toList())));
328+
329+
ObjectNode.Builder nodeBuilder = ObjectNode.objectNodeBuilder();
330+
for (Map.Entry<String, List<String>> entry : query.entrySet()) {
331+
// The value of the query bag can either be a single string or a list, so we need to ensure individual
332+
// values are written out as individual strings.
333+
if (entry.getValue().size() == 1) {
334+
nodeBuilder.withMember(entry.getKey(), StringNode.from(entry.getValue().get(0)));
335+
} else {
336+
nodeBuilder.withMember(entry.getKey(), ArrayNode.fromStrings(entry.getValue()));
337+
}
338+
}
339+
return nodeBuilder.build();
340+
}
341+
220342
// Ensure that the serialized request matches the expected request.
221343
private void writeRequestAssertions(OperationShape operation, HttpRequestTestCase testCase) {
222344
writer.write("expect(r.method).toBe($S);", testCase.getMethod());
@@ -405,10 +527,40 @@ private void writeResponseTestSetup(OperationShape operation, HttpResponseTestCa
405527
private void writeResponseAssertions(Shape operationOrError, HttpResponseTestCase testCase) {
406528
writer.write("expect(r['$$metadata'].httpStatusCode).toBe($L);", testCase.getCode());
407529

408-
writeParamAssertions(operationOrError, testCase);
530+
writeReponseParamAssertions(operationOrError, testCase);
409531
}
410532

411-
private void writeParamAssertions(Shape operationOrError, HttpResponseTestCase testCase) {
533+
private void writeRequestParamAssertions(OperationShape operation, HttpRequestTestCase testCase) {
534+
ObjectNode params = testCase.getParams();
535+
if (!params.isEmpty()) {
536+
StructureShape testInputShape = model.expectShape(
537+
operation.getInput().orElseThrow(() -> new CodegenException("Foo")),
538+
StructureShape.class);
539+
540+
// Use this trick wrapper to not need more complex trailing comma handling.
541+
writer.write("const paramsToValidate: any = [")
542+
.call(() -> params.accept(new CommandOutputNodeVisitor(testInputShape)))
543+
.write("][0];");
544+
545+
// Extract a payload binding if present.
546+
Optional<HttpBinding> pb = Optional.empty();
547+
HttpBindingIndex index = HttpBindingIndex.of(model);
548+
List<HttpBinding> payloadBindings = index.getRequestBindings(operation, Location.PAYLOAD);
549+
if (!payloadBindings.isEmpty()) {
550+
pb = Optional.of(payloadBindings.get(0));
551+
}
552+
final Optional<HttpBinding> payloadBinding = pb;
553+
554+
writeParamAssertions(writer, payloadBinding, () -> {
555+
// TODO: replace this with a collector from the server config once it's available
556+
writer.addImport("streamCollector", "__streamCollector", "@aws-sdk/node-http-handler");
557+
writer.write("const comparableBlob = await __streamCollector(r[$S]);",
558+
payloadBinding.get().getMemberName());
559+
});
560+
}
561+
}
562+
563+
private void writeReponseParamAssertions(Shape operationOrError, HttpResponseTestCase testCase) {
412564
ObjectNode params = testCase.getParams();
413565
if (!params.isEmpty()) {
414566
StructureShape testOutputShape;
@@ -438,40 +590,49 @@ private void writeParamAssertions(Shape operationOrError, HttpResponseTestCase t
438590
return null;
439591
});
440592

441-
// If we have a streaming payload blob, we need to collect it to something that
442-
// can be compared with the test contents. This emulates the customer experience.
443-
boolean hasStreamingPayloadBlob = payloadBinding
444-
.map(binding ->
593+
writeParamAssertions(writer, payloadBinding, () -> {
594+
writer.write("const comparableBlob = await client.config.streamCollector(r[$S]);",
595+
payloadBinding.get().getMemberName());
596+
});
597+
}
598+
}
599+
600+
private void writeParamAssertions(
601+
TypeScriptWriter writer,
602+
Optional<HttpBinding> payloadBinding,
603+
Runnable writeComparableBlob
604+
) {
605+
// If we have a streaming payload blob, we need to collect it to something that
606+
// can be compared with the test contents. This emulates the customer experience.
607+
boolean hasStreamingPayloadBlob = payloadBinding
608+
.map(binding ->
445609
model.getShape(binding.getMember().getTarget())
446610
.filter(Shape::isBlobShape)
447611
.filter(s -> s.hasTrait(StreamingTrait.ID))
448612
.isPresent())
449-
.orElse(false);
613+
.orElse(false);
614+
615+
if (hasStreamingPayloadBlob) {
616+
writeComparableBlob.run();
617+
}
450618

451-
// Do the collection for payload blobs.
619+
// Perform parameter comparisons.
620+
writer.openBlock("Object.keys(paramsToValidate).forEach(param => {", "});", () -> {
621+
writer.write("expect(r[param]).toBeDefined();");
452622
if (hasStreamingPayloadBlob) {
453-
writer.write("const comparableBlob = await client.config.streamCollector(r[$S]);",
454-
payloadBinding.get().getMemberName());
623+
writer.openBlock("if (param === $S) {", "} else {", payloadBinding.get().getMemberName(), () ->
624+
writer.write("expect(equivalentContents(comparableBlob, "
625+
+ "paramsToValidate[param])).toBe(true);"));
626+
writer.indent();
455627
}
456628

457-
// Perform parameter comparisons.
458-
writer.openBlock("Object.keys(paramsToValidate).forEach(param => {", "});", () -> {
459-
writer.write("expect(r[param]).toBeDefined();");
460-
if (hasStreamingPayloadBlob) {
461-
writer.openBlock("if (param === $S) {", "} else {", payloadBinding.get().getMemberName(), () ->
462-
writer.write("expect(equivalentContents(comparableBlob, "
463-
+ "paramsToValidate[param])).toBe(true);"));
464-
writer.indent();
465-
}
629+
writer.write("expect(equivalentContents(r[param], paramsToValidate[param])).toBe(true);");
466630

467-
writer.write("expect(equivalentContents(r[param], paramsToValidate[param])).toBe(true);");
468-
469-
if (hasStreamingPayloadBlob) {
470-
writer.dedent();
471-
writer.write("}");
472-
}
473-
});
474-
}
631+
if (hasStreamingPayloadBlob) {
632+
writer.dedent();
633+
writer.write("}");
634+
}
635+
});
475636
}
476637

477638
/**

0 commit comments

Comments
 (0)