Skip to content

add simply ops to oss, update fuse simply callsites #6881

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions backends/cadence/aot/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ python_library(
deps = [
":passes",
":utils",
":ops_registrations",
"//caffe2:torch",
"//executorch/backends/cadence/aot/quantizer:fusion_pass",
"//executorch/backends/cadence/aot/quantizer:quantizer",
Expand Down Expand Up @@ -71,6 +72,8 @@ python_library(
],
deps = [
":utils",
":fuse_ops",
":simplify_ops",
"//caffe2:torch",
"//executorch/exir:pass_base",
"//executorch/exir/dialects:lib",
Expand Down Expand Up @@ -163,6 +166,20 @@ python_library(
],
)

python_library(
name = "simplify_ops",
srcs = [
"simplify_ops.py",
],
typing = True,
deps = [
":pass_utils",
"//executorch/backends/cadence/aot:pass_utils",
"//executorch/exir:pass_base",
"//executorch/exir/dialects:lib",
],
)

python_unittest(
name = "test_graph_builder",
srcs = [
Expand Down
1 change: 1 addition & 0 deletions backends/cadence/aot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pathlib import Path
from typing import Callable, cast, Optional

import executorch.backends.cadence.aot.ops_registrations # noqa
import torch

from executorch.backends.cadence.aot.passes import ReplaceSafeSoftmaxWithSoftmax
Expand Down
2 changes: 1 addition & 1 deletion backends/cadence/aot/fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1022,7 +1022,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
return PassResult(graph_module, True)


class FuseOpsInGraph:
class CadenceFuseOpsInGraph:
passes = [
FuseMMWithAdd,
FuseBatchNormWithConv,
Expand Down
15 changes: 15 additions & 0 deletions backends/cadence/aot/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
import torch
import torch.fx
import torch.utils._pytree as pytree
from executorch.backends.cadence.aot.fuse_ops import CadenceFuseOpsInGraph
from executorch.backends.cadence.aot.pass_utils import (
CadencePassAttribute,
create_cadence_pass_filter,
register_cadence_pass,
)
from executorch.backends.cadence.aot.simplify_ops import CadenceSimplifyOpsInGraph
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
from executorch.exir.dialects._ops import ops as exir_ops
Expand Down Expand Up @@ -346,10 +348,23 @@ def get_passes_in_default_order() -> List[Type[PassType]]:
ReplaceScalarTensorWithFullPass,
RemoveCloneOpsTransformImported,
RemoveNopExpandOpPass,
CadenceFuseOpsInGraph.passes,
ReplaceSqueezeAndUnsqueezeWithViewPass,
ReplacePT2QuantWithCadenceQuantPass,
ReplacePT2DequantWithCadenceDequantPass,
CadenceSimplifyOpsInGraph.passes,
# TODO: add the rest of the passes here.
# InitializePipeline,
# RemoveRedundantOps.passes,
# ReorderOpsInGraph.passes,
# RemoveJarvisNops.passes,
# CadenceFuseOpsInGraph.passes,
# ReplaceOpsInGraph.passes,
# SimplifyOpsInGraph.passes,
# FinalizePipeline,
# FuseFullThenReshapePass,
# FuseTransposeOpPairsPass,
# RemoveNopSliceOrViewOpPass,
]
return pytree.tree_flatten(passes)[0]

Expand Down
112 changes: 112 additions & 0 deletions backends/cadence/aot/simplify_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

# pyre-unsafe


# This file contains all the functions that simplify args of an op

import sys
from typing import Optional

from executorch.backends.cadence.aot.pass_utils import (
CadencePassAttribute,
register_cadence_pass,
)

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, ProxyValue


@register_cadence_pass(CadencePassAttribute(opt_level=0))
class SimplifySliceOpPass(ExportPass):
"""
Simplify the start and end indices of slice and slice_scatter ops.
"""

def adjust_slice_range(
self,
length: int,
start: Optional[int] = None,
end: Optional[int] = None,
step: int = 1,
) -> tuple[int, int]:
# Get the start index and end index
start_val = start if start is not None else 0
end_val = end if end is not None else sys.maxsize # 2^63 – 1

# If start_val and end_val are negative, add length to them
if start_val < 0:
start_val += length
if end_val < 0:
end_val += length

# If the start val is still outside the tensor_size along the sliced
# dimension, adjust it accordingly.
if start_val < 0:
start_val = 0
elif start_val >= length:
start_val = length

# If the end val is still outside the tensor_size along the sliced
# dimension, adjust it accordingly.
if end_val < start_val:
end_val = start_val
elif end_val >= length:
end_val = length

# Return the adjusted start and end indices
return (start_val, end_val)

def call_operator(self, op, args, kwargs, meta):
# We are only interested in slice_copy or slice_scatter ops
if op not in {
exir_ops.edge.aten.slice_copy.Tensor,
exir_ops.edge.aten.slice_scatter.default,
}:
return super().call_operator(op, args, kwargs, meta)

# Check if it is a slice_scatter op or not. The slice_scatter op has
# an extra src argument at index 1.
slice_scatter = op == exir_ops.edge.aten.slice_scatter.default
# Parse the arguments
# Extract the tensor to be sliced, and the slicing dimension
in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0]
dim = args[1 + slice_scatter] if len(args) > 1 + slice_scatter else 0
# Make dim non-negative
dim = dim if dim >= 0 else dim + in_tensor.dim()
length = in_tensor.size(dim)

# Get the adjusted start and end indices
start_val = args[2 + slice_scatter] if len(args) > 2 + slice_scatter else None
end_val = args[3 + slice_scatter] if len(args) > 3 + slice_scatter else None
step = args[4 + slice_scatter] if len(args) > 4 + slice_scatter else 1
(start_val, end_val) = self.adjust_slice_range(length, start_val, end_val, step)

# If the start_val is geq end_val, then we can return an empty tensor
# for slice op, or input for slice_scatter op.
if start_val >= end_val and slice_scatter:
return args[0]
if start_val >= end_val:
empty_shape = [x for x in in_tensor.shape if x != 0]
empty_shape[dim] = 0
return super().call_operator(
exir_ops.edge.aten.full.default,
(tuple(empty_shape), 0),
{"dtype": in_tensor.dtype},
meta,
)

# Create new args
new_args = (
(args[0],)
+ ((args[1],) if slice_scatter else ())
+ (dim, start_val, end_val, step)
)
return super().call_operator(op, new_args, kwargs, meta)


# This class encapsulates all the functions that simplify the op's args
class CadenceSimplifyOpsInGraph:
passes = [
SimplifySliceOpPass,
]
Loading