Skip to content

Commit ba344c5

Browse files
committed
Update base for Update on "Take advantage of C++17 in scalar_type_util.h"
I generated a big ugly table because we couldn't make promoteTypes constexpr before we had C++17. Now we have C++17. Differential Revision: [D66181946](https://our.internmc.facebook.com/intern/diff/D66181946/) [ghstack-poisoned]
2 parents eae0b04 + dcacde0 commit ba344c5

File tree

25 files changed

+582
-60
lines changed

25 files changed

+582
-60
lines changed

.github/scripts/check_labels.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,15 @@ def main() -> None:
4545

4646
try:
4747
if not has_required_labels(pr):
48-
print(LABEL_ERR_MSG)
48+
print(LABEL_ERR_MSG, flush=True)
4949
add_label_err_comment(pr)
5050
if args.exit_non_zero:
51-
sys.exit(1)
51+
raise RuntimeError("PR does not have required labels")
5252
else:
5353
delete_all_label_err_comments(pr)
5454
except Exception as e:
5555
if args.exit_non_zero:
56-
sys.exit(1)
56+
raise RuntimeError(f"Error checking labels: {e}") from e
5757

5858
sys.exit(0)
5959

.github/scripts/github_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,10 @@ def gh_fetch_url(
7272
headers: Optional[Dict[str, str]] = None,
7373
data: Union[Optional[Dict[str, Any]], str] = None,
7474
method: Optional[str] = None,
75-
reader: Callable[[Any], Any] = lambda x: x.read(),
75+
reader: Callable[[Any], Any] = json.load,
7676
) -> Any:
7777
return gh_fetch_url_and_headers(
78-
url, headers=headers, data=data, reader=json.load, method=method
78+
url, headers=headers, data=data, reader=reader, method=method
7979
)[1]
8080

8181

@@ -169,7 +169,7 @@ def gh_post_commit_comment(
169169

170170
def gh_delete_comment(org: str, repo: str, comment_id: int) -> None:
171171
url = f"{GITHUB_API_URL}/repos/{org}/{repo}/issues/comments/{comment_id}"
172-
gh_fetch_url(url, method="DELETE")
172+
gh_fetch_url(url, method="DELETE", reader=lambda x: x.read())
173173

174174

175175
def gh_fetch_merge_base(org: str, repo: str, base: str, head: str) -> str:

.github/workflows/android-perf.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ jobs:
136136
fail-fast: false
137137
with:
138138
runner: linux.4xlarge
139-
docker-image: executorch-ubuntu-22.04-clang12-android
139+
docker-image: executorch-ubuntu-22.04-qnn-sdk
140140
submodules: 'true'
141141
timeout: 60
142142
upload-artifact: android-models

.github/workflows/trunk.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ jobs:
302302
fail-fast: false
303303
with:
304304
runner: linux.2xlarge
305-
docker-image: executorch-ubuntu-22.04-clang12-android
305+
docker-image: executorch-ubuntu-22.04-qnn-sdk
306306
submodules: 'true'
307307
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
308308
timeout: 900

backends/arm/operators/op_add.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def define_node(
8282

8383
if needs_rescale:
8484
# Scale output back to 8 bit
85+
# pyre-ignore
8586
tqutils.rescale_node_back_to_int8(node, add_output, scale, tosa_graph)
8687

8788

backends/cadence/aot/TARGETS

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ python_library(
3838
deps = [
3939
":passes",
4040
":utils",
41+
":ops_registrations",
4142
"//caffe2:torch",
4243
"//executorch/backends/cadence/aot/quantizer:fusion_pass",
4344
"//executorch/backends/cadence/aot/quantizer:quantizer",
@@ -71,6 +72,8 @@ python_library(
7172
],
7273
deps = [
7374
":utils",
75+
":fuse_ops",
76+
":simplify_ops",
7477
"//caffe2:torch",
7578
"//executorch/exir:pass_base",
7679
"//executorch/exir/dialects:lib",
@@ -132,6 +135,18 @@ python_library(
132135
],
133136
)
134137

138+
python_library(
139+
name = "graph_builder",
140+
srcs = [
141+
"graph_builder.py",
142+
],
143+
typing = True,
144+
deps = [
145+
"fbcode//caffe2:torch",
146+
"fbcode//executorch/exir:pass_base",
147+
],
148+
)
149+
135150
python_library(
136151
name = "fuse_ops",
137152
srcs = [
@@ -150,3 +165,34 @@ python_library(
150165
"//executorch/exir/passes:spec_prop_pass",
151166
],
152167
)
168+
169+
python_library(
170+
name = "simplify_ops",
171+
srcs = [
172+
"simplify_ops.py",
173+
],
174+
typing = True,
175+
deps = [
176+
":pass_utils",
177+
"//executorch/backends/cadence/aot:pass_utils",
178+
"//executorch/exir:pass_base",
179+
"//executorch/exir/dialects:lib",
180+
],
181+
)
182+
183+
python_unittest(
184+
name = "test_graph_builder",
185+
srcs = [
186+
"tests/test_graph_builder.py",
187+
],
188+
typing = True,
189+
deps = [
190+
"//caffe2:torch",
191+
"//executorch/backends/cadence/aot:graph_builder",
192+
"//executorch/backends/cadence/aot:pass_utils",
193+
"//executorch/exir:pass_base",
194+
"//executorch/exir/dialects:lib",
195+
"//later:lib",
196+
":ops_registrations"
197+
],
198+
)

backends/cadence/aot/compiler.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from pathlib import Path
1111
from typing import Callable, cast, Optional
1212

13+
import executorch.backends.cadence.aot.ops_registrations # noqa
1314
import torch
1415

1516
from executorch.backends.cadence.aot.passes import ReplaceSafeSoftmaxWithSoftmax
@@ -196,7 +197,26 @@ def export_to_edge(
196197
# Export the model and lower it to an EdgeProgramManager (in edge IR), and
197198
# apply passes specific to Cadence DSP execution. Return both to print the
198199
# differences.
199-
def export_to_cadence_edge_executorch(
200+
def export_to_cadence(
201+
model: torch.nn.Module,
202+
inputs: tuple[object, ...],
203+
dump_graphs: bool = False,
204+
output_dir: Optional[str] = None,
205+
opt_level: int = 1,
206+
) -> EdgeProgramManager:
207+
edge_prog_manager = export_to_edge(model, inputs)
208+
cadence_passes = get_cadence_passes(opt_level)
209+
210+
# Run a couple required passes for quant/dequant ops
211+
cadence_prog_manager = edge_prog_manager.transform(
212+
cast(
213+
list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes
214+
)
215+
)
216+
return cadence_prog_manager
217+
218+
219+
def export_to_executorch_gen_etrecord(
200220
model: torch.nn.Module,
201221
inputs: tuple[object, ...],
202222
dump_graphs: bool = False,

backends/cadence/aot/export_example.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from executorch.backends.cadence.aot.compiler import (
1818
convert_pt2,
19-
export_to_cadence_edge_executorch,
19+
export_to_executorch_gen_etrecord,
2020
fuse_pt2,
2121
)
2222

@@ -86,8 +86,8 @@ def export_model(
8686
quantized_model = fuse_pt2(converted_model, quantizer)
8787

8888
# Get edge program after Cadence specific passes
89-
exec_prog: ExecutorchProgramManager = export_to_cadence_edge_executorch(
90-
quantized_model, example_inputs, working_dir
89+
exec_prog: ExecutorchProgramManager = export_to_executorch_gen_etrecord(
90+
quantized_model, example_inputs, output_dir=working_dir
9191
)
9292

9393
logging.info("Final exported graph:\n")

backends/cadence/aot/fuse_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1022,7 +1022,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
10221022
return PassResult(graph_module, True)
10231023

10241024

1025-
class FuseOpsInGraph:
1025+
class CadenceFuseOpsInGraph:
10261026
passes = [
10271027
FuseMMWithAdd,
10281028
FuseBatchNormWithConv,

backends/cadence/aot/graph_builder.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
# pyre-strict
4+
5+
import logging
6+
from typing import Optional, Sequence, Union
7+
8+
import torch
9+
from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue
10+
from torch._subclasses import FakeTensor, FakeTensorMode
11+
from torch.fx.node import Argument, Target
12+
from torch.utils import _pytree as pytree
13+
14+
15+
class GraphBuilder(ExportPass):
16+
"""Utility class for creating a graph module with user-specified ops.
17+
18+
This class allows us to create test graph modules with any ops we want
19+
directly, rather than relying on decomposition or passes.
20+
21+
Usage:
22+
builder = GraphBuilder()
23+
# To insert placeholders, use builder.placeholder.
24+
x = builder.placeholder("x", torch.randn(1, 3, 224, 224))
25+
# To insert an op, use builder.call_operator.
26+
op = builder.call_operator(
27+
some_op
28+
(x, other_args, ...),
29+
)
30+
# Insert outputs as a list of ProxyValues using builder.output.
31+
builder.output([op])
32+
# Get GraphModule from builder.
33+
gm = builder.get_graph_module()
34+
"""
35+
36+
def __init__(self) -> None:
37+
self.exporter = ExportPass()
38+
self.tracer: ExportPass.ExportTracer = self.ExportTracer(
39+
self, torch.fx.graph.CodeGen()
40+
)
41+
self.fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=False)
42+
self.tracer.fake_tensor_mode = self.fake_tensor_mode
43+
44+
# This will be called to create nodes in tracer.
45+
self.interpreter = torch.fx.Interpreter(
46+
torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
47+
)
48+
49+
# pyre-ignore[14]: Inconsistent override.
50+
def placeholder(
51+
self, target: str, fake_tensor: Union[FakeTensor, torch.Tensor]
52+
) -> ProxyValue:
53+
if not isinstance(fake_tensor, FakeTensor):
54+
fake_tensor = self.fake_tensor_mode.from_tensor(fake_tensor)
55+
logging.info(f"Creating placeholder {target} => {fake_tensor.shape}")
56+
placeholder = super().placeholder(target, fake_tensor, NodeMetadata({}))
57+
return placeholder
58+
59+
# pyre-ignore[14]: Inconsistent override.
60+
def output(self, results: list[ProxyValue]) -> ProxyValue:
61+
logging.info(f"Creating outputs {results}")
62+
return super().output(results, NodeMetadata({}))
63+
64+
def get_graph_module(self) -> torch.fx.GraphModule:
65+
return torch.fx.GraphModule(self.tracer.root, self.tracer.graph)
66+
67+
def call_operator(
68+
self,
69+
op, # pyre-ignore
70+
args: tuple[Argument, ...],
71+
kwargs: Optional[dict[str, Argument]] = None,
72+
meta: Optional[NodeMetadata] = None,
73+
) -> ProxyValue:
74+
if meta is None:
75+
meta = NodeMetadata({})
76+
if kwargs is None:
77+
kwargs = {}
78+
return super().call_operator(op, args, kwargs, meta)
79+
80+
81+
def single_op_builder(
82+
placeholders: Sequence[Union[torch.Tensor, FakeTensor]],
83+
op: Target,
84+
args: Sequence[Argument],
85+
kwargs: Optional[dict[str, Argument]] = None,
86+
) -> torch.fx.GraphModule:
87+
"""Create a graph module with a single op.
88+
89+
Args:
90+
placeholders: Placeholders to be used as inputs to the GraphModule.
91+
op: The op to be inserted.
92+
args: The args to be passed to the op.
93+
kwargs: The kwargs to be passed to the op.
94+
95+
Returns:
96+
A graph module with a single op
97+
"""
98+
builder = GraphBuilder()
99+
op_to_placeholder_dict = {
100+
p: builder.placeholder(f"p_{i}", p) for i, p in enumerate(placeholders)
101+
}
102+
proxy_args, proxy_kwargs = pytree.tree_map_only(
103+
(torch.Tensor, FakeTensor), lambda x: op_to_placeholder_dict[x], (args, kwargs)
104+
)
105+
node = builder.call_operator(op, proxy_args, proxy_kwargs)
106+
builder.output([node])
107+
return builder.get_graph_module()

backends/cadence/aot/pass_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,12 @@ def get_node_names_list_from_gm(
8989
continue
9090
graph_nodes.append(node.name)
9191
return graph_nodes
92+
93+
94+
def count_node(graph_module: torch.fx.GraphModule, target: torch.fx.node.Target) -> int:
95+
"""Count the number of nodes with target `target` in the graph."""
96+
total = 0
97+
for node in graph_module.graph.nodes:
98+
if node.op == "call_function" and node.target == target:
99+
total += 1
100+
return total

backends/cadence/aot/passes.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@
1111
import torch
1212
import torch.fx
1313
import torch.utils._pytree as pytree
14+
from executorch.backends.cadence.aot.fuse_ops import CadenceFuseOpsInGraph
1415
from executorch.backends.cadence.aot.pass_utils import (
1516
CadencePassAttribute,
1617
create_cadence_pass_filter,
1718
register_cadence_pass,
1819
)
20+
from executorch.backends.cadence.aot.simplify_ops import CadenceSimplifyOpsInGraph
1921
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
2022
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
2123
from executorch.exir.dialects._ops import ops as exir_ops
@@ -346,10 +348,23 @@ def get_passes_in_default_order() -> List[Type[PassType]]:
346348
ReplaceScalarTensorWithFullPass,
347349
RemoveCloneOpsTransformImported,
348350
RemoveNopExpandOpPass,
351+
CadenceFuseOpsInGraph.passes,
349352
ReplaceSqueezeAndUnsqueezeWithViewPass,
350353
ReplacePT2QuantWithCadenceQuantPass,
351354
ReplacePT2DequantWithCadenceDequantPass,
355+
CadenceSimplifyOpsInGraph.passes,
352356
# TODO: add the rest of the passes here.
357+
# InitializePipeline,
358+
# RemoveRedundantOps.passes,
359+
# ReorderOpsInGraph.passes,
360+
# RemoveJarvisNops.passes,
361+
# CadenceFuseOpsInGraph.passes,
362+
# ReplaceOpsInGraph.passes,
363+
# SimplifyOpsInGraph.passes,
364+
# FinalizePipeline,
365+
# FuseFullThenReshapePass,
366+
# FuseTransposeOpPairsPass,
367+
# RemoveNopSliceOrViewOpPass,
353368
]
354369
return pytree.tree_flatten(passes)[0]
355370

0 commit comments

Comments
 (0)