17
17
18
18
import static java .lang .String .format ;
19
19
20
+ import java .io .UnsupportedEncodingException ;
21
+ import java .net .URLDecoder ;
22
+ import java .nio .charset .StandardCharsets ;
20
23
import java .util .List ;
21
24
import java .util .Locale ;
22
25
import java .util .Map ;
23
26
import java .util .Optional ;
24
27
import java .util .Set ;
25
28
import java .util .TreeSet ;
26
29
import java .util .logging .Logger ;
30
+ import java .util .stream .Collectors ;
27
31
import software .amazon .smithy .codegen .core .CodegenException ;
28
32
import software .amazon .smithy .codegen .core .Symbol ;
29
33
import software .amazon .smithy .codegen .core .SymbolProvider ;
58
62
import software .amazon .smithy .protocoltests .traits .HttpRequestTestsTrait ;
59
63
import software .amazon .smithy .protocoltests .traits .HttpResponseTestCase ;
60
64
import software .amazon .smithy .protocoltests .traits .HttpResponseTestsTrait ;
65
+ import software .amazon .smithy .typescript .codegen .integration .ProtocolGenerator ;
61
66
import software .amazon .smithy .utils .IoUtils ;
62
67
import software .amazon .smithy .utils .MapUtils ;
63
68
import software .amazon .smithy .utils .Pair ;
69
+ import software .amazon .smithy .utils .StringUtils ;
64
70
65
71
/**
66
72
* Generates HTTP protocol test cases to be run using Jest.
@@ -86,6 +92,7 @@ final class HttpProtocolTestGenerator implements Runnable {
86
92
private final SymbolProvider symbolProvider ;
87
93
private final Symbol serviceSymbol ;
88
94
private final Set <String > additionalStubs = new TreeSet <>();
95
+ private final ProtocolGenerator protocolGenerator ;
89
96
90
97
/** Vends a TypeScript IFF it's needed. */
91
98
private final TypeScriptDelegator delegator ;
@@ -98,14 +105,16 @@ final class HttpProtocolTestGenerator implements Runnable {
98
105
Model model ,
99
106
ShapeId protocol ,
100
107
SymbolProvider symbolProvider ,
101
- TypeScriptDelegator delegator
108
+ TypeScriptDelegator delegator ,
109
+ ProtocolGenerator protocolGenerator
102
110
) {
103
111
this .settings = settings ;
104
112
this .model = model ;
105
113
this .protocol = protocol ;
106
114
this .service = settings .getService (model );
107
115
this .symbolProvider = symbolProvider ;
108
116
this .delegator = delegator ;
117
+ this .protocolGenerator = protocolGenerator ;
109
118
serviceSymbol = symbolProvider .toSymbol (service );
110
119
}
111
120
@@ -116,30 +125,11 @@ public void run() {
116
125
117
126
// Use a TreeSet to have a fixed ordering of tests.
118
127
for (OperationShape operation : new TreeSet <>(topDownIndex .getContainedOperations (service ))) {
119
- if (!operation .hasTag ("server-only" )) {
120
- // 1. Generate test cases for each request.
121
- operation .getTrait (HttpRequestTestsTrait .class ).ifPresent (trait -> {
122
- for (HttpRequestTestCase testCase : trait .getTestCasesFor (AppliesTo .CLIENT )) {
123
- onlyIfProtocolMatches (testCase , () -> generateRequestTest (operation , testCase ));
124
- }
125
- });
126
- // 2. Generate test cases for each response.
127
- operation .getTrait (HttpResponseTestsTrait .class ).ifPresent (trait -> {
128
- for (HttpResponseTestCase testCase : trait .getTestCasesFor (AppliesTo .CLIENT )) {
129
- onlyIfProtocolMatches (testCase , () -> generateResponseTest (operation , testCase ));
130
- }
131
- });
132
- // 3. Generate test cases for each error on each operation.
133
- for (StructureShape error : operationIndex .getErrors (operation )) {
134
- if (!error .hasTag ("server-only" )) {
135
- error .getTrait (HttpResponseTestsTrait .class ).ifPresent (trait -> {
136
- for (HttpResponseTestCase testCase : trait .getTestCasesFor (AppliesTo .CLIENT )) {
137
- onlyIfProtocolMatches (testCase ,
138
- () -> generateErrorResponseTest (operation , error , testCase ));
139
- }
140
- });
141
- }
142
- }
128
+ if (settings .generateClient ()) {
129
+ generateClientOperationTests (operation , operationIndex );
130
+ }
131
+ if (settings .generateServerSdk ()) {
132
+ generateServerOperationTests (operation , operationIndex );
143
133
}
144
134
}
145
135
@@ -149,6 +139,45 @@ public void run() {
149
139
}
150
140
}
151
141
142
+ private void generateClientOperationTests (OperationShape operation , OperationIndex operationIndex ) {
143
+ if (!operation .hasTag ("server-only" )) {
144
+ // 1. Generate test cases for each request.
145
+ operation .getTrait (HttpRequestTestsTrait .class ).ifPresent (trait -> {
146
+ for (HttpRequestTestCase testCase : trait .getTestCasesFor (AppliesTo .CLIENT )) {
147
+ onlyIfProtocolMatches (testCase , () -> generateClientRequestTest (operation , testCase ));
148
+ }
149
+ });
150
+ // 2. Generate test cases for each response.
151
+ operation .getTrait (HttpResponseTestsTrait .class ).ifPresent (trait -> {
152
+ for (HttpResponseTestCase testCase : trait .getTestCasesFor (AppliesTo .CLIENT )) {
153
+ onlyIfProtocolMatches (testCase , () -> generateResponseTest (operation , testCase ));
154
+ }
155
+ });
156
+ // 3. Generate test cases for each error on each operation.
157
+ for (StructureShape error : operationIndex .getErrors (operation )) {
158
+ if (!error .hasTag ("server-only" )) {
159
+ error .getTrait (HttpResponseTestsTrait .class ).ifPresent (trait -> {
160
+ for (HttpResponseTestCase testCase : trait .getTestCasesFor (AppliesTo .CLIENT )) {
161
+ onlyIfProtocolMatches (testCase ,
162
+ () -> generateErrorResponseTest (operation , error , testCase ));
163
+ }
164
+ });
165
+ }
166
+ }
167
+ }
168
+ }
169
+
170
+ private void generateServerOperationTests (OperationShape operation , OperationIndex operationIndex ) {
171
+ if (!operation .hasTag ("client-only" )) {
172
+ // 1. Generate test cases for each request.
173
+ operation .getTrait (HttpRequestTestsTrait .class ).ifPresent (trait -> {
174
+ for (HttpRequestTestCase testCase : trait .getTestCasesFor (AppliesTo .SERVER )) {
175
+ onlyIfProtocolMatches (testCase , () -> generateServerRequestTest (operation , testCase ));
176
+ }
177
+ });
178
+ }
179
+ }
180
+
152
181
// Only generate test cases when its protocol matches the target protocol.
153
182
private <T extends HttpMessageTestCase > void onlyIfProtocolMatches (T testCase , Runnable runnable ) {
154
183
if (testCase .getProtocol ().equals (protocol )) {
@@ -175,7 +204,7 @@ private String createTestCaseFilename() {
175
204
return TEST_CASE_FILE_TEMPLATE .replace ("%s" , baseName );
176
205
}
177
206
178
- private void generateRequestTest (OperationShape operation , HttpRequestTestCase testCase ) {
207
+ private void generateClientRequestTest (OperationShape operation , HttpRequestTestCase testCase ) {
179
208
Symbol operationSymbol = symbolProvider .toSymbol (operation );
180
209
181
210
String testName = testCase .getId () + ":Request" ;
@@ -217,6 +246,107 @@ private void generateRequestTest(OperationShape operation, HttpRequestTestCase t
217
246
});
218
247
}
219
248
249
+ private void generateServerRequestTest (OperationShape operation , HttpRequestTestCase testCase ) {
250
+ Symbol operationSymbol = symbolProvider .toSymbol (operation );
251
+
252
+ // Lowercase all the headers we're expecting as this is what we'll get.
253
+ Map <String , String > headers = testCase .getHeaders ().entrySet ().stream ()
254
+ .map (entry -> new Pair <>(entry .getKey ().toLowerCase (Locale .US ), entry .getValue ()))
255
+ .collect (MapUtils .toUnmodifiableMap (Pair ::getLeft , Pair ::getRight ));
256
+ String queryParameters = Node .prettyPrintJson (buildQueryBag (testCase ));
257
+ String headerParameters = Node .prettyPrintJson (ObjectNode .fromStringMap (headers ));
258
+ String body = testCase .getBody ().orElse (null );
259
+
260
+ String testName = testCase .getId () + ":ServerRequest" ;
261
+ testCase .getDocumentation ().ifPresent (writer ::writeDocs );
262
+ writer .openBlock ("it($S, async () => {" , "});\n " , testName , () -> {
263
+ // TODO: use the symbol provider when it's ready
264
+ String serviceName = StringUtils .capitalize (service .getId ().getName ());
265
+ String operationName = StringUtils .capitalize (operation .getId ().getName ());
266
+ Symbol inputType = operationSymbol .expectProperty ("inputType" , Symbol .class );
267
+ Symbol outputType = operationSymbol .expectProperty ("outputType" , Symbol .class );
268
+
269
+ // Create a mock function to set in place of the server operation function so we can capture
270
+ // input and other information.
271
+ writer .write ("let testFunction = jest.fn();" );
272
+ writer .write ("testFunction.mockReturnValue({});" );
273
+
274
+ // We use a partial here so that we don't have to define the entire service, but still get the advantages
275
+ // the type checker, including excess property checking. Later on we'll use `as` to cast this to the
276
+ // full service so that we can actually use it.
277
+ writer .addImport (serviceName + "Service" , null , "./server" );
278
+ writer .openBlock ("const testService: Partial<$LService> = {" , "};" , serviceName , () -> {
279
+ writer .addImport ("Operation" , "__Operation" , "@aws-smithy/server-common" );
280
+ writer .write ("$L: testFunction as __Operation<$T, $T>," , operationName , inputType , outputType );
281
+ });
282
+
283
+ String getHandlerName = String .format ("get%sServiceHandler" , serviceName );
284
+ writer .addImport (getHandlerName , getHandlerName ,
285
+ "./protocols/" + ProtocolGenerator .getSanitizedName (protocolGenerator .getName ()));
286
+
287
+ // Cast the service as any so TS will ignore the fact that the type being passed in is incomplete.
288
+ writer .write ("const handler = $L(testService as $LService);" , getHandlerName , serviceName );
289
+
290
+ // Construct a new http request according to the test case definition.
291
+ writer .openBlock ("const request = new HttpRequest({" , "});" , () -> {
292
+ writer .write ("method: $S," , testCase .getMethod ());
293
+ writer .write ("hostname: $S," , testCase .getHost ().orElse ("foo.example.com" ));
294
+ writer .write ("path: $S," , testCase .getUri ());
295
+ writer .write ("query: $L," , queryParameters );
296
+ writer .write ("headers: $L," , headerParameters );
297
+ if (body != null ) {
298
+ writer .write ("body: Readable.from([$S])," , body );
299
+ }
300
+ });
301
+ writer .write ("await handler.handle(request);" ).write ("" );
302
+
303
+ // Assert that the function has been called exactly once.
304
+ writer .write ("expect(testFunction.mock.calls.length).toBe(1);" );
305
+
306
+ // Capture the input. We need to cast this to any so we can index into it.
307
+ writer .write ("let r: any = testFunction.mock.calls[0][0];" ).write ("" );
308
+
309
+ writeRequestParamAssertions (operation , testCase );
310
+ });
311
+ }
312
+
313
+ private ObjectNode buildQueryBag (HttpRequestTestCase testCase ) {
314
+ // The query params in the test definition is a list of strings that looks like
315
+ // "Foo=Bar", so we need to split the keys from the values.
316
+ Map <String , List <String >> query = testCase .getQueryParams ().stream ()
317
+ .map (pair -> {
318
+ String [] split = pair .split ("=" );
319
+ String key ;
320
+ String value = "" ;
321
+ try {
322
+ // The strings we're given are url encoded, so we need to decode them. In an actual implementation
323
+ // the request we're given will have already decoded these.
324
+ key = URLDecoder .decode (split [0 ], StandardCharsets .UTF_8 .toString ());
325
+ if (split .length > 1 ) {
326
+ value = URLDecoder .decode (split [1 ], StandardCharsets .UTF_8 .toString ());
327
+ }
328
+ } catch (UnsupportedEncodingException e ) {
329
+ throw new RuntimeException (e );
330
+ }
331
+ return Pair .of (key , value );
332
+ })
333
+ // Query lists/sets will just use the same key repeatedly, so here we collect all the values that
334
+ // share a key.
335
+ .collect (Collectors .groupingBy (Pair ::getKey , Collectors .mapping (Pair ::getValue , Collectors .toList ())));
336
+
337
+ ObjectNode .Builder nodeBuilder = ObjectNode .objectNodeBuilder ();
338
+ for (Map .Entry <String , List <String >> entry : query .entrySet ()) {
339
+ // The value of the query bag can either be a single string or a list, so we need to ensure individual
340
+ // values are written out as individual strings.
341
+ if (entry .getValue ().size () == 1 ) {
342
+ nodeBuilder .withMember (entry .getKey (), StringNode .from (entry .getValue ().get (0 )));
343
+ } else {
344
+ nodeBuilder .withMember (entry .getKey (), ArrayNode .fromStrings (entry .getValue ()));
345
+ }
346
+ }
347
+ return nodeBuilder .build ();
348
+ }
349
+
220
350
// Ensure that the serialized request matches the expected request.
221
351
private void writeRequestAssertions (OperationShape operation , HttpRequestTestCase testCase ) {
222
352
writer .write ("expect(r.method).toBe($S);" , testCase .getMethod ());
@@ -405,10 +535,40 @@ private void writeResponseTestSetup(OperationShape operation, HttpResponseTestCa
405
535
private void writeResponseAssertions (Shape operationOrError , HttpResponseTestCase testCase ) {
406
536
writer .write ("expect(r['$$metadata'].httpStatusCode).toBe($L);" , testCase .getCode ());
407
537
408
- writeParamAssertions (operationOrError , testCase );
538
+ writeReponseParamAssertions (operationOrError , testCase );
539
+ }
540
+
541
+ private void writeRequestParamAssertions (OperationShape operation , HttpRequestTestCase testCase ) {
542
+ ObjectNode params = testCase .getParams ();
543
+ if (!params .isEmpty ()) {
544
+ StructureShape testInputShape = model .expectShape (
545
+ operation .getInput ().orElseThrow (() -> new CodegenException ("Foo" )),
546
+ StructureShape .class );
547
+
548
+ // Use this trick wrapper to not need more complex trailing comma handling.
549
+ writer .write ("const paramsToValidate: any = [" )
550
+ .call (() -> params .accept (new CommandOutputNodeVisitor (testInputShape )))
551
+ .write ("][0];" );
552
+
553
+ // Extract a payload binding if present.
554
+ Optional <HttpBinding > pb = Optional .empty ();
555
+ HttpBindingIndex index = HttpBindingIndex .of (model );
556
+ List <HttpBinding > payloadBindings = index .getRequestBindings (operation , Location .PAYLOAD );
557
+ if (!payloadBindings .isEmpty ()) {
558
+ pb = Optional .of (payloadBindings .get (0 ));
559
+ }
560
+ final Optional <HttpBinding > payloadBinding = pb ;
561
+
562
+ writeParamAssertions (writer , payloadBinding , () -> {
563
+ // TODO: replace this with a collector from the server config once it's available
564
+ writer .addImport ("streamCollector" , "__streamCollector" , "@aws-sdk/node-http-handler" );
565
+ writer .write ("const comparableBlob = await __streamCollector(r[$S]);" ,
566
+ payloadBinding .get ().getMemberName ());
567
+ });
568
+ }
409
569
}
410
570
411
- private void writeParamAssertions (Shape operationOrError , HttpResponseTestCase testCase ) {
571
+ private void writeReponseParamAssertions (Shape operationOrError , HttpResponseTestCase testCase ) {
412
572
ObjectNode params = testCase .getParams ();
413
573
if (!params .isEmpty ()) {
414
574
StructureShape testOutputShape ;
@@ -438,40 +598,49 @@ private void writeParamAssertions(Shape operationOrError, HttpResponseTestCase t
438
598
return null ;
439
599
});
440
600
441
- // If we have a streaming payload blob, we need to collect it to something that
442
- // can be compared with the test contents. This emulates the customer experience.
443
- boolean hasStreamingPayloadBlob = payloadBinding
444
- .map (binding ->
601
+ writeParamAssertions (writer , payloadBinding , () -> {
602
+ writer .write ("const comparableBlob = await client.config.streamCollector(r[$S]);" ,
603
+ payloadBinding .get ().getMemberName ());
604
+ });
605
+ }
606
+ }
607
+
608
+ private void writeParamAssertions (
609
+ TypeScriptWriter writer ,
610
+ Optional <HttpBinding > payloadBinding ,
611
+ Runnable writeComparableBlob
612
+ ) {
613
+ // If we have a streaming payload blob, we need to collect it to something that
614
+ // can be compared with the test contents. This emulates the customer experience.
615
+ boolean hasStreamingPayloadBlob = payloadBinding
616
+ .map (binding ->
445
617
model .getShape (binding .getMember ().getTarget ())
446
618
.filter (Shape ::isBlobShape )
447
619
.filter (s -> s .hasTrait (StreamingTrait .ID ))
448
620
.isPresent ())
449
- .orElse (false );
621
+ .orElse (false );
622
+
623
+ if (hasStreamingPayloadBlob ) {
624
+ writeComparableBlob .run ();
625
+ }
450
626
451
- // Do the collection for payload blobs.
627
+ // Perform parameter comparisons.
628
+ writer .openBlock ("Object.keys(paramsToValidate).forEach(param => {" , "});" , () -> {
629
+ writer .write ("expect(r[param]).toBeDefined();" );
452
630
if (hasStreamingPayloadBlob ) {
453
- writer .write ("const comparableBlob = await client.config.streamCollector(r[$S]);" ,
454
- payloadBinding .get ().getMemberName ());
631
+ writer .openBlock ("if (param === $S) {" , "} else {" , payloadBinding .get ().getMemberName (), () ->
632
+ writer .write ("expect(equivalentContents(comparableBlob, "
633
+ + "paramsToValidate[param])).toBe(true);" ));
634
+ writer .indent ();
455
635
}
456
636
457
- // Perform parameter comparisons.
458
- writer .openBlock ("Object.keys(paramsToValidate).forEach(param => {" , "});" , () -> {
459
- writer .write ("expect(r[param]).toBeDefined();" );
460
- if (hasStreamingPayloadBlob ) {
461
- writer .openBlock ("if (param === $S) {" , "} else {" , payloadBinding .get ().getMemberName (), () ->
462
- writer .write ("expect(equivalentContents(comparableBlob, "
463
- + "paramsToValidate[param])).toBe(true);" ));
464
- writer .indent ();
465
- }
637
+ writer .write ("expect(equivalentContents(r[param], paramsToValidate[param])).toBe(true);" );
466
638
467
- writer .write ("expect(equivalentContents(r[param], paramsToValidate[param])).toBe(true);" );
468
-
469
- if (hasStreamingPayloadBlob ) {
470
- writer .dedent ();
471
- writer .write ("}" );
472
- }
473
- });
474
- }
639
+ if (hasStreamingPayloadBlob ) {
640
+ writer .dedent ();
641
+ writer .write ("}" );
642
+ }
643
+ });
475
644
}
476
645
477
646
/**
0 commit comments