Skip to content

Commit cbd5c54

Browse files
committed
Update on "[ET-VK] Generalize MeanToSumDiv to any dtype"
This change is required for fp16 models. Differential Revision: [D58040777](https://our.internmc.facebook.com/intern/diff/D58040777/) [ghstack-poisoned]
2 parents efc4869 + 6a75ddc commit cbd5c54

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+1280
-124
lines changed

.gitmodules

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,6 @@
4949
[submodule "backends/vulkan/third-party/Vulkan-Headers"]
5050
path = backends/vulkan/third-party/Vulkan-Headers
5151
url = https://github.com/KhronosGroup/Vulkan-Headers
52-
[submodule "third-party/lm-evaluation-harness"]
53-
path = third-party/lm-evaluation-harness
54-
url = https://github.com/EleutherAI/lm-evaluation-harness
55-
branch = v0.4.1
5652
[submodule "kernels/optimized/third-party/eigen"]
5753
path = kernels/optimized/third-party/eigen
5854
url = https://gitlab.com/libeigen/eigen.git

backends/qualcomm/builders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
op_slice_copy,
4545
op_softmax,
4646
op_space_to_depth,
47+
op_split_with_sizes,
4748
op_sqrt,
4849
op_squeeze,
4950
op_sub,
@@ -95,6 +96,7 @@
9596
op_slice_copy,
9697
op_softmax,
9798
op_space_to_depth,
99+
op_split_with_sizes,
98100
op_squeeze,
99101
op_sqrt,
100102
op_sub,

backends/qualcomm/builders/node_visitor.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -215,9 +215,8 @@ def get_data_type(
215215
self,
216216
tensor: torch.Tensor,
217217
quant_config: Dict,
218-
is_tensor: bool,
219218
) -> PyQnnWrapper.Qnn_TensorType_t:
220-
if quant_config and is_tensor:
219+
if quant_config:
221220
quant_range = quant_config["quant_max"] - quant_config["quant_min"]
222221
unsigned = quant_config["quant_min"] >= 0
223222
if quant_range <= torch.iinfo(torch.int8).max - torch.iinfo(torch.int8).min:
@@ -234,8 +233,8 @@ def get_data_type(
234233
else:
235234
quant_config["dtype"] = torch.int16
236235
return QNN_QUANT_TYPE_MAP[quant_config["dtype"]]
237-
else:
238-
return QNN_TENSOR_TYPE_MAP[tensor.dtype]
236+
237+
return QNN_TENSOR_TYPE_MAP[tensor.dtype]
239238

240239
def define_custom_tensor_wrapper(
241240
self,
@@ -247,10 +246,11 @@ def define_custom_tensor_wrapper(
247246
dims: torch.Size,
248247
tensor: torch.Tensor,
249248
is_fake_tensor: bool,
250-
nodes_to_wrappers: Dict[str, PyQnnWrapper.TensorWrapper],
249+
nodes_to_wrappers: Dict[str, Dict[int, PyQnnWrapper.TensorWrapper]],
250+
wrapper_idx: int = 0,
251251
) -> PyQnnWrapper.TensorWrapper:
252-
if node_name in nodes_to_wrappers:
253-
return nodes_to_wrappers[node_name]
252+
if cached := nodes_to_wrappers[node_name].get(wrapper_idx, None):
253+
return cached
254254
if is_fake_tensor:
255255
tensor_wrapper = PyQnnWrapper.TensorWrapper(
256256
node_name,
@@ -266,18 +266,19 @@ def define_custom_tensor_wrapper(
266266
else:
267267
# Can implement non-fake tensor when there is a need
268268
return None
269-
nodes_to_wrappers[node_name] = tensor_wrapper
269+
nodes_to_wrappers[node_name][wrapper_idx] = tensor_wrapper
270270
return tensor_wrapper
271271

272272
def define_tensor(
273273
self,
274274
node: torch.fx.Node,
275275
tensor: torch.Tensor,
276276
tensor_type: PyQnnWrapper.Qnn_TensorType_t,
277-
nodes_to_wrappers: Dict[str, PyQnnWrapper.TensorWrapper],
277+
nodes_to_wrappers: Dict[str, Dict[int, PyQnnWrapper.TensorWrapper]],
278278
is_input_tensor: bool,
279279
node_name: str = None,
280280
is_tensor: bool = True,
281+
wrapper_idx: int = 0,
281282
) -> PyQnnWrapper.TensorWrapper:
282283
"""
283284
Covert torch.Tensor to TensorWrapper
@@ -293,8 +294,8 @@ def define_tensor(
293294
if node_name is None:
294295
node_name = node.name
295296

296-
if node_name in nodes_to_wrappers:
297-
return nodes_to_wrappers[node_name]
297+
if cached := nodes_to_wrappers[node_name].get(wrapper_idx, None):
298+
return cached
298299
tensor_name = node.name
299300
if is_graph_output(node):
300301
tensor_name = "output_" + tensor_name
@@ -303,7 +304,7 @@ def define_tensor(
303304
quant_encoding, quant_configs = self.get_quant_encoding_conf(
304305
node, is_input_tensor
305306
)
306-
dtype = self.get_data_type(tensor, quant_configs, is_tensor)
307+
dtype = self.get_data_type(tensor, quant_configs)
307308
if isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor):
308309
tensor_wrapper = PyQnnWrapper.TensorWrapper(
309310
tensor_name,
@@ -334,13 +335,13 @@ def define_tensor(
334335
tensor.detach().numpy(),
335336
True,
336337
)
337-
nodes_to_wrappers[node_name] = tensor_wrapper
338+
nodes_to_wrappers[node_name][wrapper_idx] = tensor_wrapper
338339
return tensor_wrapper
339340

340341
def define_node(
341342
self,
342343
node: torch.fx.Node,
343-
nodes_to_wrappers: Dict[str, PyQnnWrapper.TensorWrapper],
344+
nodes_to_wrappers: Dict[str, Dict[int, PyQnnWrapper.TensorWrapper]],
344345
) -> PyQnnWrapper.PyQnnOpWrapper:
345346
"""Convert torch.fx.Node to OpWrapper"""
346347
raise NotImplementedError("NodeVisitor must be extended!")
@@ -372,10 +373,8 @@ def generate_node_to_external_map(
372373
if is_graph_input(node, edge_program):
373374
node_to_external_map[node] = len(node_to_external_map)
374375
for node in edge_program.graph_module.graph.nodes:
375-
if node.op == "output":
376-
for output_nodes in node.args:
377-
for output_node in output_nodes:
378-
node_to_external_map[output_node] = len(node_to_external_map)
376+
if is_graph_output(node):
377+
node_to_external_map[node] = len(node_to_external_map)
379378
return node_to_external_map
380379

381380

backends/qualcomm/builders/op_conv2d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def _define_conv1d(
108108
is_input_tensor=True,
109109
)
110110
unsqueeze_output_tensor = unsqueeze_input_tensor.unsqueeze(1).contiguous()
111-
dtype = self.get_data_type(unsqueeze_output_tensor, input_quant_configs, True)
111+
dtype = self.get_data_type(unsqueeze_output_tensor, input_quant_configs)
112112
unsqueeze_output_tensor_wrapper = self.define_custom_tensor_wrapper(
113113
node_name=node.name + "_unsqueeze",
114114
tensor_type=PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
@@ -186,7 +186,7 @@ def _define_conv1d(
186186
)
187187
conv_output_tensor = self.get_tensor(node, node)
188188
conv_output_tensor = conv_output_tensor.unsqueeze(1).contiguous()
189-
dtype = self.get_data_type(conv_output_tensor, input_quant_configs, True)
189+
dtype = self.get_data_type(conv_output_tensor, input_quant_configs)
190190
conv_output_tensor_wrapper = self.define_custom_tensor_wrapper(
191191
node_name=node.name + "_squeeze",
192192
tensor_type=PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,

backends/qualcomm/builders/op_skip_ops.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,5 +46,9 @@ def define_node(
4646
raise AssertionError(
4747
f"Invalid number of index for {node.name }: {len(node.args[1])}"
4848
)
49-
nodes_to_wrappers[node.name] = nodes_to_wrappers.get(node.args[0].name)
49+
idx = node.args[1]
50+
# to fit the format of nodes_to_wrappers, Dict[str, Dict[int, PyQnnWrapper.TensorWrapper]],
51+
nodes_to_wrappers[node.name] = {
52+
0: nodes_to_wrappers.get(node.args[0].name).get(idx)
53+
}
5054
return
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import cast, Dict, List
7+
8+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
9+
10+
import numpy as np
11+
import torch
12+
13+
from .node_visitor import NodeVisitor, register_node_visitor
14+
from .qnn_constants import OpSplit, QNN_OP_PACKAGE_NAME_QTI_AISW
15+
16+
17+
@register_node_visitor
18+
class SplitWithSizes(NodeVisitor):
19+
target = ["aten.split_with_sizes.default"]
20+
21+
def __init__(self, *args) -> None:
22+
super().__init__(*args)
23+
24+
def define_node(
25+
self,
26+
node: torch.fx.Node,
27+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
28+
) -> PyQnnWrapper.PyQnnOpWrapper:
29+
30+
input_node = node.args[0]
31+
input_tensor = self.get_tensor(input_node, node)
32+
33+
input_tensor_wrapper = self.define_tensor(
34+
input_node,
35+
input_tensor,
36+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
37+
nodes_to_wrappers,
38+
is_input_tensor=True,
39+
)
40+
input_tensor_wrappers = [input_tensor_wrapper]
41+
42+
# split_with_sizes will return a tuple since it has multiple outputs
43+
output_tensor_wrappers = []
44+
for index in range(len(node.meta["val"])):
45+
output_tensor = self.get_tensor(node, node, index)
46+
output_tensor_wrapper = self.define_tensor(
47+
node,
48+
output_tensor,
49+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
50+
nodes_to_wrappers,
51+
is_input_tensor=False,
52+
wrapper_idx=index,
53+
)
54+
output_tensor_wrappers.append(output_tensor_wrapper)
55+
56+
chunks = cast(List[int], node.args[1])
57+
split_indices = []
58+
sum = 0
59+
# Edge represents chunks by specifying the size of each chunk
60+
# QNN represents chunks by specifying the index to split chunks
61+
for index, _value in enumerate(chunks[:-1]):
62+
63+
sum = sum + chunks[index]
64+
split_indices.append(sum)
65+
66+
split_indices_shape = [len(split_indices)]
67+
dim = cast(int, node.args[2])
68+
if dim < 0:
69+
dim = dim % len(input_tensor.shape)
70+
71+
if "axis_order" in node.meta:
72+
dim = node.meta["axis_order"].index(dim)
73+
split_op = PyQnnWrapper.PyQnnOpWrapper(
74+
node.name,
75+
QNN_OP_PACKAGE_NAME_QTI_AISW,
76+
OpSplit.op_name,
77+
)
78+
split_op.AddInputTensors(input_tensor_wrappers)
79+
split_op.AddOutputTensors(output_tensor_wrappers)
80+
split_op.AddTensorParam(
81+
OpSplit.param_split_index,
82+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
83+
len(split_indices_shape),
84+
split_indices_shape,
85+
np.array(split_indices, dtype=np.uint32),
86+
True,
87+
)
88+
89+
split_op.AddScalarParam(
90+
OpSplit.param_axis,
91+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
92+
{"data": np.uint32(dim)},
93+
)
94+
return split_op

backends/qualcomm/builders/qnn_constants.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,12 @@ class Mode(IntEnum):
290290
CRD = 1
291291

292292

293+
class OpSplit:
294+
op_name: str = "Split"
295+
param_axis: str = "axis"
296+
param_split_index: str = "split_index"
297+
298+
293299
@dataclass(init=False, frozen=True)
294300
class OpSqueeze:
295301
op_name: str = "Squeeze"

backends/qualcomm/builders/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,10 @@ def is_graph_output(tensor: torch.fx.Node) -> bool:
7575
tensor: EdgeIR Tensor that is being checked for graph input
7676
"""
7777
for user in tensor.users.keys():
78-
if user.op == "output":
78+
# getitem node is skiped, check the op_skip_ops.py
79+
if user.op == "output" or (
80+
user.target.__name__ == "getitem" and is_graph_output(user)
81+
):
7982
return True
8083
return False
8184

backends/qualcomm/partition/qnn_partitioner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66
import copy
7+
from collections import defaultdict
78
from typing import Any, Dict, List
89

910
import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager
@@ -49,7 +50,7 @@ def __init__(
4950
)
5051

5152
self.skip_node_id_set = skip_node_id_set
52-
self.nodes_to_wrappers = {}
53+
self.nodes_to_wrappers = self.nodes_to_wrappers = defaultdict(dict)
5354
self.qnn_manager = PyQnnManager.QnnManager(
5455
generate_qnn_executorch_option(compiler_specs)
5556
)

backends/qualcomm/passes/convert_to_linear.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from executorch.exir.dialects.edge._ops import EdgeOpOverload as edge_op
1515
from executorch.exir.pass_base import ExportPass, PassResult
1616
from executorch.exir.passes import dead_code_elimination_pass
17+
1718
from torch.fx.passes.utils.source_matcher_utils import (
1819
get_source_partitions,
1920
SourcePartition,
@@ -92,6 +93,7 @@ def _convert_to_linear(
9293
if bias_node:
9394
args.append(bias_node)
9495

96+
# We need a view copy node after linear op
9597
with gm.graph.inserting_before(output):
9698
linear_node = gm.graph.create_node(
9799
"call_function", self.linear, tuple(args)
@@ -104,6 +106,52 @@ def _convert_to_linear(
104106
for user in fn_node.users.copy():
105107
user.replace_input_with(fn_node, linear_node)
106108

109+
# Since QNN has no keep dims for linear op, we will need to add squeeze and unsqueeze around linear node
110+
# TODO: Find a more general conditional statement.
111+
if (
112+
fn_node.target == self.add
113+
and linear_node.meta["val"].dim() == 3
114+
and linear_node.meta["val"].shape[0] == 1
115+
):
116+
squeeze_dim = linear_node.meta["val"].shape[1:]
117+
linear_node.meta["val"] = torch.squeeze(linear_node.meta["val"], 0)
118+
with gm.graph.inserting_after(input_node):
119+
input_users = list(input_node.users.keys())
120+
squeeze_dim = linear_node.meta["val"].shape
121+
squeeze_view_copy_node = gm.graph.create_node(
122+
"call_function",
123+
self.view_copy,
124+
(
125+
input_node,
126+
squeeze_dim,
127+
),
128+
)
129+
squeeze_view_copy_node.meta = linear_node.meta
130+
for user in input_users:
131+
if user == linear_node:
132+
user.replace_input_with(input_node, squeeze_view_copy_node)
133+
with gm.graph.inserting_after(output):
134+
output_users = list(linear_node.users.keys())
135+
unsqueeze_dim = output.args[0].meta["val"].shape
136+
unsqueeze_view_copy_node = gm.graph.create_node(
137+
"call_function",
138+
self.view_copy,
139+
(
140+
linear_node,
141+
unsqueeze_dim,
142+
),
143+
)
144+
unsqueeze_view_copy_node.meta = output.args[0].meta
145+
for user in output_users:
146+
user.replace_input_with(linear_node, unsqueeze_view_copy_node)
147+
if "quant_attrs" in linear_node.meta:
148+
squeeze_view_copy_node.meta["quant_attrs"] = linear_node.meta[
149+
"quant_attrs"
150+
]
151+
unsqueeze_view_copy_node.meta["quant_attrs"] = linear_node.meta[
152+
"quant_attrs"
153+
]
154+
107155
def _extract_mm_ops(self, partitioned_nodes: List[edge_op]) -> List[torch.fx.Node]:
108156
mm_node = [n for n in partitioned_nodes if n.target == self.mm][0]
109157
# weight -> permute -> input of mm
@@ -133,7 +181,10 @@ def _extract_bmm_ops(self, partitioned_nodes: List[edge_op]) -> List[torch.fx.No
133181
ret = [input_node, weight_node, bmm_node]
134182
if add_node:
135183
bias_node = add_node[0].args[1]
136-
ret += bias_node
184+
ret = [input_node, weight_node, add_node[0], bias_node]
185+
else:
186+
ret = [input_node, weight_node, bmm_node]
187+
137188
return ret
138189

139190
def _convert(self, graph_module: torch.fx.GraphModule):

backends/qualcomm/passes/layout_transform.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ class LayoutTransform(ExportPass):
6060
exir_ops.edge.aten.sub.Tensor,
6161
exir_ops.edge.aten.sum.dim_IntList,
6262
exir_ops.edge.aten._to_copy.default,
63+
exir_ops.edge.aten.split_with_sizes.default,
6364
*q_ops,
6465
*dq_ops,
6566
_operator.getitem,
@@ -142,7 +143,7 @@ def is_edge_condition(self, node):
142143
),
143144
(
144145
node.op != "output"
145-
and not isinstance(node.meta["val"], tuple)
146+
and not isinstance(node.meta["val"], (tuple, list))
146147
and len(node.meta["val"].shape) == 0
147148
),
148149
is_parameter(node, self.edge_program),

0 commit comments

Comments
 (0)