@@ -323,7 +323,7 @@ def canonicalize_program(obj):
323
323
update_spill_fill_size (obj )
324
324
325
325
326
- def get_decomp_table () -> Dict [torch ._ops .OperatorBase , Callable ]:
326
+ def get_decomp_table (passes_job ) -> Dict [torch ._ops .OperatorBase , Callable ]:
327
327
source_decompositions = core_aten_decompositions ()
328
328
# The below super ops are supported by QNN
329
329
skip_decompositions = [
@@ -337,10 +337,17 @@ def get_decomp_table() -> Dict[torch._ops.OperatorBase, Callable]:
337
337
torch .ops .pt2e_quant .quantize_affine .default ,
338
338
torch .ops .pt2e_quant .dequantize_affine .default ,
339
339
torch .ops .aten ._safe_softmax .default ,
340
- torch .ops .aten .stack .default , # TODO: Might need to remove this later due to Mimi. QNN does not support int io for stack op.
340
+ torch .ops .aten .stack .default ,
341
341
torch .ops .aten .unbind .int ,
342
342
]
343
343
344
+ # If we want to annotate the decomposed ops, then we should decompose the operation.
345
+ if passes_job and passes_job .get (AnnotateDecomposed , False ):
346
+ skip_decompositions = [
347
+ skip_decomp_op
348
+ for skip_decomp_op in skip_decompositions
349
+ if skip_decomp_op not in AnnotateDecomposed .decomp_ops
350
+ ]
344
351
remove_decompositions (source_decompositions , skip_decompositions )
345
352
346
353
return source_decompositions
@@ -468,7 +475,7 @@ def capture_program(
468
475
module = _preprocess_module (module , inputs )
469
476
ep = torch .export .export (module , inputs , dynamic_shapes = dynamic_shapes , strict = True )
470
477
# TODO: Handle stack op. If we want to run annotate_decomposed pass for stack op, we need to make stack op decompose, which means we need to find a method to remove it from skip_decomp table
471
- decomposed_ep = ep .run_decompositions (get_decomp_table ())
478
+ decomposed_ep = ep .run_decompositions (get_decomp_table (passes_job ))
472
479
core_ep = ExirExportedProgram (decomposed_ep , False )
473
480
core_ep .transform (TensorI64toI32 (edge_program = core_ep ))
474
481
edge_ep = core_ep .to_edge (qnn_edge_config ())
0 commit comments