Skip to content

Commit 49805bd

Browse files
authored
Qualcomm AI Engine Direct - xr model enablement (mld_f)
Differential Revision: D74590962 Pull Request resolved: #10546
1 parent ebda84d commit 49805bd

File tree

13 files changed

+423
-65
lines changed

13 files changed

+423
-65
lines changed

backends/qualcomm/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from .expand_broadcast_tensor_shape import ExpandBroadcastTensorShape
2121
from .fixed_linear_keep_dim import FixedLinearKeepDim
2222
from .fold_qdq import FoldQDQ
23+
from .fuse_consecutive_cast import FuseConsecutiveCast
2324
from .fuse_consecutive_transpose import FuseConsecutiveTranspose
2425
from .i64_to_i32 import I64toI32
2526
from .insert_io_qdq import InsertIOQDQ
@@ -54,6 +55,7 @@
5455
ExpandBroadcastTensorShape,
5556
FixedLinearKeepDim,
5657
FoldQDQ,
58+
FuseConsecutiveCast,
5759
FuseConsecutiveTranspose,
5860
I64toI32,
5961
InsertIOQDQ,
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
import torch
9+
10+
from executorch.exir.dialects._ops import ops as exir_ops
11+
from executorch.exir.pass_base import ExportPass, PassResult
12+
from executorch.exir.passes import dead_code_elimination_pass
13+
14+
15+
class FuseConsecutiveCast(ExportPass):
16+
"""
17+
This pass fuses consecutive cast into one or none to reduce runtime
18+
overhead.
19+
To simplify the fuse logic, we ensure each cast node's output has at most 1 cast node
20+
by cloning cast.
21+
Example:
22+
Before clone cast:
23+
relu -> cast1 ─> cast2
24+
|──────> cast3
25+
26+
After clone cast:
27+
relu ─> cast1 ──────> cast2
28+
|───> cast4(new) ─> cast3
29+
"""
30+
31+
def __init__(self):
32+
super().__init__()
33+
self.op_map = {
34+
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
35+
exir_ops.edge.aten._to_copy.default,
36+
}
37+
self.visited = set()
38+
self.nodes = []
39+
40+
def _canonicalize_cast(
41+
self, graph_module: torch.fx.GraphModule
42+
) -> torch.fx.GraphModule:
43+
# replace all i64 cast nodes with i32 version
44+
graph = graph_module.graph
45+
for n in graph_module.graph.nodes:
46+
if n.target in self.op_map and n.meta["val"].dtype == torch.int64:
47+
users = list(n.users)
48+
for user in users:
49+
# bypass graph output node to meet original convention
50+
if user.op == "output":
51+
continue
52+
53+
with graph.inserting_after(n):
54+
cast_node = graph.create_node(
55+
"call_function",
56+
exir_ops.edge.aten._to_copy.default,
57+
n.args,
58+
kwargs={"dtype": torch.int32},
59+
)
60+
cast_node.meta = n.meta
61+
cast_node.meta["val"] = cast_node.meta["val"].to(torch.int32)
62+
user.replace_input_with(n, cast_node)
63+
64+
graph.eliminate_dead_code()
65+
66+
# clone nodes for future fusion
67+
for n in graph_module.graph.nodes:
68+
# make sure we're handling cast node instead of convert node
69+
if n.target in self.op_map and n.kwargs.get("dtype", None) is not None:
70+
users = [user for user in list(n.users) if user.target in self.op_map]
71+
if len(users) > 1:
72+
for i in range(1, len(users)):
73+
with graph.inserting_after(n):
74+
clone_cast_node = graph.create_node(
75+
"call_function",
76+
exir_ops.edge.aten._to_copy.default,
77+
n.args,
78+
kwargs=n.kwargs,
79+
)
80+
clone_cast_node.meta = n.meta
81+
users[i].replace_input_with(n, clone_cast_node)
82+
83+
def _traverse(self, node):
84+
if node in self.visited or node.target not in self.op_map:
85+
return
86+
87+
self.nodes.append(node)
88+
self.visited.add(node)
89+
next_users = [n for n in list(node.users) if n.target in self.op_map]
90+
91+
assert (
92+
len(next_users) <= 1
93+
), "Each cast node should have at most 1 cast output node after _clone_cast"
94+
if not next_users:
95+
return
96+
else:
97+
self._traverse(list(node.users)[0])
98+
99+
def _fuse(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
100+
for n in graph_module.graph.nodes:
101+
self._traverse(n)
102+
# TODO: how to handle following scenario (won't happen for quantized graph)
103+
# fp -> to(i32) -> to(fp)
104+
if len(self.nodes) > 1:
105+
input_node, output_node = self.nodes[0], self.nodes[-1]
106+
output_node.replace_input_with(output_node.args[0], input_node.args[0])
107+
108+
# clear current stack
109+
self.nodes = []
110+
111+
def call(self, graph_module: torch.fx.GraphModule):
112+
self._canonicalize_cast(graph_module)
113+
self._fuse(graph_module)
114+
graph_module.recompile()
115+
dead_code_elimination_pass(graph_module)
116+
return PassResult(graph_module, True)

backends/qualcomm/_passes/i64_to_i32.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,14 @@ class I64toI32(ExportPass):
3131
exir_ops.edge.aten.full.default,
3232
exir_ops.edge.aten.scalar_tensor.default,
3333
}
34+
# This dict is to ensure that the input of the OPs are int64 due to Pytorch restrictions.
35+
# For example, scatter op can only accept args[2], the index, as int64.
36+
# Key: Ops to cast input to i64
37+
# Value: The args' indices to add casting op
38+
I64_IN_OPS = {
39+
exir_ops.edge.aten.gather.default: [2],
40+
exir_ops.edge.aten.scatter.src: [2],
41+
}
3442
copy_op = exir_ops.edge.aten._to_copy.default
3543

3644
def __init__(
@@ -141,11 +149,32 @@ def _cast_constant_to_int32(self, graph_module: torch.fx.GraphModule):
141149
n.replace_all_uses_with(to_dst_node)
142150
to_dst_node.args = (n,)
143151

152+
def _cast_op_args_to_i64(self, graph_module: torch.fx.GraphModule):
153+
# input will be cast to i32 during call_operator dtype propogation
154+
# insert i64 cast node to prevent PyTorch's operator validation failure
155+
for node in graph_module.graph.nodes:
156+
if node.target in self.I64_IN_OPS:
157+
with graph_module.graph.inserting_before(node):
158+
arg_indices = self.I64_IN_OPS[node.target]
159+
for arg_index in arg_indices:
160+
input_node = node.args[arg_index]
161+
cast_i64_node = graph_module.graph.create_node(
162+
"call_function",
163+
self.copy_op,
164+
(input_node,),
165+
{"dtype": torch.int64},
166+
)
167+
cast_i64_node.meta["val"] = node.meta["val"].to(torch.int64)
168+
args_list = list(node.args)
169+
args_list[arg_index] = cast_i64_node
170+
node.args = tuple(args_list)
171+
144172
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
145173
# Record original output dtype to ensure that if user expects int64 as output,
146174
# convert the output back to int64 if it is casted from int64->int32.
147175
self._record_original_output_dtype(graph_module)
148176
self._cast_constant_to_int32(graph_module)
177+
self._cast_op_args_to_i64(graph_module)
149178
graph_module = super().call(graph_module).graph_module
150179
self._preserve_output_dtype(graph_module)
151180
graph_module.recompile()

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
ExpandBroadcastTensorShape,
2626
FixedLinearKeepDim,
2727
FoldQDQ,
28+
FuseConsecutiveCast,
2829
FuseConsecutiveTranspose,
2930
I64toI32,
3031
InsertIOQDQ,
@@ -182,6 +183,7 @@ def transform_for_to_edge_pipeline(
182183

183184
# Before quantizer
184185
def transform_for_annotation_pipeline(self, graph_module: GraphModule):
186+
self.add_pass(RemoveRedundancy(quantization_capture=True))
185187
self.add_pass(ReduceDynamicRange())
186188
self.add_pass(RecomposePixelUnshuffle(quantization_capture=True))
187189
self.add_pass(ReplaceArangeArgs())
@@ -214,5 +216,6 @@ def transform_for_preprocess_pipeline(self, exported_program: ExportedProgram):
214216
self.add_pass(InsertRequantize())
215217
self.add_pass(InsertIOQDQ(exported_program))
216218
self.add_pass(LayoutTransform(exported_program, insert_permute=True))
219+
self.add_pass(FuseConsecutiveCast())
217220
self.add_pass(FuseConsecutiveTranspose())
218221
return self._transform(exported_program.graph_module)

backends/qualcomm/_passes/remove_redundancy.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ class RemoveRedundancy(ExportPass):
1414
Trim certain operators to reduce unnecessary overhead.
1515
"""
1616

17-
def __init__(self):
17+
def __init__(self, quantization_capture=False):
1818
super(RemoveRedundancy, self).__init__()
19-
self.redundant_ops = {
19+
self.redundant_ops_general = {
2020
torch.clone: self._default_condition,
2121
torch.ops.aten.clone.default: self._default_condition,
2222
exir_ops.edge.aten.clone.default: self._default_condition,
@@ -28,7 +28,16 @@ def __init__(self):
2828
exir_ops.edge.dim_order_ops._to_dim_order_copy.default: self._dim_order_op_condition,
2929
# remove channel_last / contiguous _to_copy if '_skip_dim_order' is set to True
3030
exir_ops.edge.aten._to_copy.default: self._to_copy_op_condition,
31+
torch.ops.aten._assert_tensor_metadata.default: self._default_condition,
3132
}
33+
self.redundant_ops_annotation = {
34+
torch.ops.aten._assert_tensor_metadata.default: self._default_condition,
35+
}
36+
self.redundant_ops = (
37+
self.redundant_ops_annotation
38+
if quantization_capture
39+
else self.redundant_ops_general
40+
)
3241

3342
def _dim_order_op_condition(self, node):
3443
dim_order = node.kwargs.get("dim_order")
@@ -50,6 +59,10 @@ def _remove(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
5059
continue
5160

5261
to_be_remove = n
62+
# assert_tensor_metadata op has no user
63+
if len(n.users.keys()) == 0:
64+
n.args = ()
65+
# normal case
5366
for user_n in list(n.users.keys()):
5467
user_n.replace_input_with(n, n.args[0])
5568
graph_module.graph.erase_node(to_be_remove)

backends/qualcomm/builders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
op_expand,
3333
op_full,
3434
op_full_like,
35+
op_gather,
3536
op_ge,
3637
op_gelu,
3738
op_group_norm,
@@ -120,6 +121,7 @@
120121
op_expand,
121122
op_full,
122123
op_full_like,
124+
op_gather,
123125
op_ge,
124126
op_gelu,
125127
op_group_norm,

backends/qualcomm/builders/op_argmin.py

Lines changed: 10 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
import torch
1111
from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA
1212

13-
from .node_visitor import NodeVisitor, QNN_TENSOR_TYPE_MAP, register_node_visitor
14-
from .qnn_constants import OpArgmin, OpCast, QNN_OP_PACKAGE_NAME_QTI_AISW
13+
from .node_visitor import NodeVisitor, register_node_visitor
14+
from .qnn_constants import OpArgmin, QNN_OP_PACKAGE_NAME_QTI_AISW
1515

1616

1717
@register_node_visitor
@@ -26,7 +26,6 @@ def define_node(
2626
node: torch.fx.Node,
2727
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
2828
) -> PyQnnWrapper.PyQnnOpWrapper:
29-
op_wrapper_list = []
3029
input_node = self.get_node(node.args[0])
3130
input_tensor = self.get_tensor(input_node, node)
3231
output_tensor = self.get_tensor(node, node)
@@ -38,26 +37,14 @@ def define_node(
3837
nodes_to_wrappers,
3938
)
4039
argmin_input_tensors = [argmin_inp_tensor_wrapper]
41-
42-
# arg output is index, do not quantize it.
43-
node.meta.pop("quant_attrs", None)
44-
input_quant_encoding, input_quant_configs = self.get_quant_encoding_conf(
45-
input_node, node
46-
)
47-
48-
argmin_intermediate_tensor_wrapper = self.define_custom_tensor_wrapper(
49-
node_name=node.name + "_cast",
50-
tensor_type=PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
51-
dtype=QNN_TENSOR_TYPE_MAP[torch.int32],
52-
quant_encoding=input_quant_encoding,
53-
quant_configs=input_quant_configs,
54-
dims=output_tensor.size(),
55-
tensor=output_tensor,
56-
is_fake_tensor=True,
57-
nodes_to_wrappers=nodes_to_wrappers,
40+
argmin_out_tensor_wrapper = self.define_tensor(
41+
node,
42+
node,
43+
output_tensor.to(torch.int32),
44+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
45+
nodes_to_wrappers,
5846
)
59-
60-
argmin_output_tensors = [argmin_intermediate_tensor_wrapper]
47+
argmin_output_tensors = [argmin_out_tensor_wrapper]
6148

6249
dim = cast(int, node.args[1])
6350
if dim < 0:
@@ -87,24 +74,4 @@ def define_node(
8774
{QCOM_DATA: keep_dims},
8875
)
8976

90-
op_wrapper_list.append(argmin_op)
91-
92-
cast_op = PyQnnWrapper.PyQnnOpWrapper(
93-
node.name + "_cast",
94-
QNN_OP_PACKAGE_NAME_QTI_AISW,
95-
OpCast.op_name,
96-
)
97-
98-
output_tensor_wrapper = self.define_tensor(
99-
node,
100-
node,
101-
output_tensor,
102-
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
103-
nodes_to_wrappers,
104-
)
105-
106-
cast_op.AddInputTensors([argmin_intermediate_tensor_wrapper])
107-
cast_op.AddOutputTensors([output_tensor_wrapper])
108-
op_wrapper_list.append(cast_op)
109-
110-
return op_wrapper_list
77+
return argmin_op

0 commit comments

Comments
 (0)