17
17
18
18
import static java .lang .String .format ;
19
19
20
+ import java .io .UnsupportedEncodingException ;
21
+ import java .net .URLDecoder ;
22
+ import java .nio .charset .StandardCharsets ;
20
23
import java .util .List ;
21
24
import java .util .Locale ;
22
25
import java .util .Map ;
23
26
import java .util .Optional ;
24
27
import java .util .Set ;
25
28
import java .util .TreeSet ;
26
29
import java .util .logging .Logger ;
30
+ import java .util .stream .Collectors ;
27
31
import software .amazon .smithy .codegen .core .CodegenException ;
28
32
import software .amazon .smithy .codegen .core .Symbol ;
29
33
import software .amazon .smithy .codegen .core .SymbolProvider ;
58
62
import software .amazon .smithy .protocoltests .traits .HttpRequestTestsTrait ;
59
63
import software .amazon .smithy .protocoltests .traits .HttpResponseTestCase ;
60
64
import software .amazon .smithy .protocoltests .traits .HttpResponseTestsTrait ;
65
+ import software .amazon .smithy .typescript .codegen .integration .ProtocolGenerator ;
61
66
import software .amazon .smithy .utils .IoUtils ;
62
67
import software .amazon .smithy .utils .MapUtils ;
63
68
import software .amazon .smithy .utils .Pair ;
69
+ import software .amazon .smithy .utils .StringUtils ;
64
70
65
71
/**
66
72
* Generates HTTP protocol test cases to be run using Jest.
@@ -86,6 +92,7 @@ final class HttpProtocolTestGenerator implements Runnable {
86
92
private final SymbolProvider symbolProvider ;
87
93
private final Symbol serviceSymbol ;
88
94
private final Set <String > additionalStubs = new TreeSet <>();
95
+ private final ProtocolGenerator protocolGenerator ;
89
96
90
97
/** Vends a TypeScript IFF it's needed. */
91
98
private final TypeScriptDelegator delegator ;
@@ -98,14 +105,16 @@ final class HttpProtocolTestGenerator implements Runnable {
98
105
Model model ,
99
106
ShapeId protocol ,
100
107
SymbolProvider symbolProvider ,
101
- TypeScriptDelegator delegator
108
+ TypeScriptDelegator delegator ,
109
+ ProtocolGenerator protocolGenerator
102
110
) {
103
111
this .settings = settings ;
104
112
this .model = model ;
105
113
this .protocol = protocol ;
106
114
this .service = settings .getService (model );
107
115
this .symbolProvider = symbolProvider ;
108
116
this .delegator = delegator ;
117
+ this .protocolGenerator = protocolGenerator ;
109
118
serviceSymbol = symbolProvider .toSymbol (service );
110
119
}
111
120
@@ -116,30 +125,11 @@ public void run() {
116
125
117
126
// Use a TreeSet to have a fixed ordering of tests.
118
127
for (OperationShape operation : new TreeSet <>(topDownIndex .getContainedOperations (service ))) {
119
- if (!operation .hasTag ("server-only" )) {
120
- // 1. Generate test cases for each request.
121
- operation .getTrait (HttpRequestTestsTrait .class ).ifPresent (trait -> {
122
- for (HttpRequestTestCase testCase : trait .getTestCasesFor (AppliesTo .CLIENT )) {
123
- onlyIfProtocolMatches (testCase , () -> generateRequestTest (operation , testCase ));
124
- }
125
- });
126
- // 2. Generate test cases for each response.
127
- operation .getTrait (HttpResponseTestsTrait .class ).ifPresent (trait -> {
128
- for (HttpResponseTestCase testCase : trait .getTestCasesFor (AppliesTo .CLIENT )) {
129
- onlyIfProtocolMatches (testCase , () -> generateResponseTest (operation , testCase ));
130
- }
131
- });
132
- // 3. Generate test cases for each error on each operation.
133
- for (StructureShape error : operationIndex .getErrors (operation )) {
134
- if (!error .hasTag ("server-only" )) {
135
- error .getTrait (HttpResponseTestsTrait .class ).ifPresent (trait -> {
136
- for (HttpResponseTestCase testCase : trait .getTestCasesFor (AppliesTo .CLIENT )) {
137
- onlyIfProtocolMatches (testCase ,
138
- () -> generateErrorResponseTest (operation , error , testCase ));
139
- }
140
- });
141
- }
142
- }
128
+ if (settings .generateClient ()) {
129
+ generateClientOperationTests (operation , operationIndex );
130
+ }
131
+ if (settings .generateServerSdk ()) {
132
+ generateServerOperationTests (operation , operationIndex );
143
133
}
144
134
}
145
135
@@ -149,6 +139,45 @@ public void run() {
149
139
}
150
140
}
151
141
142
+ private void generateClientOperationTests (OperationShape operation , OperationIndex operationIndex ) {
143
+ if (!operation .hasTag ("server-only" )) {
144
+ // 1. Generate test cases for each request.
145
+ operation .getTrait (HttpRequestTestsTrait .class ).ifPresent (trait -> {
146
+ for (HttpRequestTestCase testCase : trait .getTestCasesFor (AppliesTo .CLIENT )) {
147
+ onlyIfProtocolMatches (testCase , () -> generateClientRequestTest (operation , testCase ));
148
+ }
149
+ });
150
+ // 2. Generate test cases for each response.
151
+ operation .getTrait (HttpResponseTestsTrait .class ).ifPresent (trait -> {
152
+ for (HttpResponseTestCase testCase : trait .getTestCasesFor (AppliesTo .CLIENT )) {
153
+ onlyIfProtocolMatches (testCase , () -> generateResponseTest (operation , testCase ));
154
+ }
155
+ });
156
+ // 3. Generate test cases for each error on each operation.
157
+ for (StructureShape error : operationIndex .getErrors (operation )) {
158
+ if (!error .hasTag ("server-only" )) {
159
+ error .getTrait (HttpResponseTestsTrait .class ).ifPresent (trait -> {
160
+ for (HttpResponseTestCase testCase : trait .getTestCasesFor (AppliesTo .CLIENT )) {
161
+ onlyIfProtocolMatches (testCase ,
162
+ () -> generateErrorResponseTest (operation , error , testCase ));
163
+ }
164
+ });
165
+ }
166
+ }
167
+ }
168
+ }
169
+
170
+ private void generateServerOperationTests (OperationShape operation , OperationIndex operationIndex ) {
171
+ if (!operation .hasTag ("client-only" )) {
172
+ // 1. Generate test cases for each request.
173
+ operation .getTrait (HttpRequestTestsTrait .class ).ifPresent (trait -> {
174
+ for (HttpRequestTestCase testCase : trait .getTestCasesFor (AppliesTo .SERVER )) {
175
+ onlyIfProtocolMatches (testCase , () -> generateServerRequestTest (operation , testCase ));
176
+ }
177
+ });
178
+ }
179
+ }
180
+
152
181
// Only generate test cases when its protocol matches the target protocol.
153
182
private <T extends HttpMessageTestCase > void onlyIfProtocolMatches (T testCase , Runnable runnable ) {
154
183
if (testCase .getProtocol ().equals (protocol )) {
@@ -175,7 +204,7 @@ private String createTestCaseFilename() {
175
204
return TEST_CASE_FILE_TEMPLATE .replace ("%s" , baseName );
176
205
}
177
206
178
- private void generateRequestTest (OperationShape operation , HttpRequestTestCase testCase ) {
207
+ private void generateClientRequestTest (OperationShape operation , HttpRequestTestCase testCase ) {
179
208
Symbol operationSymbol = symbolProvider .toSymbol (operation );
180
209
181
210
String testName = testCase .getId () + ":Request" ;
@@ -217,6 +246,99 @@ private void generateRequestTest(OperationShape operation, HttpRequestTestCase t
217
246
});
218
247
}
219
248
249
+ private void generateServerRequestTest (OperationShape operation , HttpRequestTestCase testCase ) {
250
+ Symbol operationSymbol = symbolProvider .toSymbol (operation );
251
+
252
+ // Lowercase all the headers we're expecting as this is what we'll get.
253
+ Map <String , String > headers = testCase .getHeaders ().entrySet ().stream ()
254
+ .map (entry -> new Pair <>(entry .getKey ().toLowerCase (Locale .US ), entry .getValue ()))
255
+ .collect (MapUtils .toUnmodifiableMap (Pair ::getLeft , Pair ::getRight ));
256
+ String queryParameters = Node .prettyPrintJson (buildQueryBag (testCase ));
257
+ String headerParameters = Node .prettyPrintJson (ObjectNode .fromStringMap (headers ));
258
+ String body = testCase .getBody ().orElse (null );
259
+
260
+ String testName = testCase .getId () + ":ServerRequest" ;
261
+ testCase .getDocumentation ().ifPresent (writer ::writeDocs );
262
+ writer .openBlock ("it($S, async () => {" , "});\n " , testName , () -> {
263
+ // TODO: use the symbol provider when it's ready
264
+ String serviceName = StringUtils .capitalize (service .getId ().getName ());
265
+ String operationName = StringUtils .capitalize (operation .getId ().getName ());
266
+ Symbol inputType = operationSymbol .expectProperty ("inputType" , Symbol .class );
267
+
268
+ // Declare "r" to capture the deserialized input
269
+ writer .write ("let r: any;" );
270
+
271
+ // This doesn't actually satisfy the service's server interface since it only applies to one operation
272
+ // and even then we're using any for response. This doesn't actually matter though since we should only
273
+ // ever be calling this one operation and we're not using the output.
274
+ writer .openBlock ("class TestService {" , "}" , () -> {
275
+ writer .openBlock ("$L(input: $T, request: HttpRequest): any {" , "}" , operationName , inputType , () -> {
276
+ writer .write ("r = input;" );
277
+ writer .write ("return {};" );
278
+ });
279
+ });
280
+
281
+ String getHandlerName = String .format ("get%sServiceHandler" , serviceName );
282
+ writer .addImport (getHandlerName , getHandlerName ,
283
+ "./protocols/" + ProtocolGenerator .getSanitizedName (protocolGenerator .getName ()));
284
+
285
+ // Cast the service as any so TS will ignore the fact that the type being passed in is incomplete.
286
+ writer .write ("const handler = $L(new TestService() as any);" , getHandlerName );
287
+
288
+ // Construct a new http request according to the test case definition.
289
+ writer .openBlock ("const request = new HttpRequest({" , "});" , () -> {
290
+ writer .write ("method: $S," , testCase .getMethod ());
291
+ writer .write ("hostname: $S," , testCase .getHost ().orElse ("foo.example.com" ));
292
+ writer .write ("path: $S," , testCase .getUri ());
293
+ writer .write ("query: $L," , queryParameters );
294
+ writer .write ("headers: $L," , headerParameters );
295
+ if (body != null ) {
296
+ writer .write ("body: Readable.from([$S])," , body );
297
+ }
298
+ });
299
+ writer .write ("await handler.handle(request);" );
300
+
301
+ writeRequestParamAssertions (operation , testCase );
302
+ });
303
+ }
304
+
305
+ private ObjectNode buildQueryBag (HttpRequestTestCase testCase ) {
306
+ // The query params in the test definition is a list of strings that looks like
307
+ // "Foo=Bar", so we need to split the keys from the values.
308
+ Map <String , List <String >> query = testCase .getQueryParams ().stream ()
309
+ .map (pair -> {
310
+ String [] split = pair .split ("=" );
311
+ String key ;
312
+ String value = "" ;
313
+ try {
314
+ // The strings we're given are url encoded, so we need to decode them. In an actual implementation
315
+ // the request we're given will have already decoded these.
316
+ key = URLDecoder .decode (split [0 ], StandardCharsets .UTF_8 .toString ());
317
+ if (split .length > 1 ) {
318
+ value = URLDecoder .decode (split [1 ], StandardCharsets .UTF_8 .toString ());
319
+ }
320
+ } catch (UnsupportedEncodingException e ) {
321
+ throw new RuntimeException (e );
322
+ }
323
+ return Pair .of (key , value );
324
+ })
325
+ // Query lists/sets will just use the same key repeatedly, so here we collect all the values that
326
+ // share a key.
327
+ .collect (Collectors .groupingBy (Pair ::getKey , Collectors .mapping (Pair ::getValue , Collectors .toList ())));
328
+
329
+ ObjectNode .Builder nodeBuilder = ObjectNode .objectNodeBuilder ();
330
+ for (Map .Entry <String , List <String >> entry : query .entrySet ()) {
331
+ // The value of the query bag can either be a single string or a list, so we need to ensure individual
332
+ // values are written out as individual strings.
333
+ if (entry .getValue ().size () == 1 ) {
334
+ nodeBuilder .withMember (entry .getKey (), StringNode .from (entry .getValue ().get (0 )));
335
+ } else {
336
+ nodeBuilder .withMember (entry .getKey (), ArrayNode .fromStrings (entry .getValue ()));
337
+ }
338
+ }
339
+ return nodeBuilder .build ();
340
+ }
341
+
220
342
// Ensure that the serialized request matches the expected request.
221
343
private void writeRequestAssertions (OperationShape operation , HttpRequestTestCase testCase ) {
222
344
writer .write ("expect(r.method).toBe($S);" , testCase .getMethod ());
@@ -405,10 +527,40 @@ private void writeResponseTestSetup(OperationShape operation, HttpResponseTestCa
405
527
private void writeResponseAssertions (Shape operationOrError , HttpResponseTestCase testCase ) {
406
528
writer .write ("expect(r['$$metadata'].httpStatusCode).toBe($L);" , testCase .getCode ());
407
529
408
- writeParamAssertions (operationOrError , testCase );
530
+ writeReponseParamAssertions (operationOrError , testCase );
409
531
}
410
532
411
- private void writeParamAssertions (Shape operationOrError , HttpResponseTestCase testCase ) {
533
+ private void writeRequestParamAssertions (OperationShape operation , HttpRequestTestCase testCase ) {
534
+ ObjectNode params = testCase .getParams ();
535
+ if (!params .isEmpty ()) {
536
+ StructureShape testInputShape = model .expectShape (
537
+ operation .getInput ().orElseThrow (() -> new CodegenException ("Foo" )),
538
+ StructureShape .class );
539
+
540
+ // Use this trick wrapper to not need more complex trailing comma handling.
541
+ writer .write ("const paramsToValidate: any = [" )
542
+ .call (() -> params .accept (new CommandOutputNodeVisitor (testInputShape )))
543
+ .write ("][0];" );
544
+
545
+ // Extract a payload binding if present.
546
+ Optional <HttpBinding > pb = Optional .empty ();
547
+ HttpBindingIndex index = HttpBindingIndex .of (model );
548
+ List <HttpBinding > payloadBindings = index .getRequestBindings (operation , Location .PAYLOAD );
549
+ if (!payloadBindings .isEmpty ()) {
550
+ pb = Optional .of (payloadBindings .get (0 ));
551
+ }
552
+ final Optional <HttpBinding > payloadBinding = pb ;
553
+
554
+ writeParamAssertions (writer , payloadBinding , () -> {
555
+ // TODO: replace this with a collector from the server config once it's available
556
+ writer .addImport ("streamCollector" , "__streamCollector" , "@aws-sdk/node-http-handler" );
557
+ writer .write ("const comparableBlob = await __streamCollector(r[$S]);" ,
558
+ payloadBinding .get ().getMemberName ());
559
+ });
560
+ }
561
+ }
562
+
563
+ private void writeReponseParamAssertions (Shape operationOrError , HttpResponseTestCase testCase ) {
412
564
ObjectNode params = testCase .getParams ();
413
565
if (!params .isEmpty ()) {
414
566
StructureShape testOutputShape ;
@@ -438,40 +590,49 @@ private void writeParamAssertions(Shape operationOrError, HttpResponseTestCase t
438
590
return null ;
439
591
});
440
592
441
- // If we have a streaming payload blob, we need to collect it to something that
442
- // can be compared with the test contents. This emulates the customer experience.
443
- boolean hasStreamingPayloadBlob = payloadBinding
444
- .map (binding ->
593
+ writeParamAssertions (writer , payloadBinding , () -> {
594
+ writer .write ("const comparableBlob = await client.config.streamCollector(r[$S]);" ,
595
+ payloadBinding .get ().getMemberName ());
596
+ });
597
+ }
598
+ }
599
+
600
+ private void writeParamAssertions (
601
+ TypeScriptWriter writer ,
602
+ Optional <HttpBinding > payloadBinding ,
603
+ Runnable writeComparableBlob
604
+ ) {
605
+ // If we have a streaming payload blob, we need to collect it to something that
606
+ // can be compared with the test contents. This emulates the customer experience.
607
+ boolean hasStreamingPayloadBlob = payloadBinding
608
+ .map (binding ->
445
609
model .getShape (binding .getMember ().getTarget ())
446
610
.filter (Shape ::isBlobShape )
447
611
.filter (s -> s .hasTrait (StreamingTrait .ID ))
448
612
.isPresent ())
449
- .orElse (false );
613
+ .orElse (false );
614
+
615
+ if (hasStreamingPayloadBlob ) {
616
+ writeComparableBlob .run ();
617
+ }
450
618
451
- // Do the collection for payload blobs.
619
+ // Perform parameter comparisons.
620
+ writer .openBlock ("Object.keys(paramsToValidate).forEach(param => {" , "});" , () -> {
621
+ writer .write ("expect(r[param]).toBeDefined();" );
452
622
if (hasStreamingPayloadBlob ) {
453
- writer .write ("const comparableBlob = await client.config.streamCollector(r[$S]);" ,
454
- payloadBinding .get ().getMemberName ());
623
+ writer .openBlock ("if (param === $S) {" , "} else {" , payloadBinding .get ().getMemberName (), () ->
624
+ writer .write ("expect(equivalentContents(comparableBlob, "
625
+ + "paramsToValidate[param])).toBe(true);" ));
626
+ writer .indent ();
455
627
}
456
628
457
- // Perform parameter comparisons.
458
- writer .openBlock ("Object.keys(paramsToValidate).forEach(param => {" , "});" , () -> {
459
- writer .write ("expect(r[param]).toBeDefined();" );
460
- if (hasStreamingPayloadBlob ) {
461
- writer .openBlock ("if (param === $S) {" , "} else {" , payloadBinding .get ().getMemberName (), () ->
462
- writer .write ("expect(equivalentContents(comparableBlob, "
463
- + "paramsToValidate[param])).toBe(true);" ));
464
- writer .indent ();
465
- }
629
+ writer .write ("expect(equivalentContents(r[param], paramsToValidate[param])).toBe(true);" );
466
630
467
- writer .write ("expect(equivalentContents(r[param], paramsToValidate[param])).toBe(true);" );
468
-
469
- if (hasStreamingPayloadBlob ) {
470
- writer .dedent ();
471
- writer .write ("}" );
472
- }
473
- });
474
- }
631
+ if (hasStreamingPayloadBlob ) {
632
+ writer .dedent ();
633
+ writer .write ("}" );
634
+ }
635
+ });
475
636
}
476
637
477
638
/**
0 commit comments