Skip to content

Commit 3377f9d

Browse files
committed
Qualcomm AI Engine Direct - Mimi Enablement Stage 2
1 parent 6adff9c commit 3377f9d

18 files changed

+701
-45
lines changed

backends/qualcomm/_passes/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .convert_bmm_to_matmul import ConvertBmmToMatmul
1111
from .convert_conv1d_to_conv2d import ConvertConv1dToConv2d
1212
from .decompose_any import DecomposeAny
13+
from .decompose_cdist import DecomposeCDist
1314
from .decompose_einsum import DecomposeEinsum
1415
from .decompose_expm1 import DecomposeExpM1
1516
from .decompose_linalg_vector_norm import DecomposeLinalgVectorNorm
@@ -26,6 +27,7 @@
2627
from .recompose_pixel_unshuffle import RecomposePixelUnshuffle
2728
from .recompose_rms_norm import RecomposeRmsNorm
2829
from .reduce_dynamic_range import ReduceDynamicRange
30+
from .remove_0d_tensor import Remove0DTensor
2931
from .remove_redundancy import RemoveRedundancy
3032
from .replace_arange_args import ReplaceArangeArgs
3133
from .replace_index_put_input import ReplaceIndexPutInput
@@ -40,6 +42,7 @@
4042
ConvertBmmToMatmul,
4143
ConvertConv1dToConv2d,
4244
DecomposeAny,
45+
DecomposeCDist,
4346
DecomposeEinsum,
4447
DecomposeExpM1,
4548
DecomposeLinalgVectorNorm,
@@ -56,6 +59,7 @@
5659
RecomposePixelUnshuffle,
5760
RecomposeRmsNorm,
5861
ReduceDynamicRange,
62+
Remove0DTensor,
5963
RemoveRedundancy,
6064
ReplaceArangeArgs,
6165
ReplaceIndexPutInput,

backends/qualcomm/_passes/annotate_stack.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,16 @@ class AnnotateStack(ExportPass):
1717
generated after quantization process.
1818
"""
1919

20-
decomp_ops = [torch.ops.aten.unbind.int]
20+
decomp_ops = [torch.ops.aten.stack.default]
2121

2222
def __init__(self, edge_program: torch.export.ExportedProgram):
2323
super(AnnotateStack, self).__init__()
2424
self.edge_program = edge_program
2525

2626
def _annotate_stack(self, graph_module: torch.fx.GraphModule):
27-
partitions = get_source_partitions(graph_module.graph, [torch.stack, "stack"])
27+
partitions = get_source_partitions(
28+
graph_module.graph, [torch.stack, torch.ops.aten.stack.default, "stack"]
29+
)
2830
for _, src_partitions in partitions.items():
2931
for src_partition in src_partitions:
3032
output = src_partition.output_nodes[0]

backends/qualcomm/_passes/annotate_unbind.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ def __init__(self, edge_program: torch.export.ExportedProgram):
2424
self.edge_program = edge_program
2525

2626
def _annotate_unbind(self, graph_module: torch.fx.GraphModule):
27-
partitions = get_source_partitions(graph_module.graph, [torch.unbind, "unbind"])
27+
partitions = get_source_partitions(
28+
graph_module.graph, [torch.unbind, torch.ops.aten.unbind.int, "unbind"]
29+
)
2830
for _, src_partitions in partitions.items():
2931
for src_partition in src_partitions:
3032
if src_partition.input_nodes[0].target in dq_ops:

backends/qualcomm/_passes/convert_conv1d_to_conv2d.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch
88
import torch.nn as nn
99
from executorch.backends.qualcomm.builders.utils import get_parameter, set_parameter
10+
from executorch.backends.qualcomm.utils.constants import QCOM_REQUANTIZE
1011
from executorch.exir.dialects._ops import ops as exir_ops
1112
from executorch.exir.pass_base import ExportPass, PassResult
1213

@@ -43,6 +44,7 @@ def call(self, graph_module: torch.fx.GraphModule):
4344
unsqueeze_node.meta = copy_meta(
4445
input_node.meta, lambda m: {**m, "val": m["val"].unsqueeze(2)}
4546
)
47+
4648
with graph_module.graph.inserting_after(unsqueeze_node):
4749

4850
filter_node = node.args[1]
@@ -92,6 +94,17 @@ def call(self, graph_module: torch.fx.GraphModule):
9294
),
9395
)
9496
squeeze_node.meta = copy_meta(node.meta)
97+
98+
if QCOM_REQUANTIZE in input_node.meta:
99+
input_node.meta.pop(QCOM_REQUANTIZE)
100+
unsqueeze_node.meta[QCOM_REQUANTIZE][node.name] = (
101+
unsqueeze_node.meta[QCOM_REQUANTIZE].pop(node.name)
102+
)
103+
if QCOM_REQUANTIZE in node.meta:
104+
squeeze_node.meta[QCOM_REQUANTIZE] = node.meta[
105+
QCOM_REQUANTIZE
106+
]
107+
conv2d_node.meta.pop(QCOM_REQUANTIZE, None)
95108
for user in node.users.copy():
96109
user.replace_input_with(node, squeeze_node)
97110
graph.eliminate_dead_code()
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.exir.pass_base import ExportPass, PassResult
9+
10+
11+
class CDist(torch.nn.Module):
12+
def __init__(self):
13+
super().__init__()
14+
15+
def forward(self, x, y):
16+
# Step 1: Compute differences
17+
diff = x.unsqueeze(2) - y.unsqueeze(1)
18+
19+
# Step 2: Square differences
20+
sq_diff = diff**2
21+
22+
# Step 3: Sum of squares
23+
sum_sq_diff = sq_diff.sum(dim=-1)
24+
25+
# Step 4: Square root
26+
distances = torch.sqrt(sum_sq_diff)
27+
28+
return distances
29+
30+
31+
class DecomposeCDist(ExportPass):
32+
"""
33+
Decompose for math equivalent op.
34+
"""
35+
36+
def __init__(self) -> None:
37+
super().__init__()
38+
39+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
40+
graph = graph_module.graph
41+
for node in graph.nodes:
42+
model = CDist()
43+
if torch.ops.aten.cdist.default == node.target:
44+
decomposed_module = torch.export.export(
45+
model,
46+
(node.args[0].meta["val"], node.args[1].meta["val"]),
47+
strict=True,
48+
).module()
49+
with graph.inserting_before(node):
50+
# remap is used to map original node values to new node values,
51+
# which ensures that reference to nodes are correctly updated in the new graph
52+
remap = {"x": node.args[0], "y": node.args[1]}
53+
54+
for decomposed_node in decomposed_module.graph.nodes:
55+
# no need to copy existent 'output'
56+
if decomposed_node.op == "output":
57+
for user in node.users.copy():
58+
# remap
59+
user.replace_input_with(
60+
node,
61+
remap[decomposed_node.args[0][0]],
62+
)
63+
# no need to copy existent placeholders
64+
elif decomposed_node.op == "placeholder":
65+
# replace node map from string to graph node
66+
remap[decomposed_node] = remap.pop(decomposed_node.name)
67+
else:
68+
remap[decomposed_node] = graph.node_copy(
69+
decomposed_node,
70+
arg_transform=lambda x, remap=remap: remap[x],
71+
)
72+
73+
graph.erase_node(node)
74+
75+
graph.eliminate_dead_code()
76+
graph_module.recompile()
77+
return PassResult(graph_module, True)

backends/qualcomm/_passes/lift_constant_scalar_operands.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,13 @@ class TensorOpInfo:
5353
}
5454

5555

56-
SKIP_LIFT_OPS = {aten.full_like.default, aten.arange.start_step}
56+
SKIP_LIFT_OPS = {
57+
aten.full_like.default,
58+
aten.arange.start_step,
59+
aten.arange.default,
60+
aten.scalar_tensor.default,
61+
aten.elu.default,
62+
}
5763

5864

5965
class LiftConstantScalarOperands(ExportPass):

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
ConvertBmmToMatmul,
1616
ConvertConv1dToConv2d,
1717
DecomposeAny,
18+
DecomposeCDist,
1819
DecomposeEinsum,
1920
DecomposeExpM1,
2021
DecomposeLinalgVectorNorm,
@@ -31,6 +32,7 @@
3132
RecomposePixelUnshuffle,
3233
RecomposeRmsNorm,
3334
ReduceDynamicRange,
35+
Remove0DTensor,
3436
RemoveRedundancy,
3537
ReplaceArangeArgs,
3638
ReplaceIndexPutInput,
@@ -70,7 +72,7 @@ def get_capture_program_passes():
7072
# If a pass is activated, it will be executed by default.
7173
default_passes_and_setting = [
7274
(AnnotateQuantAttrs, True),
73-
(AnnotateStack, False),
75+
(AnnotateStack, True),
7476
(AnnotateUnbind, True),
7577
(ConvertBmmToMatmul, True),
7678
(ConvertConv1dToConv2d, True),
@@ -82,6 +84,7 @@ def get_capture_program_passes():
8284
(LayoutTransform, True),
8385
(RecomposePixelUnshuffle, True),
8486
(RecomposeRmsNorm, False),
87+
(Remove0DTensor, True),
8588
(RemoveRedundancy, True),
8689
(ReplaceIndexPutInput, True),
8790
(TagQuantIO, False),
@@ -174,7 +177,23 @@ def transform_for_to_edge_pipeline(
174177

175178
return exported_program
176179

180+
# Before quantizer
181+
def transform_for_annotation_pipeline(self, graph_module: GraphModule):
182+
self.add_pass(ReduceDynamicRange())
183+
self.add_pass(RecomposePixelUnshuffle(quantization_capture=True))
184+
self.add_pass(ReplaceArangeArgs())
185+
self.add_pass(DecomposeCDist())
186+
self.add_pass(DecomposeScaledDotProductAttention())
187+
self.add_pass(DecomposeSilu())
188+
self.add_pass(DecomposeEinsum())
189+
self.add_pass(DecomposeExpM1())
190+
self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True))
191+
self.add_pass(ReplaceInfValues())
192+
self.add_pass(LiftConstantScalarOperands())
193+
return self._transform(graph_module)
194+
177195
def transform_for_export_pipeline(self, exported_program: ExportedProgram):
196+
self.add_pass(DecomposeCDist())
178197
self.add_pass(DecomposeScaledDotProductAttention())
179198
self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True))
180199
self.add_pass(DecomposeExpM1())
@@ -189,16 +208,3 @@ def transform_for_preprocess_pipeline(self, exported_program: ExportedProgram):
189208
self.add_pass(LayoutTransform(exported_program, insert_permute=True))
190209
self.add_pass(FuseConsecutiveTranspose())
191210
return self._transform(exported_program.graph_module)
192-
193-
def transform_for_annotation_pipeline(self, graph_module: GraphModule):
194-
self.add_pass(ReduceDynamicRange())
195-
self.add_pass(RecomposePixelUnshuffle(quantization_capture=True))
196-
self.add_pass(ReplaceArangeArgs())
197-
self.add_pass(DecomposeScaledDotProductAttention())
198-
self.add_pass(DecomposeSilu())
199-
self.add_pass(DecomposeEinsum())
200-
self.add_pass(DecomposeExpM1())
201-
self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True))
202-
self.add_pass(ReplaceInfValues())
203-
self.add_pass(LiftConstantScalarOperands())
204-
return self._transform(graph_module)
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
from executorch.exir.pass_base import ExportPass, PassResult
10+
11+
12+
class Remove0DTensor(ExportPass):
13+
"""
14+
QNN does not allow 0D tensor, we remove the node that will output an 0D tensor.
15+
Before adding operations to the list of nodes to be removed, please ensure that it will not change the logic.
16+
"""
17+
18+
remove_ops = {
19+
exir_ops.edge.aten.select.int,
20+
exir_ops.edge.aten.select_copy.int,
21+
}
22+
23+
def __init__(self, quantization_capture=False) -> None:
24+
super().__init__()
25+
26+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
27+
graph = graph_module.graph
28+
for node in graph.nodes:
29+
if node.target in self.remove_ops and len(node.meta["val"].shape) == 0:
30+
for user_n in list(node.users.keys()):
31+
user_n.replace_input_with(node, node.args[0])
32+
graph.erase_node(node)
33+
34+
graph.eliminate_dead_code()
35+
graph_module.recompile()
36+
return PassResult(graph_module, True)

backends/qualcomm/builders/op_stack.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def define_node(
5151

5252
dim = 0 if len(node.args) == 1 else cast(int, node.args[1])
5353
if dim < 0:
54-
dim = dim % len(input_tensor.shape)
54+
dim = dim % len(output_tensor.shape)
5555
if QCOM_AXIS_ORDER in node.meta:
5656
dim = node.meta[QCOM_AXIS_ORDER].index(dim)
5757
stack_op = PyQnnWrapper.PyQnnOpWrapper(

backends/qualcomm/partition/qnn_partitioner.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
not_supported_operator,
3535
to_be_implemented_operator,
3636
)
37-
from .utils import generate_qnn_executorch_option, get_skip_decomp_table
37+
from .utils import filter_fn, generate_qnn_executorch_option, get_skip_decomp_table
3838

3939

4040
class QnnOperatorSupport(OperatorSupportBase):
@@ -181,5 +181,4 @@ def ops_to_not_decompose(
181181
self, ep: ExportedProgram
182182
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
183183
do_not_decompose = get_skip_decomp_table()
184-
185-
return do_not_decompose, None
184+
return (do_not_decompose, filter_fn)

backends/qualcomm/partition/utils.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,21 @@ def generate_qnn_executorch_option(
2424
return qnn_compile_spec_buffer
2525

2626

27+
# Logic to determine whether to skip decompose and has higher priority than get_skip_decomp_table()
28+
def filter_fn(node: torch.fx.Node) -> bool:
29+
# QNN does not support int32/int64 IO for the following OPs.
30+
potential_i32_i64_io_ops = [
31+
torch.ops.aten.stack.default,
32+
torch.ops.aten.unbind.int,
33+
]
34+
if node.target in potential_i32_i64_io_ops and node.meta["val"].dtype in [
35+
torch.int32,
36+
torch.int64,
37+
]:
38+
return False
39+
return True
40+
41+
2742
def get_skip_decomp_table() -> List[torch._ops.OperatorBase]:
2843
do_not_decompose = [
2944
torch.ops.aten.adaptive_avg_pool2d.default,
@@ -40,7 +55,7 @@ def get_skip_decomp_table() -> List[torch._ops.OperatorBase]:
4055
torch.ops.aten._safe_softmax.default,
4156
torch.ops.aten.stack.default,
4257
# This request is ignored because it is in a blocklist. Refer to exir/program/_program.py
43-
# torch.ops.aten.unbind.int,
58+
torch.ops.aten.unbind.int,
4459
torch.ops.pt2e_quant.quantize_affine.default,
4560
torch.ops.pt2e_quant.dequantize_affine.default,
4661
]

0 commit comments

Comments
 (0)