Skip to content

Commit af87283

Browse files
authored
add simply ops to oss, update fuse simply callsites
Differential Revision: D65980636 Pull Request resolved: #6881
1 parent f40daea commit af87283

File tree

5 files changed

+146
-1
lines changed

5 files changed

+146
-1
lines changed

backends/cadence/aot/TARGETS

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ python_library(
3838
deps = [
3939
":passes",
4040
":utils",
41+
":ops_registrations",
4142
"//caffe2:torch",
4243
"//executorch/backends/cadence/aot/quantizer:fusion_pass",
4344
"//executorch/backends/cadence/aot/quantizer:quantizer",
@@ -71,6 +72,8 @@ python_library(
7172
],
7273
deps = [
7374
":utils",
75+
":fuse_ops",
76+
":simplify_ops",
7477
"//caffe2:torch",
7578
"//executorch/exir:pass_base",
7679
"//executorch/exir/dialects:lib",
@@ -163,6 +166,20 @@ python_library(
163166
],
164167
)
165168

169+
python_library(
170+
name = "simplify_ops",
171+
srcs = [
172+
"simplify_ops.py",
173+
],
174+
typing = True,
175+
deps = [
176+
":pass_utils",
177+
"//executorch/backends/cadence/aot:pass_utils",
178+
"//executorch/exir:pass_base",
179+
"//executorch/exir/dialects:lib",
180+
],
181+
)
182+
166183
python_unittest(
167184
name = "test_graph_builder",
168185
srcs = [

backends/cadence/aot/compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from pathlib import Path
1111
from typing import Callable, cast, Optional
1212

13+
import executorch.backends.cadence.aot.ops_registrations # noqa
1314
import torch
1415

1516
from executorch.backends.cadence.aot.passes import ReplaceSafeSoftmaxWithSoftmax

backends/cadence/aot/fuse_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1022,7 +1022,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
10221022
return PassResult(graph_module, True)
10231023

10241024

1025-
class FuseOpsInGraph:
1025+
class CadenceFuseOpsInGraph:
10261026
passes = [
10271027
FuseMMWithAdd,
10281028
FuseBatchNormWithConv,

backends/cadence/aot/passes.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@
1111
import torch
1212
import torch.fx
1313
import torch.utils._pytree as pytree
14+
from executorch.backends.cadence.aot.fuse_ops import CadenceFuseOpsInGraph
1415
from executorch.backends.cadence.aot.pass_utils import (
1516
CadencePassAttribute,
1617
create_cadence_pass_filter,
1718
register_cadence_pass,
1819
)
20+
from executorch.backends.cadence.aot.simplify_ops import CadenceSimplifyOpsInGraph
1921
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
2022
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
2123
from executorch.exir.dialects._ops import ops as exir_ops
@@ -346,10 +348,23 @@ def get_passes_in_default_order() -> List[Type[PassType]]:
346348
ReplaceScalarTensorWithFullPass,
347349
RemoveCloneOpsTransformImported,
348350
RemoveNopExpandOpPass,
351+
CadenceFuseOpsInGraph.passes,
349352
ReplaceSqueezeAndUnsqueezeWithViewPass,
350353
ReplacePT2QuantWithCadenceQuantPass,
351354
ReplacePT2DequantWithCadenceDequantPass,
355+
CadenceSimplifyOpsInGraph.passes,
352356
# TODO: add the rest of the passes here.
357+
# InitializePipeline,
358+
# RemoveRedundantOps.passes,
359+
# ReorderOpsInGraph.passes,
360+
# RemoveJarvisNops.passes,
361+
# CadenceFuseOpsInGraph.passes,
362+
# ReplaceOpsInGraph.passes,
363+
# SimplifyOpsInGraph.passes,
364+
# FinalizePipeline,
365+
# FuseFullThenReshapePass,
366+
# FuseTransposeOpPairsPass,
367+
# RemoveNopSliceOrViewOpPass,
353368
]
354369
return pytree.tree_flatten(passes)[0]
355370

backends/cadence/aot/simplify_ops.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
# pyre-unsafe
4+
5+
6+
# This file contains all the functions that simplify args of an op
7+
8+
import sys
9+
from typing import Optional
10+
11+
from executorch.backends.cadence.aot.pass_utils import (
12+
CadencePassAttribute,
13+
register_cadence_pass,
14+
)
15+
16+
from executorch.exir.dialects._ops import ops as exir_ops
17+
from executorch.exir.pass_base import ExportPass, ProxyValue
18+
19+
20+
@register_cadence_pass(CadencePassAttribute(opt_level=0))
21+
class SimplifySliceOpPass(ExportPass):
22+
"""
23+
Simplify the start and end indices of slice and slice_scatter ops.
24+
"""
25+
26+
def adjust_slice_range(
27+
self,
28+
length: int,
29+
start: Optional[int] = None,
30+
end: Optional[int] = None,
31+
step: int = 1,
32+
) -> tuple[int, int]:
33+
# Get the start index and end index
34+
start_val = start if start is not None else 0
35+
end_val = end if end is not None else sys.maxsize # 2^63 – 1
36+
37+
# If start_val and end_val are negative, add length to them
38+
if start_val < 0:
39+
start_val += length
40+
if end_val < 0:
41+
end_val += length
42+
43+
# If the start val is still outside the tensor_size along the sliced
44+
# dimension, adjust it accordingly.
45+
if start_val < 0:
46+
start_val = 0
47+
elif start_val >= length:
48+
start_val = length
49+
50+
# If the end val is still outside the tensor_size along the sliced
51+
# dimension, adjust it accordingly.
52+
if end_val < start_val:
53+
end_val = start_val
54+
elif end_val >= length:
55+
end_val = length
56+
57+
# Return the adjusted start and end indices
58+
return (start_val, end_val)
59+
60+
def call_operator(self, op, args, kwargs, meta):
61+
# We are only interested in slice_copy or slice_scatter ops
62+
if op not in {
63+
exir_ops.edge.aten.slice_copy.Tensor,
64+
exir_ops.edge.aten.slice_scatter.default,
65+
}:
66+
return super().call_operator(op, args, kwargs, meta)
67+
68+
# Check if it is a slice_scatter op or not. The slice_scatter op has
69+
# an extra src argument at index 1.
70+
slice_scatter = op == exir_ops.edge.aten.slice_scatter.default
71+
# Parse the arguments
72+
# Extract the tensor to be sliced, and the slicing dimension
73+
in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0]
74+
dim = args[1 + slice_scatter] if len(args) > 1 + slice_scatter else 0
75+
# Make dim non-negative
76+
dim = dim if dim >= 0 else dim + in_tensor.dim()
77+
length = in_tensor.size(dim)
78+
79+
# Get the adjusted start and end indices
80+
start_val = args[2 + slice_scatter] if len(args) > 2 + slice_scatter else None
81+
end_val = args[3 + slice_scatter] if len(args) > 3 + slice_scatter else None
82+
step = args[4 + slice_scatter] if len(args) > 4 + slice_scatter else 1
83+
(start_val, end_val) = self.adjust_slice_range(length, start_val, end_val, step)
84+
85+
# If the start_val is geq end_val, then we can return an empty tensor
86+
# for slice op, or input for slice_scatter op.
87+
if start_val >= end_val and slice_scatter:
88+
return args[0]
89+
if start_val >= end_val:
90+
empty_shape = [x for x in in_tensor.shape if x != 0]
91+
empty_shape[dim] = 0
92+
return super().call_operator(
93+
exir_ops.edge.aten.full.default,
94+
(tuple(empty_shape), 0),
95+
{"dtype": in_tensor.dtype},
96+
meta,
97+
)
98+
99+
# Create new args
100+
new_args = (
101+
(args[0],)
102+
+ ((args[1],) if slice_scatter else ())
103+
+ (dim, start_val, end_val, step)
104+
)
105+
return super().call_operator(op, new_args, kwargs, meta)
106+
107+
108+
# This class encapsulates all the functions that simplify the op's args
109+
class CadenceSimplifyOpsInGraph:
110+
passes = [
111+
SimplifySliceOpPass,
112+
]

0 commit comments

Comments
 (0)