Skip to content

Commit bf9cd34

Browse files
authored
Qualcomm AI Engine Direct - Ensure that math invariant ops don't change scale and offset (#11989)
Summary: - After QNN 2.35, op validation will check math invariant op should not change scale and offset. So, we should replace annotate_single_in_single_out with annotate_single_in_share_out to ensure not to change scale and offset - Fix the error for internal CI - Fix the bug for batch norm op cc: @haowhsu-quic, @cccclai , @winskuo-quic
1 parent 292c7b4 commit bf9cd34

File tree

5 files changed

+246
-40
lines changed

5 files changed

+246
-40
lines changed

backends/qualcomm/_passes/remove_redundancy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def __init__(self, quantization_capture=False):
2929
# remove channel_last / contiguous _to_copy if '_skip_dim_order' is set to True
3030
exir_ops.edge.aten._to_copy.default: self._to_copy_op_condition,
3131
torch.ops.aten._assert_tensor_metadata.default: self._default_condition,
32+
torch.ops.aten._assert_scalar.default: self._default_condition,
3233
}
3334
self.redundant_ops_annotation = {
3435
torch.ops.aten._assert_tensor_metadata.default: self._default_condition,

backends/qualcomm/builders/op_batch_norm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def define_node(
128128
bias_tensor = self.try_dequantize(
129129
bias_node, get_parameter(bias_node, self.edge_program)
130130
)
131-
amount = (filter_tensor * mean_tensor) / torch.sqrt(var_tensor + eps)
131+
amount = filter_tensor * mean_tensor
132132
bias_tensor = bias_tensor - amount
133133
self.update_encoding(bias_node, bias_tensor, eps)
134134
bias_tensor_wrapper = self.define_tensor(
Lines changed: 205 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,22 @@
1+
import warnings
12
from typing import Dict
23

34
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
5+
import numpy as np
46
import torch
57

6-
from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS
8+
from executorch.backends.qualcomm.utils.constants import QCOM_DATA, QCOM_QUANT_ATTRS
9+
from executorch.exir.dialects._ops import ops as exir_ops
710

8-
from .node_visitor import NodeVisitor
11+
from .node_visitor import NodeVisitor, QNN_TENSOR_TYPE_MAP
912
from .node_visitor_manager import register_node_visitor
10-
from .qnn_constants import OpScatterNd, QNN_OP_PACKAGE_NAME_QTI_AISW
13+
from .qnn_constants import (
14+
OpConcat,
15+
OpReshape,
16+
OpScatterNd,
17+
OpTile,
18+
QNN_OP_PACKAGE_NAME_QTI_AISW,
19+
)
1120

1221

1322
@register_node_visitor
@@ -22,6 +31,7 @@ def define_node(
2231
node: torch.fx.Node,
2332
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
2433
) -> PyQnnWrapper.PyQnnOpWrapper:
34+
op_wrapper_list = []
2535
input_node = self.get_node(node.args[0])
2636
# Because the args[0] of index_put op doesn't annotate, need to fill in the quant_attr with the node here.
2737
if quant_attrs := node.meta.get(QCOM_QUANT_ATTRS):
@@ -35,38 +45,206 @@ def define_node(
3545
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
3646
nodes_to_wrappers,
3747
)
38-
indicies_node = node.args[1]
39-
indices_list = [
40-
self.get_tensor(idx, idx) for idx in indicies_node if idx is not None
41-
]
42-
43-
# Unpack the tuple
44-
indices_unpacked = [torch.flatten(idx) for idx in indices_list]
45-
46-
# Convert to 2-D tensor
47-
indices_qnn = torch.cat(indices_unpacked).unsqueeze(0)
48-
indice_node = [n for n in indicies_node if isinstance(n, torch.fx.Node)]
49-
# TODO consider to write a pass to combine to one input tensor for indices
50-
assert len(indice_node) == 1, "Not support multiple indices tensor"
5148

49+
indicies_node = node.args[1]
50+
index_node_dim = None
51+
index_nodes = []
52+
index_tensors = []
53+
target_index = []
54+
# If there is None in a list, it means all range at that dimension
55+
# E.g., indicies_node: [None, None, aten__to_copy_default_1]
56+
if isinstance(indicies_node, list):
57+
for index, idx_node in enumerate(indicies_node):
58+
# First, collect the indice_node and index of None to construct the shape of index node
59+
# E.g., shape of input: [1, 1024, 12, 64]
60+
# For "None" axis (assume indicies_node: [None, None, aten__to_copy_default_1]),
61+
# target_index: [1, 1024, x], x is the shape of index_tensor, index_node_dim: 2
62+
if isinstance(idx_node, torch.fx.Node):
63+
index_nodes.append(idx_node)
64+
index_tensors.append(self.get_tensor(idx_node, idx_node))
65+
target_index.extend(index_tensors[-1].size())
66+
index_node_dim = index
67+
elif idx_node is None and index_node_dim is None:
68+
# E.g., indicies_node: [None, aten__to_copy_default_1, None]
69+
# Don't need to consider "None" after index_node.
70+
target_index.append(input_tensor.size(index))
71+
else:
72+
warnings.warn(
73+
f"[QNN Delegate Op Builder]: Get the index {idx_node} that is neither a node nor None",
74+
stacklevel=1,
75+
)
76+
return
77+
# Assume that there is only one node in list
78+
assert len(index_nodes) == 1, "Not support multiple indices tensor"
79+
indice_node = index_nodes[0]
80+
indice_tensor = index_tensors[0]
5281
indices_tensor_wrapper = self.define_tensor(
53-
indice_node[0],
82+
indice_node,
5483
node,
55-
indices_qnn,
84+
indice_tensor,
5685
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
5786
nodes_to_wrappers,
5887
)
59-
value_node = self.get_node(node.args[2])
6088

61-
value_tensor = self.get_tensor(value_node, node)
89+
# Need to reconstruct the index tensor.
90+
# E.g., based on ScatterND Op Def in QNN Docs.
91+
# Given that
92+
# shape of input: [1, 12, 1024, 64]
93+
# indicies_node: [None, None, aten__to_copy_default_1]
94+
# shape of aten__to_copy_default_1: [1]
95+
# The shape of index tensor should be [1, 12, 1, 3]
96+
# The index tensor is treated as 4-dimensional tensor of 3-tuples,
97+
# where each 3-tuple is a partial-index into input
98+
# Reference code for QNN ScatterNd:
99+
# output = np.copy(input)
100+
# update_indices = indices.shape[:-1]
101+
# for idx in np.ndindex(update_indices):
102+
# output[indices[idx]] = updates[idx]
103+
104+
# Append one dimension to specify x-tuple
105+
index_shape = target_index + [1]
106+
# Reshape the index_node for tile op
107+
reshape_shape = [
108+
shape if id == index_node_dim else 1 for id, shape in enumerate(index_shape)
109+
]
110+
reshape_output_tensor = indice_tensor.reshape(reshape_shape)
111+
reshape_output_tensor_wrapper = self.define_custom_tensor_wrapper(
112+
node_name=node.name + "_reshape",
113+
tensor_type=PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
114+
dtype=QNN_TENSOR_TYPE_MAP[reshape_output_tensor.dtype],
115+
quant_encoding=PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED,
116+
quant_configs={},
117+
dims=reshape_output_tensor.size(),
118+
tensor=reshape_output_tensor,
119+
is_fake_tensor=True,
120+
nodes_to_wrappers=nodes_to_wrappers,
121+
)
122+
reshape_op = PyQnnWrapper.PyQnnOpWrapper(
123+
node.name,
124+
QNN_OP_PACKAGE_NAME_QTI_AISW,
125+
OpReshape.op_name,
126+
)
127+
reshape_op.AddInputTensors([indices_tensor_wrapper])
128+
reshape_op.AddOutputTensors([reshape_output_tensor_wrapper])
129+
op_wrapper_list.append(reshape_op)
130+
index_put_index_input_tensor_wrapper = reshape_output_tensor_wrapper
131+
132+
# Tile the index_node and concat the target index
133+
if None in indicies_node:
134+
tile_output_tensor = reshape_output_tensor.expand(index_shape)
135+
# Tile the index_node to align with the shape of target_index
136+
# Only need to tile the dim of None axis
137+
# E.g., indicies_node: [None, None, aten__to_copy_default_1]
138+
# Should tile the first two dimension.
139+
multiples = [
140+
shape if id != index_node_dim else 1
141+
for id, shape in enumerate(index_shape)
142+
]
143+
multiples_shape = [len(index_shape)]
144+
tile_output_tensor_wrapper = self.define_custom_tensor_wrapper(
145+
node_name=node.name + "_tile",
146+
tensor_type=PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
147+
dtype=QNN_TENSOR_TYPE_MAP[tile_output_tensor.dtype],
148+
quant_encoding=PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED,
149+
quant_configs={},
150+
dims=tile_output_tensor.size(),
151+
tensor=tile_output_tensor,
152+
is_fake_tensor=True,
153+
nodes_to_wrappers=nodes_to_wrappers,
154+
)
155+
tile_op = PyQnnWrapper.PyQnnOpWrapper(
156+
node.name,
157+
QNN_OP_PACKAGE_NAME_QTI_AISW,
158+
OpTile.op_name,
159+
)
160+
tile_op.AddInputTensors([reshape_output_tensor_wrapper])
161+
tile_op.AddOutputTensors([tile_output_tensor_wrapper])
162+
tile_op.AddTensorParam(
163+
OpTile.param_multiples,
164+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
165+
len(multiples_shape),
166+
multiples_shape,
167+
np.array(multiples, dtype=np.uint32),
168+
True,
169+
)
170+
op_wrapper_list.append(tile_op)
171+
172+
# Repeat index for "None" axis in indicies_node
173+
ranges = [
174+
torch.arange(dim, dtype=indice_tensor.dtype)
175+
for dim in target_index[:-1]
176+
]
177+
target_index_shape = target_index + [len(ranges)]
178+
target_index_tensor = torch.cartesian_prod(*ranges)
179+
reshape_target_index_shape = [
180+
shape if id != index_node_dim else 1
181+
for id, shape in enumerate(target_index_shape)
182+
]
183+
target_index_tensor = target_index_tensor.reshape(
184+
reshape_target_index_shape
185+
)
186+
target_index_tensor = target_index_tensor.expand(
187+
target_index_shape
188+
).contiguous()
189+
target_index_node = torch.fx.Node(
190+
node.graph,
191+
node.name + "_target_index",
192+
"call_function",
193+
exir_ops.edge.aten.tensor.default,
194+
(), # args
195+
{}, # kwargs
196+
)
197+
target_index_tensor_wrapper = self.define_tensor(
198+
target_index_node,
199+
node,
200+
target_index_tensor,
201+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
202+
nodes_to_wrappers,
203+
)
62204

205+
# Concat target_index and tile output to reconstruct index_node
206+
# Cannot use QNN Pack (stack) since QNN Pack is not support int32 dtype
207+
concat_output_tensor = torch.concat(
208+
(target_index_tensor, tile_output_tensor), dim=-1
209+
)
210+
concat_output_tensor_wrapper = self.define_custom_tensor_wrapper(
211+
node_name=node.name + "_concat",
212+
tensor_type=PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
213+
dtype=QNN_TENSOR_TYPE_MAP[concat_output_tensor.dtype],
214+
quant_encoding=PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED,
215+
quant_configs={},
216+
dims=concat_output_tensor.size(),
217+
tensor=concat_output_tensor,
218+
is_fake_tensor=True,
219+
nodes_to_wrappers=nodes_to_wrappers,
220+
)
221+
concat_op = PyQnnWrapper.PyQnnOpWrapper(
222+
node.name,
223+
QNN_OP_PACKAGE_NAME_QTI_AISW,
224+
OpConcat.op_name,
225+
)
226+
concat_op.AddInputTensors(
227+
[target_index_tensor_wrapper, tile_output_tensor_wrapper]
228+
)
229+
concat_op.AddOutputTensors([concat_output_tensor_wrapper])
230+
concat_op.AddScalarParam(
231+
OpConcat.param_axis,
232+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
233+
{QCOM_DATA: np.uint32(concat_output_tensor.dim() - 1)},
234+
)
235+
op_wrapper_list.append(concat_op)
236+
index_put_index_input_tensor_wrapper = concat_output_tensor_wrapper
237+
238+
value_node = self.get_node(node.args[2])
239+
value_tensor = self.get_tensor(value_node, node)
63240
value_tensor_wrapper = self.define_tensor(
64241
value_node,
65242
node,
66243
value_tensor,
67244
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
68245
nodes_to_wrappers,
69246
)
247+
70248
output_tensor = self.get_tensor(node, node)
71249
output_tensor_wrapper = self.define_tensor(
72250
node,
@@ -82,8 +260,12 @@ def define_node(
82260
OpScatterNd.op_name,
83261
)
84262
index_put_op.AddInputTensors(
85-
[input_tensor_wrapper, indices_tensor_wrapper, value_tensor_wrapper]
263+
[
264+
input_tensor_wrapper,
265+
index_put_index_input_tensor_wrapper,
266+
value_tensor_wrapper,
267+
]
86268
)
87269
index_put_op.AddOutputTensors([output_tensor_wrapper])
88-
89-
return index_put_op
270+
op_wrapper_list.append(index_put_op)
271+
return op_wrapper_list

0 commit comments

Comments
 (0)