Skip to content

Commit 5392bdb

Browse files
author
morelos
committed
Update on "[ET-VK][Ops] aten.tan.default in unary_ops"
Adds tan to unary_ops, albeit trivially as it doesn't already exist Differential Revision: [D75112807](https://our.internmc.facebook.com/intern/diff/D75112807/) [ghstack-poisoned]
2 parents 5bbeea3 + 67e959a commit 5392bdb

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+686
-213
lines changed

.github/workflows/build-presets.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ jobs:
2020
strategy:
2121
fail-fast: false
2222
matrix:
23-
preset: [macos-arm64, pybind, llm]
23+
preset: [macos, ios, ios-simulator, pybind, llm]
2424
with:
2525
job-name: build
2626
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
@@ -39,7 +39,7 @@ jobs:
3939
strategy:
4040
fail-fast: false
4141
matrix:
42-
preset: [pybind, llm]
42+
preset: [linux, pybind, llm]
4343
runner: [linux.2xlarge, linux.arm64.2xlarge]
4444
docker-image: [executorch-ubuntu-22.04-clang12, executorch-ubuntu-22.04-gcc11-aarch64]
4545
# Excluding specific runner + docker image combinations that don't make sense:

.lintrunner.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,6 @@ exclude_patterns = [
390390
"backends/arm/test/ops/**",
391391
"backends/vulkan/quantizer/**",
392392
"backends/vulkan/test/**",
393-
"backends/cadence/aot/quantizer/**",
394393
"backends/qualcomm/quantizer/**",
395394
"examples/qualcomm/**",
396395
"backends/xnnpack/quantizer/**",

CMakePresets.json

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@
77
"binaryDir": "${sourceDir}/cmake-out"
88
},
99
{
10-
"name": "macos-arm64",
11-
"displayName": "Build everything buildable on macOS arm64",
10+
"name": "macos",
11+
"displayName": "Build everything buildable on macOS",
1212
"inherits": ["common"],
1313
"generator": "Xcode",
1414
"cacheVariables": {
1515
"CMAKE_TOOLCHAIN_FILE": "${sourceDir}/third-party/ios-cmake/ios.toolchain.cmake",
16-
"EXECUTORCH_BUILD_PRESET_FILE": "${sourceDir}/tools/cmake/preset/macos-arm64.cmake",
16+
"EXECUTORCH_BUILD_PRESET_FILE": "${sourceDir}/tools/cmake/preset/macos.cmake",
1717
"PLATFORM": "MAC_ARM64",
1818
"DEPLOYMENT_TARGET": "10.15"
1919
},
@@ -23,6 +23,54 @@
2323
"rhs": "Darwin"
2424
}
2525
},
26+
{
27+
"name": "ios",
28+
"displayName": "Build everything buildable on iOS",
29+
"inherits": ["common"],
30+
"generator": "Xcode",
31+
"cacheVariables": {
32+
"CMAKE_TOOLCHAIN_FILE": "${sourceDir}/third-party/ios-cmake/ios.toolchain.cmake",
33+
"EXECUTORCH_BUILD_PRESET_FILE": "${sourceDir}/tools/cmake/preset/ios.cmake",
34+
"PLATFORM": "OS64",
35+
"DEPLOYMENT_TARGET": "17.0"
36+
},
37+
"condition": {
38+
"lhs": "${hostSystemName}",
39+
"type": "equals",
40+
"rhs": "Darwin"
41+
}
42+
},
43+
{
44+
"name": "ios-simulator",
45+
"displayName": "Build everything buildable on iOS simulator",
46+
"inherits": ["common"],
47+
"generator": "Xcode",
48+
"cacheVariables": {
49+
"CMAKE_TOOLCHAIN_FILE": "${sourceDir}/third-party/ios-cmake/ios.toolchain.cmake",
50+
"EXECUTORCH_BUILD_PRESET_FILE": "${sourceDir}/tools/cmake/preset/ios.cmake",
51+
"PLATFORM": "SIMULATORARM64",
52+
"DEPLOYMENT_TARGET": "17.0"
53+
},
54+
"condition": {
55+
"lhs": "${hostSystemName}",
56+
"type": "equals",
57+
"rhs": "Darwin"
58+
}
59+
},
60+
{
61+
"name": "linux",
62+
"displayName": "Build everything buildable on Linux",
63+
"inherits": ["common"],
64+
"cacheVariables": {
65+
"CMAKE_SYSTEM_NAME": "Linux",
66+
"EXECUTORCH_BUILD_PRESET_FILE": "${sourceDir}/tools/cmake/preset/linux.cmake"
67+
},
68+
"condition": {
69+
"lhs": "${hostSystemName}",
70+
"type": "equals",
71+
"rhs": "Linux"
72+
}
73+
},
2674
{
2775
"name": "pybind",
2876
"displayName": "Build pybindings exported in the wheel",

backends/cadence/aot/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def prepare_and_convert_pt2(
123123
assert isinstance(model_gm, torch.fx.GraphModule)
124124

125125
# Prepare
126-
prepared_model = prepare_pt2e(model_gm, quantizer) # pyre-ignore[6]
126+
prepared_model = prepare_pt2e(model_gm, quantizer)
127127

128128
# Calibrate
129129
# If no calibration data is provided, use the inputs

backends/cadence/aot/quantizer/TARGETS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ python_library(
99
],
1010
deps = [
1111
"//caffe2:torch",
12+
"//pytorch/ao:torchao",
1213
],
1314
)
1415

@@ -34,7 +35,6 @@ python_library(
3435
":patterns",
3536
":utils",
3637
"//caffe2:torch",
37-
"//executorch/backends/xnnpack/quantizer:xnnpack_quantizer_utils",
3838
],
3939
)
4040

backends/cadence/aot/quantizer/patterns.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from torch import fx
1717
from torch._ops import OpOverload
18-
from torch.ao.quantization.quantizer import (
18+
from torchao.quantization.pt2e.quantizer import (
1919
DerivedQuantizationSpec,
2020
SharedQuantizationSpec,
2121
)

backends/cadence/aot/quantizer/quantizer.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,19 +29,20 @@
2929
is_annotated,
3030
no_outside_users,
3131
)
32-
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer_utils import (
32+
33+
from torch import fx
34+
35+
from torchao.quantization.pt2e import HistogramObserver, MinMaxObserver
36+
from torchao.quantization.pt2e.quantizer import (
37+
ComposableQuantizer,
38+
DerivedQuantizationSpec,
3339
OperatorConfig,
3440
QuantizationAnnotation,
3541
QuantizationConfig,
3642
QuantizationSpec,
43+
Quantizer,
3744
)
3845

39-
from torch import fx
40-
41-
from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver
42-
from torch.ao.quantization.quantizer import DerivedQuantizationSpec, Quantizer
43-
from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer
44-
4546

4647
act_qspec_asym8s = QuantizationSpec(
4748
dtype=torch.int8,

backends/cadence/aot/quantizer/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@
1414
import torch
1515
from torch import fx
1616
from torch._ops import OpOverload
17-
from torch.ao.quantization import ObserverOrFakeQuantize
1817

1918
from torch.fx import GraphModule
2019
from torch.fx.passes.utils.source_matcher_utils import (
2120
check_subgraphs_connected,
2221
SourcePartition,
2322
)
23+
from torchao.quantization.pt2e import ObserverOrFakeQuantize
2424

2525

2626
def quantize_tensor_multiplier(

backends/cadence/aot/remove_ops.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -235,10 +235,7 @@ def call_operator(
235235
kwargs: dict[str, Argument],
236236
meta: NodeMetadata,
237237
) -> ProxyValue:
238-
if op not in {
239-
exir_ops.edge.aten.linalg_vector_norm.default,
240-
exir_ops.edge.cadence.linalg_vector_norm.default,
241-
}:
238+
if op is not exir_ops.edge.aten.linalg_vector_norm.default:
242239
return super().call_operator(op, args, kwargs, meta)
243240

244241
# If the op has three args or less, it can't be a nop

backends/cadence/aot/tests/test_remove_ops_passes.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -467,10 +467,7 @@ def forward(self, x: torch.Tensor):
467467

468468
# Expect the linalg_vector_norm op to be removed by the pass
469469
self.assertEqual(
470-
count_node(graph_module, exir_ops.edge.aten.linalg_vector_norm.default)
471-
+ count_node(
472-
graph_module, exir_ops.edge.cadence.linalg_vector_norm.default
473-
),
470+
count_node(graph_module, exir_ops.edge.aten.linalg_vector_norm.default),
474471
0,
475472
)
476473

backends/qualcomm/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from .expand_broadcast_tensor_shape import ExpandBroadcastTensorShape
2121
from .fixed_linear_keep_dim import FixedLinearKeepDim
2222
from .fold_qdq import FoldQDQ
23+
from .fuse_consecutive_cast import FuseConsecutiveCast
2324
from .fuse_consecutive_transpose import FuseConsecutiveTranspose
2425
from .i64_to_i32 import I64toI32
2526
from .insert_io_qdq import InsertIOQDQ
@@ -54,6 +55,7 @@
5455
ExpandBroadcastTensorShape,
5556
FixedLinearKeepDim,
5657
FoldQDQ,
58+
FuseConsecutiveCast,
5759
FuseConsecutiveTranspose,
5860
I64toI32,
5961
InsertIOQDQ,
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
import torch
9+
10+
from executorch.exir.dialects._ops import ops as exir_ops
11+
from executorch.exir.pass_base import ExportPass, PassResult
12+
from executorch.exir.passes import dead_code_elimination_pass
13+
14+
15+
class FuseConsecutiveCast(ExportPass):
16+
"""
17+
This pass fuses consecutive cast into one or none to reduce runtime
18+
overhead.
19+
To simplify the fuse logic, we ensure each cast node's output has at most 1 cast node
20+
by cloning cast.
21+
Example:
22+
Before clone cast:
23+
relu -> cast1 ─> cast2
24+
|──────> cast3
25+
26+
After clone cast:
27+
relu ─> cast1 ──────> cast2
28+
|───> cast4(new) ─> cast3
29+
"""
30+
31+
def __init__(self):
32+
super().__init__()
33+
self.op_map = {
34+
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
35+
exir_ops.edge.aten._to_copy.default,
36+
}
37+
self.visited = set()
38+
self.nodes = []
39+
40+
def _canonicalize_cast(
41+
self, graph_module: torch.fx.GraphModule
42+
) -> torch.fx.GraphModule:
43+
# replace all i64 cast nodes with i32 version
44+
graph = graph_module.graph
45+
for n in graph_module.graph.nodes:
46+
if n.target in self.op_map and n.meta["val"].dtype == torch.int64:
47+
users = list(n.users)
48+
for user in users:
49+
# bypass graph output node to meet original convention
50+
if user.op == "output":
51+
continue
52+
53+
with graph.inserting_after(n):
54+
cast_node = graph.create_node(
55+
"call_function",
56+
exir_ops.edge.aten._to_copy.default,
57+
n.args,
58+
kwargs={"dtype": torch.int32},
59+
)
60+
cast_node.meta = n.meta
61+
cast_node.meta["val"] = cast_node.meta["val"].to(torch.int32)
62+
user.replace_input_with(n, cast_node)
63+
64+
graph.eliminate_dead_code()
65+
66+
# clone nodes for future fusion
67+
for n in graph_module.graph.nodes:
68+
# make sure we're handling cast node instead of convert node
69+
if n.target in self.op_map and n.kwargs.get("dtype", None) is not None:
70+
users = [user for user in list(n.users) if user.target in self.op_map]
71+
if len(users) > 1:
72+
for i in range(1, len(users)):
73+
with graph.inserting_after(n):
74+
clone_cast_node = graph.create_node(
75+
"call_function",
76+
exir_ops.edge.aten._to_copy.default,
77+
n.args,
78+
kwargs=n.kwargs,
79+
)
80+
clone_cast_node.meta = n.meta
81+
users[i].replace_input_with(n, clone_cast_node)
82+
83+
def _traverse(self, node):
84+
if node in self.visited or node.target not in self.op_map:
85+
return
86+
87+
self.nodes.append(node)
88+
self.visited.add(node)
89+
next_users = [n for n in list(node.users) if n.target in self.op_map]
90+
91+
assert (
92+
len(next_users) <= 1
93+
), "Each cast node should have at most 1 cast output node after _clone_cast"
94+
if not next_users:
95+
return
96+
else:
97+
self._traverse(list(node.users)[0])
98+
99+
def _fuse(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
100+
for n in graph_module.graph.nodes:
101+
self._traverse(n)
102+
# TODO: how to handle following scenario (won't happen for quantized graph)
103+
# fp -> to(i32) -> to(fp)
104+
if len(self.nodes) > 1:
105+
input_node, output_node = self.nodes[0], self.nodes[-1]
106+
output_node.replace_input_with(output_node.args[0], input_node.args[0])
107+
108+
# clear current stack
109+
self.nodes = []
110+
111+
def call(self, graph_module: torch.fx.GraphModule):
112+
self._canonicalize_cast(graph_module)
113+
self._fuse(graph_module)
114+
graph_module.recompile()
115+
dead_code_elimination_pass(graph_module)
116+
return PassResult(graph_module, True)

backends/qualcomm/_passes/i64_to_i32.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,14 @@ class I64toI32(ExportPass):
3131
exir_ops.edge.aten.full.default,
3232
exir_ops.edge.aten.scalar_tensor.default,
3333
}
34+
# This dict is to ensure that the input of the OPs are int64 due to Pytorch restrictions.
35+
# For example, scatter op can only accept args[2], the index, as int64.
36+
# Key: Ops to cast input to i64
37+
# Value: The args' indices to add casting op
38+
I64_IN_OPS = {
39+
exir_ops.edge.aten.gather.default: [2],
40+
exir_ops.edge.aten.scatter.src: [2],
41+
}
3442
copy_op = exir_ops.edge.aten._to_copy.default
3543

3644
def __init__(
@@ -141,11 +149,32 @@ def _cast_constant_to_int32(self, graph_module: torch.fx.GraphModule):
141149
n.replace_all_uses_with(to_dst_node)
142150
to_dst_node.args = (n,)
143151

152+
def _cast_op_args_to_i64(self, graph_module: torch.fx.GraphModule):
153+
# input will be cast to i32 during call_operator dtype propogation
154+
# insert i64 cast node to prevent PyTorch's operator validation failure
155+
for node in graph_module.graph.nodes:
156+
if node.target in self.I64_IN_OPS:
157+
with graph_module.graph.inserting_before(node):
158+
arg_indices = self.I64_IN_OPS[node.target]
159+
for arg_index in arg_indices:
160+
input_node = node.args[arg_index]
161+
cast_i64_node = graph_module.graph.create_node(
162+
"call_function",
163+
self.copy_op,
164+
(input_node,),
165+
{"dtype": torch.int64},
166+
)
167+
cast_i64_node.meta["val"] = node.meta["val"].to(torch.int64)
168+
args_list = list(node.args)
169+
args_list[arg_index] = cast_i64_node
170+
node.args = tuple(args_list)
171+
144172
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
145173
# Record original output dtype to ensure that if user expects int64 as output,
146174
# convert the output back to int64 if it is casted from int64->int32.
147175
self._record_original_output_dtype(graph_module)
148176
self._cast_constant_to_int32(graph_module)
177+
self._cast_op_args_to_i64(graph_module)
149178
graph_module = super().call(graph_module).graph_module
150179
self._preserve_output_dtype(graph_module)
151180
graph_module.recompile()

0 commit comments

Comments
 (0)