Skip to content

TOSA specification in Arm partitioner #6851

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 12 additions & 71 deletions backends/arm/arm_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,27 @@
# pyre-unsafe

import logging
import operator
import os
from typing import Callable, cast, final, List, Optional, Tuple
from typing import Callable, final, List, Optional, Tuple

import torch
from executorch.backends.arm.arm_backend import ArmBackend # usort: skip
from executorch.backends.arm._passes.tag_io_quant_pass import TagIOQuantPass
from executorch.backends.arm.operator_support.tosa_supported_operators import (
TOSASupportedOperators,
)
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.backend.partitioner import (
DelegationSpec,
Partitioner,
PartitionResult,
)
from executorch.exir.backend.utils import tag_constant_data
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.passes import PassManager
from torch.export.exported_program import ExportedProgram
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner

from torch.fx.passes.operator_support import OperatorSupportBase

logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
TOSA_DBG_VERBOSE = os.environ.get("TOSA_DBG_VERBOSE") == "1"
Expand All @@ -35,71 +35,6 @@
logger.setLevel(logging.INFO)


class TOSASupportedOperators(OperatorSupportBase):
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
supported = node.op == "call_function" and node.target in [
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.expand_copy.default,
exir_ops.edge.aten.cat.default,
exir_ops.edge.aten.bmm.default,
exir_ops.edge.aten.permute_copy.default,
exir_ops.edge.aten.hardtanh.default,
exir_ops.edge.aten.convolution.default,
exir_ops.edge.aten.div.Tensor,
exir_ops.edge.aten.exp.default,
exir_ops.edge.aten.log.default,
exir_ops.edge.aten.linear.default,
exir_ops.edge.aten.split_with_sizes_copy.default,
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
exir_ops.edge.aten.native_layer_norm.default,
exir_ops.edge.aten.avg_pool2d.default,
exir_ops.edge.aten.max_pool2d_with_indices.default,
exir_ops.edge.aten.sigmoid.default,
exir_ops.edge.aten.mm.default,
exir_ops.edge.aten.repeat.default,
exir_ops.edge.aten.reciprocal.default,
exir_ops.edge.aten.relu.default,
exir_ops.edge.aten.rsqrt.default,
exir_ops.edge.aten._softmax.default,
exir_ops.edge.aten.select_copy.int,
exir_ops.edge.aten._log_softmax.default,
exir_ops.edge.aten.slice_copy.Tensor,
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.sum.dim_IntList,
exir_ops.edge.aten.tanh.default,
exir_ops.edge.aten.upsample_nearest2d.vec,
exir_ops.edge.aten.view_copy.default,
exir_ops.edge.aten.clone.default,
exir_ops.edge.aten.mean.dim,
exir_ops.edge.aten.var.correction,
exir_ops.edge.aten.unsqueeze_copy.default,
exir_ops.edge.aten.squeeze_copy.dims,
operator.getitem,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
]

supported &= self.is_node_supported_custom(node)

# Override partitioning based on pre partition passes
if "arm_override_partition" in node.meta:
supported = supported & node.meta["arm_override_partition"]
node.meta.pop("arm_override_partition")

return supported

def is_node_supported_custom(self, node: torch.fx.Node) -> bool:
if node.target == exir_ops.edge.aten.mean.dim:
keep_dim = node.args[2] if len(node.args) > 2 else False
return cast(bool, keep_dim)
if node.target == exir_ops.edge.aten.var.correction:
keep_dim = node.kwargs.get("keepdim", False)
return cast(bool, keep_dim)
return True


@final
class ArmPartitioner(Partitioner):
def __init__(self, compile_spec: List[CompileSpec]) -> None:
Expand All @@ -111,6 +46,12 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
logger.info("ArmPartitioner::partition")
partition_tags = {}

tosa_spec = TosaSpecification.create_from_compilespecs(
self.delegation_spec.compile_specs
)

logger.info(f"Partitioning for {tosa_spec}")

for spec in self.delegation_spec.compile_specs:
if spec.key == "quantize_io" and spec.value.decode() == "True":
# Exclude IO quantization from the partition
Expand All @@ -123,7 +64,7 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:

capability_partitioner = CapabilityBasedPartitioner(
exported_program.graph_module,
TOSASupportedOperators(),
TOSASupportedOperators(tosa_spec),
allows_single_node_partition=True,
)
partition_list = capability_partitioner.propose_partitions()
Expand Down
8 changes: 8 additions & 0 deletions backends/arm/operator_support/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright 2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

from . import mean_dim_support, tosa_supported_operators, var_correction_support # noqa
33 changes: 33 additions & 0 deletions backends/arm/operator_support/mean_dim_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright 2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

from typing import cast

import torch.fx as fx

from executorch.backends.arm.operator_support.tosa_supported_operators import (
register_tosa_support_check,
SupportedTOSAOperatorCheck,
)
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.exir.dialects._ops import ops as exir_ops


@register_tosa_support_check
class MeanDimSupported(SupportedTOSAOperatorCheck):
targets = [exir_ops.edge.aten.mean.dim]

tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80.0+BI"),
TosaSpecification.create_from_string("TOSA-0.80.0+MI"),
]

def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool:
assert node.target in self.targets

keep_dim = node.args[2] if len(node.args) > 2 else False
return cast(bool, keep_dim)
128 changes: 128 additions & 0 deletions backends/arm/operator_support/tosa_supported_operators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright 2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import operator

import torch.fx as fx
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.exir.dialects._ops import ops as exir_ops
from torch.fx.passes.operator_support import OperatorSupportBase


class SupportedTOSAOperatorCheck:
"""
Supported OP for TOSA lowering
"""

# Should be populated by subclass implementation
tosa_specs: list[TosaSpecification] = []
targets: list[str] = []

def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool:
"""
Checks if the fx.Node node is lowerable using the TOSA specification defined by tosa_spec.
To be implemented by subclasses targeting
"""
raise NotImplementedError("NodeVisitor must be extended.")


# container for all SupportedTosaOperatorCheck classes
_tosa_spec_dicts: dict[TosaSpecification, dict[str, SupportedTOSAOperatorCheck]] = {
TosaSpecification.create_from_string("TOSA-0.80.0+BI"): {},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Add a URL as a comment on how to find the exact version spec

TosaSpecification.create_from_string("TOSA-0.80.0+MI"): {},
}


def register_tosa_support_check(checker):
"""
Decorator to mark a subclass implmentation of SupportedTosaOperatorCheck
to be registered for checking if a torch.fx.Node is lowerable given
a TOSA specification.
"""
for tosa_spec in checker.tosa_specs:
for target in checker.targets:
_tosa_spec_dicts[tosa_spec][target] = checker
return checker


def get_registered_tosa_support_checks(
tosa_spec: TosaSpecification,
) -> dict[str, SupportedTOSAOperatorCheck]:

if tosa_spec not in _tosa_spec_dicts:
raise RuntimeError

tosa_support_checks = {}
for target, tosa_check in _tosa_spec_dicts[tosa_spec].items():
tosa_support_checks[target] = tosa_check()

return tosa_support_checks


class TOSASupportedOperators(OperatorSupportBase):
def __init__(self, tosa_spec: TosaSpecification):
super().__init__()
self.tosa_spec = tosa_spec

def is_node_supported(self, submodules, node: fx.Node) -> bool:
supported = node.op == "call_function" and node.target in [
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.expand_copy.default,
exir_ops.edge.aten.cat.default,
exir_ops.edge.aten.bmm.default,
exir_ops.edge.aten.permute_copy.default,
exir_ops.edge.aten.hardtanh.default,
exir_ops.edge.aten.convolution.default,
exir_ops.edge.aten.div.Tensor,
exir_ops.edge.aten.exp.default,
exir_ops.edge.aten.log.default,
exir_ops.edge.aten.linear.default,
exir_ops.edge.aten.split_with_sizes_copy.default,
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
exir_ops.edge.aten.native_layer_norm.default,
exir_ops.edge.aten.avg_pool2d.default,
exir_ops.edge.aten.max_pool2d_with_indices.default,
exir_ops.edge.aten.sigmoid.default,
exir_ops.edge.aten.mm.default,
exir_ops.edge.aten.repeat.default,
exir_ops.edge.aten.reciprocal.default,
exir_ops.edge.aten.relu.default,
exir_ops.edge.aten.rsqrt.default,
exir_ops.edge.aten._softmax.default,
exir_ops.edge.aten.select_copy.int,
exir_ops.edge.aten._log_softmax.default,
exir_ops.edge.aten.slice_copy.Tensor,
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.sum.dim_IntList,
exir_ops.edge.aten.tanh.default,
exir_ops.edge.aten.upsample_nearest2d.vec,
exir_ops.edge.aten.view_copy.default,
exir_ops.edge.aten.clone.default,
exir_ops.edge.aten.unsqueeze_copy.default,
exir_ops.edge.aten.squeeze_copy.dims,
operator.getitem,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
]

if not supported:
supported = self.is_node_supported_custom(node)

# Override partitioning based on pre partition passes
if "arm_override_partition" in node.meta:
supported = supported & node.meta["arm_override_partition"]
node.meta.pop("arm_override_partition")

return supported

def is_node_supported_custom(self, node: fx.Node) -> bool:
tosa_checks = get_registered_tosa_support_checks(self.tosa_spec)
if node.target in tosa_checks.keys():
return tosa_checks[node.target].is_node_supported(node, self.tosa_spec)
return False
33 changes: 33 additions & 0 deletions backends/arm/operator_support/var_correction_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright 2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

from typing import cast

import torch.fx as fx

from executorch.backends.arm.operator_support.tosa_supported_operators import (
register_tosa_support_check,
SupportedTOSAOperatorCheck,
)
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.exir.dialects._ops import ops as exir_ops


@register_tosa_support_check
class VarCorrectionSupported(SupportedTOSAOperatorCheck):
targets = [exir_ops.edge.aten.var.correction]

tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80.0+BI"),
TosaSpecification.create_from_string("TOSA-0.80.0+MI"),
]

def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool:
assert node.target in self.targets

keep_dim = node.kwargs.get("keepdim", False)
return cast(bool, keep_dim)
Loading