17
17
18
18
import java .nio .file .Paths ;
19
19
import java .util .List ;
20
+ import java .util .Map ;
20
21
import java .util .Optional ;
21
22
import java .util .Set ;
23
+ import java .util .TreeMap ;
22
24
import java .util .TreeSet ;
23
25
import java .util .function .BiFunction ;
24
26
import java .util .function .Consumer ;
27
+ import java .util .function .Function ;
25
28
import java .util .logging .Logger ;
29
+ import java .util .stream .Collectors ;
26
30
import software .amazon .smithy .codegen .core .CodegenException ;
27
31
import software .amazon .smithy .codegen .core .Symbol ;
28
32
import software .amazon .smithy .codegen .core .SymbolProvider ;
29
33
import software .amazon .smithy .codegen .core .SymbolReference ;
30
34
import software .amazon .smithy .model .knowledge .HttpBinding .Location ;
31
- import software .amazon .smithy .model .knowledge .OperationIndex ;
32
35
import software .amazon .smithy .model .pattern .SmithyPattern ;
33
36
import software .amazon .smithy .model .shapes .MemberShape ;
34
37
import software .amazon .smithy .model .shapes .OperationShape ;
@@ -313,6 +316,7 @@ public static void writeRetryableTrait(TypeScriptWriter writer, StructureShape e
313
316
* @param errorCodeGenerator A consumer
314
317
* @param shouldParseErrorBody Flag indicating whether need to parse response body in this dispatcher function
315
318
* @param bodyErrorLocationModifier A function that returns the location of an error in a body given a data source.
319
+ * @param operationErrorsToShapes A map of error names to their {@link ShapeId}.
316
320
* @return A set of all error structure shapes for the operation that were dispatched to.
317
321
*/
318
322
static Set <StructureShape > generateErrorDispatcher (
@@ -321,11 +325,11 @@ static Set<StructureShape> generateErrorDispatcher(
321
325
SymbolReference responseType ,
322
326
Consumer <GenerationContext > errorCodeGenerator ,
323
327
boolean shouldParseErrorBody ,
324
- BiFunction <GenerationContext , String , String > bodyErrorLocationModifier
328
+ BiFunction <GenerationContext , String , String > bodyErrorLocationModifier ,
329
+ BiFunction <GenerationContext , OperationShape , Map <String , ShapeId >> operationErrorsToShapes
325
330
) {
326
331
TypeScriptWriter writer = context .getWriter ();
327
332
SymbolProvider symbolProvider = context .getSymbolProvider ();
328
- OperationIndex operationIndex = OperationIndex .of (context .getModel ());
329
333
Set <StructureShape > errorShapes = new TreeSet <>();
330
334
331
335
Symbol symbol = symbolProvider .toSymbol (operation );
@@ -350,10 +354,6 @@ static Set<StructureShape> generateErrorDispatcher(
350
354
SymbolReference baseExceptionReference = getClientBaseException (context );
351
355
errorCodeGenerator .accept (context );
352
356
353
- TreeSet <StructureShape > structureShapes = new TreeSet <>(
354
- operationIndex .getErrors (operation , context .getService ())
355
- );
356
-
357
357
Runnable defaultErrorHandler = () -> {
358
358
if (shouldParseErrorBody ) {
359
359
// Body is already parsed above
@@ -379,20 +379,21 @@ static Set<StructureShape> generateErrorDispatcher(
379
379
});
380
380
};
381
381
382
- if (!structureShapes .isEmpty ()) {
382
+ Map <String , ShapeId > operationNamesToShapes = operationErrorsToShapes .apply (context , operation );
383
+ if (!operationNamesToShapes .isEmpty ()) {
383
384
writer .openBlock ("switch (errorCode) {" , "}" , () -> {
384
385
// Generate the case statement for each error, invoking the specific deserializer.
385
386
386
- structureShapes .forEach (error -> {
387
- final ShapeId errorId = error . getId ();
387
+ operationNamesToShapes .forEach (( name , errorId ) -> {
388
+ StructureShape error = context . getModel (). expectShape ( errorId ). asStructureShape (). get ();
388
389
// Track errors bound to the operation so their deserializers may be generated.
389
390
errorShapes .add (error );
390
391
Symbol errorSymbol = symbolProvider .toSymbol (error );
391
392
String errorDeserMethodName = ProtocolGenerator .getDeserFunctionName (errorSymbol ,
392
393
context .getProtocolName ()) + "Response" ;
393
394
// Dispatch to the error deserialization function.
394
395
String outputParam = shouldParseErrorBody ? "parsedOutput" : "output" ;
395
- writer .write ("case $S:" , errorId . getName () );
396
+ writer .write ("case $S:" , name );
396
397
writer .write ("case $S:" , errorId .toString ());
397
398
writer .indent ()
398
399
.write ("throw await $L($L, context);" , errorDeserMethodName , outputParam )
@@ -468,4 +469,24 @@ private static SymbolReference getClientBaseException(GenerationContext context)
468
469
.symbol (serviceExceptionSymbol )
469
470
.build ();
470
471
}
472
+
473
+ /**
474
+ * Returns a map of error names to their {@link ShapeId}.
475
+ *
476
+ * @param context the generation context
477
+ * @param operation the operation shape to retrieve errors for
478
+ * @return map of error names to {@link ShapeId}
479
+ */
480
+ public static Map <String , ShapeId > getOperationErrors (GenerationContext context , OperationShape operation ) {
481
+ return operation .getErrors ().stream ()
482
+ .collect (Collectors .toMap (
483
+ shapeId -> shapeId .getName (context .getService ()),
484
+ Function .identity (),
485
+ (x , y ) -> {
486
+ if (!x .equals (y )) {
487
+ throw new CodegenException (String .format ("conflicting error shape ids: %s, %s" , x , y ));
488
+ }
489
+ return x ;
490
+ }, TreeMap ::new ));
491
+ }
471
492
}
0 commit comments