Skip to content

Commit b860608

Browse files
committed
Update on "[ET-VK] Add pass to remove local_scalar_dense"
## Context Scalar tensors (i.e. tensors with only 1 element) are often passed in to functions as scalars via ``` scalar_tensor[0].item() ``` This translates to the following chain in the graph ``` index_select = index_select(scalar_tensor, ...) scalar = local_scalar_dense(index_select) ``` This diff introduces a pass to remove the `local_scalar_dense` "chain" in favor of passing in the input tensor directly. Note that this replacement only occurs if the original tensor is a scalar tensor. In the Vulkan backend, these scalar tensors will be represented as symbolic integers instead of actual tensors, which is why this replacement is valid. However, it may not a valid replacement for other backends. Differential Revision: [D63913432](https://our.internmc.facebook.com/intern/diff/D63913432/) [ghstack-poisoned]
2 parents cbb4a84 + 9994ebd commit b860608

File tree

111 files changed

+2649
-677
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

111 files changed

+2649
-677
lines changed

.ci/docker/ci_commit_pins/pytorch.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
4b2970f7cd3cdd56883cacf116a8693862f89db5
1+
d1b87e26e5c4343f5b56bb1e6f89b479b389bfac

.ci/docker/requirements-ci.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
mpmath==1.3.0
2-
numpy==1.21.3; python_version == '3.10'
2+
numpy==1.22.0; python_version == '3.10'
33
numpy==1.23.2; python_version == '3.11'
44
numpy; python_version >= '3.12'
55
PyYAML==6.0.1

backends/cadence/aot/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ python_library(
2222
deps = [
2323
"fbsource//third-party/pypi/tabulate:tabulate",
2424
"//caffe2:torch",
25+
"//executorch/exir:lib",
2526
"//executorch/exir:memory",
2627
"//executorch/exir/dialects:lib",
2728
"//executorch/exir/dialects/edge:lib",

backends/cadence/aot/compiler.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
from torch.export import export
3737
from torch.export.exported_program import ExportedProgram
3838

39+
from .utils import print_ops_info
40+
3941

4042
# Note: this is not meant as a primary API since it can create inconsistencies
4143
# if the quantizer here is different from the quantizer used to convert. It is
@@ -193,16 +195,17 @@ def export_to_edge(
193195

194196

195197
# Export the model and lower it to an EdgeProgramManager (in edge IR), and
196-
# apply passes specific to Cadence DSP execution.
198+
# apply passes specific to Cadence DSP execution. Return both to print the
199+
# differences.
197200
def export_to_cadence(
198201
model: torch.nn.Module,
199202
inputs: tuple[object, ...],
200203
dump_graphs: bool = False,
201204
) -> EdgeProgramManager:
202-
edge_program_manager = export_to_edge(model, inputs)
205+
edge_prog_manager = export_to_edge(model, inputs)
203206

204207
# Run a couple required passes for quant/dequant ops
205-
cadence_program_manager = edge_program_manager.transform(
208+
cadence_prog_manager = edge_prog_manager.transform(
206209
[
207210
InitializePipeline(),
208211
RemoveZeroSizedCatArgsPass(),
@@ -216,4 +219,10 @@ def export_to_cadence(
216219
]
217220
)
218221

219-
return cadence_program_manager
222+
# Print some information to terminal
223+
print_ops_info(
224+
edge_prog_manager.exported_program().graph_module,
225+
cadence_prog_manager.exported_program().graph_module,
226+
)
227+
228+
return cadence_prog_manager

backends/cadence/aot/export_example.py

Lines changed: 14 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -10,61 +10,26 @@
1010
import tempfile
1111

1212
from executorch.backends.cadence.aot.ops_registrations import * # noqa
13-
import os
1413
from typing import Any, Tuple
1514

1615
from executorch.backends.cadence.aot.compiler import (
1716
convert_pt2,
1817
export_to_cadence,
19-
export_to_edge,
20-
quantize_pt2,
18+
fuse_pt2,
2119
)
2220
from executorch.backends.cadence.aot.quantizer.quantizer import CadenceQuantizer
2321
from executorch.backends.cadence.runtime import runtime
2422
from executorch.backends.cadence.runtime.executor import BundledProgramManager
2523
from executorch.exir import ExecutorchProgramManager
2624
from torch import nn
2725

28-
from .utils import print_ops_info
26+
from .utils import save_bpte_program, save_pte_program
2927

3028

3129
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
3230
logging.basicConfig(level=logging.INFO, format=FORMAT)
3331

3432

35-
def _save_pte_program(
36-
prog: ExecutorchProgramManager, model_name: str, output_dir: str = ""
37-
) -> None:
38-
if model_name.endswith(".pte"):
39-
filename = model_name
40-
else:
41-
filename = os.path.join(output_dir, f"{model_name}.pte")
42-
43-
try:
44-
with open(filename, "wb") as file:
45-
prog.write_to_file(file)
46-
logging.info(f"Saved exported program to {filename}")
47-
except Exception as e:
48-
logging.error(f"Error while saving to {filename}: {e}")
49-
50-
51-
def _save_bpte_program(
52-
buffer: bytes,
53-
model_name: str,
54-
output_dir: str = "",
55-
) -> None:
56-
if model_name.endswith(".bpte"):
57-
filename = model_name
58-
else:
59-
filename = os.path.join(output_dir, f"{model_name}.bpte")
60-
try:
61-
with open(filename, "wb") as f:
62-
f.write(buffer)
63-
logging.info(f"Saved exported program to {filename}")
64-
except Exception as e:
65-
logging.error(f"Error while saving to {output_dir}: {e}")
66-
67-
6833
def export_model(
6934
model: nn.Module,
7035
example_inputs: Tuple[Any, ...],
@@ -74,32 +39,28 @@ def export_model(
7439
working_dir = tempfile.mkdtemp(dir="/tmp")
7540
logging.debug(f"Created work directory {working_dir}")
7641

77-
# convert the model (also called in quantize_pt2)
78-
converted_model = convert_pt2(model, example_inputs, CadenceQuantizer())
42+
# Instantiate the quantizer
43+
quantizer = CadenceQuantizer()
7944

80-
# Get reference outputs from quantized_model
81-
ref_outputs = converted_model(*example_inputs)
45+
# Convert the model
46+
converted_model = convert_pt2(model, example_inputs, quantizer)
8247

83-
# Quantize the model
84-
quantized_model = quantize_pt2(model, example_inputs)
48+
# Get reference outputs from converted model
49+
ref_outputs = converted_model(*example_inputs)
8550

86-
# Get edge program (also called in export_to_cadence)
87-
edge_prog_manager = export_to_edge(quantized_model, example_inputs)
51+
# Quantize the model (note: quantizer needs to be the same as
52+
# the one used in convert_pt2)
53+
quantized_model = fuse_pt2(converted_model, quantizer)
8854

8955
# Get edge program after Cadence specific passes
9056
cadence_prog_manager = export_to_cadence(quantized_model, example_inputs)
9157

58+
# Get executorch program after Cadence specific passes
9259
exec_prog: ExecutorchProgramManager = cadence_prog_manager.to_executorch()
9360

9461
logging.info("Final exported graph:\n")
9562
exec_prog.exported_program().graph_module.graph.print_tabular()
9663

97-
# Print some information to terminal
98-
print_ops_info(
99-
edge_prog_manager.exported_program().graph_module,
100-
cadence_prog_manager.exported_program().graph_module,
101-
)
102-
10364
forward_test_data = BundledProgramManager.bundled_program_test_data_gen(
10465
method="forward", inputs=example_inputs, expected_outputs=ref_outputs
10566
)
@@ -110,9 +71,9 @@ def export_model(
11071
forward_test_data,
11172
)
11273
# Save the program as pte (default name is CadenceDemoModel.pte)
113-
_save_pte_program(exec_prog, file_name, working_dir)
74+
save_pte_program(exec_prog, file_name, working_dir)
11475
# Save the program as btpe (default name is CadenceDemoModel.bpte)
115-
_save_bpte_program(buffer, file_name, working_dir)
76+
save_bpte_program(buffer, file_name, working_dir)
11677

11778
logging.debug(
11879
f"Executorch bundled program buffer saved to {file_name} is {len(buffer)} total bytes"

backends/cadence/aot/utils.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88

99
import logging
1010
import operator
11+
import os
1112
from typing import Dict, List, Tuple
1213

1314
import torch
14-
from executorch.exir import memory
15+
16+
from executorch.exir import ExecutorchProgramManager, memory
1517
from executorch.exir.dialects._ops import ops as exir_ops
1618
from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket
1719
from tabulate import tabulate
@@ -185,3 +187,36 @@ def model_gm_has_SDPA(model_gm: torch.fx.GraphModule) -> bool:
185187
if node.target == torch.ops.aten.scaled_dot_product_attention.default:
186188
return True
187189
return False
190+
191+
192+
def save_pte_program(
193+
prog: ExecutorchProgramManager, model_name: str, output_dir: str = ""
194+
) -> None:
195+
if model_name.endswith(".pte"):
196+
filename = model_name
197+
else:
198+
filename = os.path.join(output_dir, f"{model_name}.pte")
199+
200+
try:
201+
with open(filename, "wb") as file:
202+
prog.write_to_file(file)
203+
logging.info(f"Saved exported program to {filename}")
204+
except Exception as e:
205+
logging.error(f"Error while saving to {filename}: {e}")
206+
207+
208+
def save_bpte_program(
209+
buffer: bytes,
210+
model_name: str,
211+
output_dir: str = "",
212+
) -> None:
213+
if model_name.endswith(".bpte"):
214+
filename = model_name
215+
else:
216+
filename = os.path.join(output_dir, f"{model_name}.bpte")
217+
try:
218+
with open(filename, "wb") as f:
219+
f.write(buffer)
220+
logging.info(f"Saved exported program to {filename}")
221+
except Exception as e:
222+
logging.error(f"Error while saving to {output_dir}: {e}")

backends/qualcomm/aot/wrappers/TensorWrapper.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,9 @@ TensorWrapper::TensorWrapper(
9191
if (data != nullptr) {
9292
QNN_VER_PTR(tensor_)->clientBuf.dataSize = bytes;
9393

94-
if (copy_data) {
94+
if (tensor_type != QNN_TENSOR_TYPE_STATIC) {
95+
QNN_VER_PTR(tensor_)->clientBuf.data = nullptr;
96+
} else if (copy_data) {
9597
owned_data_ = std::make_unique<char[]>(bytes);
9698
const char* src_data = static_cast<const char*>(data);
9799
std::memcpy(owned_data_.get(), src_data, bytes);

backends/qualcomm/builders/op_avg_pool2d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ def define_node(
5151
filter_size = filter_size + filter_size
5252
filter_size_shape = [len(filter_size)]
5353

54-
# stride info
55-
stride = cast(List[int], node.args[2])
54+
# stride info - default to kernel_size if not given
55+
stride = cast(List[int], node.args[2]) if len(node.args) > 2 else filter_size
5656
if len(stride) == 1:
5757
stride = stride + stride
5858
stride_shape = [len(stride)]

backends/qualcomm/passes/convert_to_linear.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -109,49 +109,50 @@ def _convert_to_linear(
109109

110110
# Since QNN has no keep dims for linear op, we will need to add squeeze and unsqueeze around linear node
111111
# TODO: Find a more general conditional statement.
112-
if (
113-
fn_node.target == self.add
114-
and linear_node.meta["val"].dim() == 3
115-
and linear_node.meta["val"].shape[0] == 1
116-
):
117-
squeeze_dim = linear_node.meta["val"].shape[1:]
118-
linear_node.meta["val"] = torch.squeeze(linear_node.meta["val"], 0)
112+
linear_output = linear_node.meta["val"]
113+
if linear_output.dim() == 3 and linear_output.shape[0] == 1:
119114
with gm.graph.inserting_after(input_node):
120115
input_users = list(input_node.users.keys())
121-
squeeze_dim = linear_node.meta["val"].shape
122-
squeeze_view_copy_node = gm.graph.create_node(
116+
input_tensor = input_node.meta["val"]
117+
squeeze_dim = input_tensor.shape[-2:]
118+
squeeze_node = gm.graph.create_node(
123119
"call_function",
124120
self.view_copy,
125121
(
126122
input_node,
127123
squeeze_dim,
128124
),
129125
)
130-
squeeze_view_copy_node.meta = linear_node.meta
126+
# meta needs to be copied elementwisely for fake-tensor
127+
# to be updated correctly and not affect meta of input_node
128+
for k, v in input_node.meta.items():
129+
squeeze_node.meta[k] = v
130+
squeeze_node.meta["val"] = input_tensor.reshape(squeeze_dim)
131131
for user in input_users:
132132
if user == linear_node:
133-
user.replace_input_with(input_node, squeeze_view_copy_node)
134-
with gm.graph.inserting_after(output):
133+
user.replace_input_with(input_node, squeeze_node)
134+
135+
with gm.graph.inserting_after(linear_node):
135136
output_users = list(linear_node.users.keys())
136-
unsqueeze_dim = output.args[0].meta["val"].shape
137-
unsqueeze_view_copy_node = gm.graph.create_node(
137+
unsqueeze_dim = linear_output.shape
138+
unsqueeze_node = gm.graph.create_node(
138139
"call_function",
139140
self.view_copy,
140141
(
141142
linear_node,
142143
unsqueeze_dim,
143144
),
144145
)
145-
unsqueeze_view_copy_node.meta = output.args[0].meta
146+
# meta needs to be copied elementwisely for fake-tensor
147+
# to be updated correctly and not affect meta of unsqueeze_node
148+
for k, v in linear_node.meta.items():
149+
unsqueeze_node.meta[k] = v
150+
# update linear node's shape
151+
linear_node.meta["val"] = linear_output.reshape(
152+
linear_output.shape[-2:]
153+
)
146154
for user in output_users:
147-
user.replace_input_with(linear_node, unsqueeze_view_copy_node)
148-
if QCOM_QUANT_ATTRS in linear_node.meta:
149-
squeeze_view_copy_node.meta[QCOM_QUANT_ATTRS] = linear_node.meta[
150-
QCOM_QUANT_ATTRS
151-
]
152-
unsqueeze_view_copy_node.meta[QCOM_QUANT_ATTRS] = linear_node.meta[
153-
QCOM_QUANT_ATTRS
154-
]
155+
user.replace_input_with(linear_node, unsqueeze_node)
155156

156157
def _extract_mm_ops(self, partitioned_nodes: List[edge_op]) -> List[torch.fx.Node]:
157158
mm_node = [n for n in partitioned_nodes if n.target == self.mm][0]
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
from executorch.exir.pass_base import ExportPass, PassResult
10+
from executorch.exir.passes import dead_code_elimination_pass
11+
12+
13+
class ExpandBroadcastTensorShape(ExportPass):
14+
"""
15+
Make tensors have same rank for layout-transform to work properly.
16+
"""
17+
18+
def __init__(self):
19+
super(ExpandBroadcastTensorShape, self).__init__()
20+
self.broadcast_op_targets = [
21+
exir_ops.edge.aten.add.Tensor,
22+
exir_ops.edge.aten.sub.Tensor,
23+
exir_ops.edge.aten.mul.Tensor,
24+
exir_ops.edge.aten.div.Tensor,
25+
]
26+
27+
def traverse_broadcast_node(self, graph_module: torch.fx.GraphModule):
28+
for node in graph_module.graph.nodes:
29+
if node.target in self.broadcast_op_targets:
30+
for arg in node.args:
31+
input_rank = len(arg.meta["val"].shape)
32+
output_rank = len(node.meta["val"].shape)
33+
if input_rank != output_rank:
34+
with graph_module.graph.inserting_after(arg):
35+
new_rank = [1] * (output_rank - input_rank) + list(
36+
arg.meta["val"].shape
37+
)
38+
users = list(arg.users.keys())
39+
reshape_node = graph_module.graph.create_node(
40+
"call_function",
41+
exir_ops.edge.aten.view_copy.default,
42+
(arg, tuple(new_rank)),
43+
)
44+
# meta needs to be copied elementwisely for fake-tensor
45+
# to be updated correctly and not affect meta of arg
46+
for k, v in arg.meta.items():
47+
reshape_node.meta[k] = v
48+
reshape_node.meta["val"] = reshape_node.meta["val"].reshape(
49+
new_rank
50+
)
51+
for user in users:
52+
user.replace_input_with(arg, reshape_node)
53+
54+
def call(self, graph_module: torch.fx.GraphModule):
55+
self.traverse_broadcast_node(graph_module)
56+
graph_module.recompile()
57+
dead_code_elimination_pass(graph_module)
58+
return PassResult(graph_module, True)

0 commit comments

Comments
 (0)