Skip to content

Commit f277d44

Browse files
committed
Skip decompose op
1 parent 59df9ef commit f277d44

File tree

3 files changed

+25
-6
lines changed

3 files changed

+25
-6
lines changed

backends/qualcomm/_passes/annotate_decomposed.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ class AnnotateDecomposed(ExportPass):
1717
generated after quantization process.
1818
"""
1919

20+
decomp_ops = [torch.ops.aten.stack.default, torch.ops.aten.unbind.int]
21+
2022
def __init__(self, edge_program: torch.export.ExportedProgram):
2123
super(AnnotateDecomposed, self).__init__()
2224
self.edge_program = edge_program
@@ -32,7 +34,7 @@ def _annotate_unbind(self, graph_module: torch.fx.GraphModule):
3234
n.meta[QCOM_QUANT_ATTRS] = quant_attrs.copy()
3335

3436
def _annotate_stack(self, graph_module: torch.fx.GraphModule):
35-
partitions = get_source_partitions(graph_module.graph, [torch.stack])
37+
partitions = get_source_partitions(graph_module.graph, [torch.stack, "stack"])
3638
for _, src_partitions in partitions.items():
3739
for src_partition in src_partitions:
3840
output = src_partition.output_nodes[0]

backends/qualcomm/utils/utils.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ def canonicalize_program(obj):
323323
update_spill_fill_size(obj)
324324

325325

326-
def get_decomp_table() -> Dict[torch._ops.OperatorBase, Callable]:
326+
def get_decomp_table(passes_job) -> Dict[torch._ops.OperatorBase, Callable]:
327327
source_decompositions = core_aten_decompositions()
328328
# The below super ops are supported by QNN
329329
skip_decompositions = [
@@ -337,10 +337,17 @@ def get_decomp_table() -> Dict[torch._ops.OperatorBase, Callable]:
337337
torch.ops.pt2e_quant.quantize_affine.default,
338338
torch.ops.pt2e_quant.dequantize_affine.default,
339339
torch.ops.aten._safe_softmax.default,
340-
torch.ops.aten.stack.default, # TODO: Might need to remove this later due to Mimi. QNN does not support int io for stack op.
340+
torch.ops.aten.stack.default,
341341
torch.ops.aten.unbind.int,
342342
]
343343

344+
# If we want to annotate the decomposed ops, then we should decompose the operation.
345+
if passes_job and passes_job.get(AnnotateDecomposed, False):
346+
skip_decompositions = [
347+
skip_decomp_op
348+
for skip_decomp_op in skip_decompositions
349+
if skip_decomp_op not in AnnotateDecomposed.decomp_ops
350+
]
344351
remove_decompositions(source_decompositions, skip_decompositions)
345352

346353
return source_decompositions
@@ -468,7 +475,7 @@ def capture_program(
468475
module = _preprocess_module(module, inputs)
469476
ep = torch.export.export(module, inputs, dynamic_shapes=dynamic_shapes, strict=True)
470477
# TODO: Handle stack op. If we want to run annotate_decomposed pass for stack op, we need to make stack op decompose, which means we need to find a method to remove it from skip_decomp table
471-
decomposed_ep = ep.run_decompositions(get_decomp_table())
478+
decomposed_ep = ep.run_decompositions(get_decomp_table(passes_job))
472479
core_ep = ExirExportedProgram(decomposed_ep, False)
473480
core_ep.transform(TensorI64toI32(edge_program=core_ep))
474481
edge_ep = core_ep.to_edge(qnn_edge_config())

examples/models/llama/export_llama_lib.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -794,9 +794,19 @@ def _to_edge_and_lower_llama( # noqa: C901
794794
)
795795
)
796796
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils`
797-
from executorch.backends.qualcomm.utils.utils import _transform, tag_quant_io
797+
from executorch.backends.qualcomm._passes.annotate_decomposed import (
798+
AnnotateDecomposed,
799+
)
800+
from executorch.backends.qualcomm.utils.constants import QCOM_PASS_ACTIVATE_KEY
801+
from executorch.backends.qualcomm.utils.utils import (
802+
_transform,
803+
get_capture_program_passes,
804+
tag_quant_io,
805+
)
798806

799-
_transform(builder_exported_to_edge.edge_manager.exported_program())
807+
passes_job = get_capture_program_passes()
808+
passes_job[AnnotateDecomposed][QCOM_PASS_ACTIVATE_KEY] = True
809+
_transform(builder_exported_to_edge.edge_manager.exported_program(), passes_job)
800810

801811
if args.num_sharding > 0:
802812
model_sharding.split_graph(

0 commit comments

Comments
 (0)