Skip to content

migrate passes calls to pass utils #6721

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions backends/cadence/aot/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,11 @@ python_library(
],
)


python_library(
name = "passes",
name = "pass_utils",
srcs = [
"_passes.py",
"pass_utils.py",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tarun292 do we need the _ here? I forgot the exact reason to add it

],
deps = [
":utils",
Expand All @@ -64,9 +65,9 @@ python_library(
)

python_library(
name = "pass_utils",
name = "passes",
srcs = [
"pass_utils.py",
"passes.py",
],
deps = [
":utils",
Expand Down
34 changes: 10 additions & 24 deletions backends/cadence/aot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,41 +8,33 @@

import logging
from pathlib import Path
from typing import Optional
from typing import Callable, cast, Optional

import torch

from executorch.backends.cadence.aot._passes import (
InitializePipeline,
RemoveNopExpandOpPass,
RemoveZeroSizedCatArgsPass,
ReplaceLogicalNotBooleanWhereWithWherePass,
ReplacePT2DequantWithCadenceDequantPass,
ReplacePT2QuantWithCadenceQuantPass,
ReplaceSafeSoftmaxWithSoftmax,
ReplaceScalarTensorWithFullPass,
ReplaceSqueezeAndUnsqueezeWithViewPass,
)
from executorch.backends.cadence.aot.passes import ReplaceSafeSoftmaxWithSoftmax
from executorch.backends.cadence.aot.quantizer.fusion_pass import QuantFusion
from executorch.backends.cadence.aot.quantizer.quantizer import CadenceQuantizer
from executorch.backends.cadence.aot.utils import model_gm_has_SDPA, model_is_quantized
from executorch.backends.transforms.decompose_sdpa import (
DecomposeScaledDotProductAttention,
)
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
from executorch.devtools import generate_etrecord
from executorch.exir import (
EdgeCompileConfig,
EdgeProgramManager,
ExecutorchProgramManager,
to_edge,
)
from executorch.exir.pass_base import PassResult
from torch.ao.quantization.pt2e.export_utils import model_is_exported
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e

from torch.export import export
from torch.export.exported_program import ExportedProgram

from .passes import get_cadence_passes

from .utils import print_ops_info


Expand Down Expand Up @@ -209,22 +201,16 @@ def export_to_cadence_edge_executorch(
inputs: tuple[object, ...],
dump_graphs: bool = False,
output_dir: Optional[str] = None,
opt_level: int = 1,
) -> ExecutorchProgramManager:
edge_prog_manager = export_to_edge(model, inputs)
cadence_passes = get_cadence_passes(opt_level)

# Run a couple required passes for quant/dequant ops
cadence_prog_manager = edge_prog_manager.transform(
[
InitializePipeline(),
RemoveZeroSizedCatArgsPass(),
ReplaceLogicalNotBooleanWhereWithWherePass(),
ReplaceScalarTensorWithFullPass(),
RemoveCloneOpsTransform(),
RemoveNopExpandOpPass(),
ReplaceSqueezeAndUnsqueezeWithViewPass(),
ReplacePT2QuantWithCadenceQuantPass(),
ReplacePT2DequantWithCadenceDequantPass(),
]
cast(
list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes
)
)

# Print some information to terminal
Expand Down
109 changes: 93 additions & 16 deletions backends/cadence/aot/_passes.py → backends/cadence/aot/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,74 @@

# pyre-strict

from typing import Any, cast, Dict, Sequence, Tuple
from typing import Any, cast, Dict, List, Optional, Sequence, Tuple, Type

import torch
import torch.fx
import torch.utils._pytree as pytree
from executorch.backends.cadence.aot.pass_utils import (
CadencePassAttribute,
create_cadence_pass_filter,
register_cadence_pass,
)
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue
from executorch.exir.pass_manager import PassManager, PassType
from executorch.exir.passes import dead_code_elimination_pass
from executorch.exir.passes.scalar_to_tensor_pass import ScalarToTensorPass
from executorch.exir.passes.spec_prop_pass import SpecPropPass
from torch._subclasses import FakeTensor
from torch.utils._pytree import tree_map_only


@register_cadence_pass(CadencePassAttribute(opt_level=0))
class RemoveCloneOpsTransformImported(ExportPass):
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
finalize_passes: List[PassType] = [
RemoveCloneOpsTransform(),
]
result = PassManager(passes=finalize_passes)(graph_module)
dead_code_elimination_pass(result.graph_module)
return result


@register_cadence_pass(CadencePassAttribute(opt_level=0))
class InitializePipeline(ExportPass):
"""
Initialize the Jarvis pipeline. This should invariably be the first pass to
run.
"""

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
dead_code_elimination_pass(graph_module)
result = SpecPropPass()(graph_module)
assert result is not None
return result


@register_cadence_pass(CadencePassAttribute(opt_level=0))
class FinalizePipeline(ExportPass):
"""
The final cleanup pass after running the Jarvis pipeline.
"""

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
finalize_passes: List[PassType] = [
ScalarToTensorPass(),
SpecPropPass(),
]
result = PassManager(passes=finalize_passes)(graph_module)
dead_code_elimination_pass(result.graph_module)
return result


# Similar to what's done in executorch/exir/pass_base.py
Argument = Any # pyre-ignore


@register_cadence_pass(CadencePassAttribute(opt_level=0))
class ReplacePT2QuantWithCadenceQuantPass(ExportPass):
"""
Replace the pt2 quantization ops with custom cadence quantization ops.
Expand All @@ -44,6 +97,7 @@ def call_operator(
)


@register_cadence_pass(CadencePassAttribute(opt_level=0))
class ReplacePT2DequantWithCadenceDequantPass(ExportPass):
"""
Replace the pt2 dequantization ops with custom cadence dequantization ops.
Expand All @@ -67,6 +121,7 @@ def call_operator(
)


@register_cadence_pass(CadencePassAttribute(opt_level=0))
class ReplaceScalarTensorWithFullPass(ExportPass):
"""
aten.scalar_tensor can be replaced by aten.full with a shape of [1].
Expand Down Expand Up @@ -96,6 +151,7 @@ def call_operator(
)


@register_cadence_pass(CadencePassAttribute(opt_level=0))
class ReplaceSqueezeAndUnsqueezeWithViewPass(ExportPass):
"""
When the shape is static, replace squeeze_copy and unsqueeze_copy ops with
Expand Down Expand Up @@ -131,7 +187,8 @@ def call_operator(
)


class RemoveZeroSizedCatArgsPass(ExportPass):
@register_cadence_pass(CadencePassAttribute(opt_level=0))
class RemoveZeroSizedCatArgsPass(ExportPass): # is this the latest?
def call_operator(
self,
op, # pyre-ignore
Expand Down Expand Up @@ -176,6 +233,7 @@ def call_operator(
return super().call_operator(op, args, kwargs, meta)


@register_cadence_pass(CadencePassAttribute(opt_level=0))
class RemoveNopExpandOpPass(ExportPass):
"""
For an expand op, if the operator shape matches the expand shape, then the
Expand Down Expand Up @@ -205,6 +263,7 @@ def call_operator(
return super().call_operator(op, args, kwargs, meta)


@register_cadence_pass(CadencePassAttribute(opt_level=0))
class ReplaceLogicalNotBooleanWhereWithWherePass(ExportPass):
"""
A where op with a logical_not and a boolean tensor can be replaced
Expand Down Expand Up @@ -255,20 +314,8 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
return result


class InitializePipeline(ExportPass):
"""
Initialize the Jarvis pipeline. This should invariably be the first pass to
run.
"""

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
dead_code_elimination_pass(graph_module)
result = SpecPropPass()(graph_module)
assert result is not None
return result


class ReplaceSafeSoftmaxWithSoftmax(ExportPass):
@register_cadence_pass(CadencePassAttribute(opt_level=0))
class ReplaceSafeSoftmaxWithSoftmax(ExportPass): # keep
"""
Replace _safe_softmax with _softmax
"""
Expand All @@ -292,3 +339,33 @@ def call_operator(
kwargs,
meta,
)


def get_passes_in_default_order() -> List[Type[PassType]]:
passes = [
InitializePipeline,
RemoveZeroSizedCatArgsPass,
ReplaceLogicalNotBooleanWhereWithWherePass,
ReplaceScalarTensorWithFullPass,
RemoveCloneOpsTransformImported,
RemoveNopExpandOpPass,
ReplaceSqueezeAndUnsqueezeWithViewPass,
ReplacePT2QuantWithCadenceQuantPass,
ReplacePT2DequantWithCadenceDequantPass,
# TODO: add the rest of the passes here.
]
return pytree.tree_flatten(passes)[0]


def get_cadence_passes(
opt_level: int,
) -> List[Optional[PassResult]]:
passes = get_passes_in_default_order()
pass_filter = create_cadence_pass_filter(opt_level)
filtered_passes = [
# pyre-fixme[20]: Call `torch.fx.passes.infra.pass_base.PassBase.__call__` expects argument `graph_module`.
filtered_pass()
# pyre-fixme[6]: In call `filter.__new__` ... got `List[Type[typing.Callable[[GraphModule], Optional[PassResult]]]]`.
for filtered_pass in list(filter(pass_filter, passes))
]
return filtered_passes
Loading