Skip to content

Arm backend: Refactor misc tests for TOSA V1.0 #10851

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions backends/arm/test/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,15 @@ def maybe_get_tosa_collate_path() -> str | None:
tosa_test_base = os.environ.get("TOSA_TESTCASES_BASE_PATH")
if tosa_test_base:
current_test = os.environ.get("PYTEST_CURRENT_TEST")
#'backends/arm/test/ops/test_mean_dim.py::TestMeanDim::test_meandim_tosa_BI_0_zeros (call)'
test_class = current_test.split("::")[1] # type: ignore[union-attr]
test_name = current_test.split("::")[-1].split(" ")[0] # type: ignore[union-attr]
# '::test_collate_tosa_BI_tests[randn] (call)'
test_name = current_test.split("::")[1].split(" ")[0] # type: ignore[union-attr]
if "BI" in test_name:
tosa_test_base = os.path.join(tosa_test_base, "tosa-bi")
elif "MI" in test_name:
tosa_test_base = os.path.join(tosa_test_base, "tosa-mi")
else:
tosa_test_base = os.path.join(tosa_test_base, "other")
return os.path.join(tosa_test_base, test_class, test_name)
return os.path.join(tosa_test_base, test_name)

return None

Expand Down
219 changes: 86 additions & 133 deletions backends/arm/test/misc/test_custom_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,25 @@
# LICENSE file in the root directory of this source tree.

import logging
from typing import Tuple

import torch
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.arm_tester import ArmTester
from executorch.backends.arm.tosa_partitioner import TOSAPartitioner
from executorch.backends.arm.test.tester.test_pipeline import TosaPipelineMI
from executorch.exir.backend.operator_support import (
DontPartition,
DontPartitionModule,
DontPartitionName,
)
from executorch.exir.dialects._ops import ops as exir_ops

input_t1 = Tuple[torch.Tensor, torch.Tensor] # Input x, y


class CustomPartitioning(torch.nn.Module):
inputs = (torch.randn(10, 4, 5), torch.randn(10, 4, 5))
inputs = {
"randn": (torch.randn(10, 4, 5), torch.randn(10, 4, 5)),
}

def forward(self, x: torch.Tensor, y: torch.Tensor):
z = x + y
Expand All @@ -27,7 +31,9 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):


class NestedModule(torch.nn.Module):
inputs = (torch.randn(10, 4, 5), torch.randn(10, 4, 5))
inputs = {
"randn": (torch.randn(10, 4, 5), torch.randn(10, 4, 5)),
}

def __init__(self):
super().__init__()
Expand All @@ -39,192 +45,139 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
return self.nested(a, b)


def test_single_reject(caplog):
@common.parametrize("test_data", CustomPartitioning.inputs)
def test_single_reject(caplog, test_data: input_t1):
caplog.set_level(logging.INFO)

module = CustomPartitioning()
inputs = module.inputs
compile_spec = common.get_tosa_compile_spec("TOSA-0.80+MI")
pipeline = TosaPipelineMI[input_t1](module, test_data, [], exir_op=[])
check = DontPartition(exir_ops.edge.aten.sigmoid.default)
partitioner = TOSAPartitioner(compile_spec, additional_checks=[check])
(
ArmTester(
module,
example_inputs=inputs,
compile_spec=compile_spec,
)
.export()
.to_edge_transform_and_lower(partitioners=[partitioner])
.check(["executorch_exir_dialects_edge__ops_aten_sigmoid_default"])
.check_count({"torch.ops.higher_order.executorch_call_delegate": 2})
.to_executorch()
.run_method_and_compare_outputs(inputs=inputs)
pipeline.change_args("to_edge_transform_and_lower", additional_checks=[check])
pipeline.change_args(
"check_count.exir", {"torch.ops.higher_order.executorch_call_delegate": 2}
)
pipeline.change_args(
"check_count.exir",
{"executorch_exir_dialects_edge__ops_aten_sigmoid_default": 1},
)
pipeline.run()
assert check.has_rejected_node()
assert "Rejected by DontPartition" in caplog.text


def test_multiple_reject():
@common.parametrize("test_data", CustomPartitioning.inputs)
def test_multiple_reject(test_data: input_t1):
module = CustomPartitioning()
inputs = module.inputs
compile_spec = common.get_tosa_compile_spec("TOSA-0.80+MI")
pipeline = TosaPipelineMI[input_t1](module, test_data, [], exir_op=[])
check = DontPartition(
exir_ops.edge.aten.sigmoid.default, exir_ops.edge.aten.mul.Tensor
)
partitioner = TOSAPartitioner(compile_spec, additional_checks=[check])
(
ArmTester(
module,
example_inputs=inputs,
compile_spec=compile_spec,
)
.export()
.to_edge_transform_and_lower(partitioners=[partitioner])
.check(["executorch_exir_dialects_edge__ops_aten_sigmoid_default"])
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.run_method_and_compare_outputs(inputs=inputs)
pipeline.change_args("to_edge_transform_and_lower", additional_checks=[check])
pipeline.change_args(
"check_count.exir", {"torch.ops.higher_order.executorch_call_delegate": 2}
)
pipeline.change_args(
"check_count.exir",
{"executorch_exir_dialects_edge__ops_aten_sigmoid_default": 1},
)
pipeline.run()
assert check.has_rejected_node()


def test_torch_op_reject(caplog):
@common.parametrize("test_data", CustomPartitioning.inputs)
def test_torch_op_reject(caplog, test_data: input_t1):
caplog.set_level(logging.INFO)

module = CustomPartitioning()
inputs = module.inputs
compile_spec = common.get_tosa_compile_spec("TOSA-0.80+MI")
check = DontPartition(torch.ops.aten.sigmoid.default)
partitioner = TOSAPartitioner(compile_spec, additional_checks=[check])
(
ArmTester(
module,
example_inputs=inputs,
compile_spec=compile_spec,
)
.export()
.to_edge_transform_and_lower(partitioners=[partitioner])
.check(["executorch_exir_dialects_edge__ops_aten_sigmoid_default"])
.check_count({"torch.ops.higher_order.executorch_call_delegate": 2})
.to_executorch()
.run_method_and_compare_outputs(inputs=inputs)
pipeline = TosaPipelineMI[input_t1](module, test_data, [], exir_op=[])
pipeline.change_args("to_edge_transform_and_lower", additional_checks=[check])
pipeline.change_args(
"check_count.exir", {"torch.ops.higher_order.executorch_call_delegate": 2}
)
pipeline.change_args(
"check_count.exir",
{"executorch_exir_dialects_edge__ops_aten_sigmoid_default": 1},
)
pipeline.run()
assert check.has_rejected_node()
assert "Rejected by DontPartition" in caplog.text


def test_string_op_reject():
@common.parametrize("test_data", CustomPartitioning.inputs)
def test_string_op_reject(test_data: input_t1):
module = CustomPartitioning()
inputs = module.inputs
compile_spec = common.get_tosa_compile_spec("TOSA-0.80+MI")
check = DontPartition("aten.sigmoid.default")
partitioner = TOSAPartitioner(compile_spec, additional_checks=[check])
(
ArmTester(
module,
example_inputs=inputs,
compile_spec=compile_spec,
)
.export()
.to_edge_transform_and_lower(partitioners=[partitioner])
.check(["executorch_exir_dialects_edge__ops_aten_sigmoid_default"])
.check_count({"torch.ops.higher_order.executorch_call_delegate": 2})
.to_executorch()
.run_method_and_compare_outputs(inputs=inputs)
pipeline = TosaPipelineMI[input_t1](module, test_data, [], exir_op=[])
pipeline.change_args("to_edge_transform_and_lower", additional_checks=[check])
pipeline.change_args(
"check_count.exir", {"torch.ops.higher_order.executorch_call_delegate": 2}
)

pipeline.change_args(
"check_count.exir",
{"executorch_exir_dialects_edge__ops_aten_sigmoid_default": 1},
)
pipeline.run()
assert check.has_rejected_node()


def test_name_reject(caplog):
@common.parametrize("test_data", CustomPartitioning.inputs)
def test_name_reject(caplog, test_data: input_t1):
caplog.set_level(logging.INFO)

module = CustomPartitioning()
inputs = module.inputs
compile_spec = common.get_tosa_compile_spec("TOSA-0.80+MI")
check = DontPartitionName("mul", "sigmoid", exact=False)
partitioner = TOSAPartitioner(compile_spec, additional_checks=[check])
(
ArmTester(
module,
example_inputs=inputs,
compile_spec=compile_spec,
)
.export()
.to_edge_transform_and_lower(partitioners=[partitioner])
.check(["executorch_exir_dialects_edge__ops_aten_sigmoid_default"])
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.run_method_and_compare_outputs(inputs=inputs)
pipeline = TosaPipelineMI[input_t1](module, test_data, [], exir_op=[])
pipeline.change_args("to_edge_transform_and_lower", additional_checks=[check])
pipeline.change_args(
"check_count.exir",
{"executorch_exir_dialects_edge__ops_aten_sigmoid_default": 1},
)
pipeline.run()
assert check.has_rejected_node()
assert "Rejected by DontPartitionName" in caplog.text


def test_module_reject():
@common.parametrize("test_data", CustomPartitioning.inputs)
def test_module_reject(test_data: input_t1):
module = NestedModule()
inputs = module.inputs
compile_spec = common.get_tosa_compile_spec("TOSA-0.80+MI")
check = DontPartitionModule(module_name="CustomPartitioning")
partitioner = TOSAPartitioner(compile_spec, additional_checks=[check])
(
ArmTester(
module,
example_inputs=inputs,
compile_spec=compile_spec,
)
.export()
.to_edge_transform_and_lower(partitioners=[partitioner])
.check(["executorch_exir_dialects_edge__ops_aten_sigmoid_default"])
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.run_method_and_compare_outputs(inputs=inputs)
pipeline = TosaPipelineMI[input_t1](module, test_data, [], exir_op=[])
pipeline.change_args("to_edge_transform_and_lower", additional_checks=[check])
pipeline.change_args(
"check_count.exir",
{"executorch_exir_dialects_edge__ops_aten_sigmoid_default": 1},
)
pipeline.run()
assert check.has_rejected_node()


def test_inexact_module_reject(caplog):
@common.parametrize("test_data", CustomPartitioning.inputs)
def test_inexact_module_reject(caplog, test_data: input_t1):
caplog.set_level(logging.INFO)

module = NestedModule()
inputs = module.inputs
compile_spec = common.get_tosa_compile_spec("TOSA-0.80+MI")
check = DontPartitionModule(module_name="Custom", exact=False)
partitioner = TOSAPartitioner(compile_spec, additional_checks=[check])
(
ArmTester(
module,
example_inputs=inputs,
compile_spec=compile_spec,
)
.export()
.to_edge_transform_and_lower(partitioners=[partitioner])
.check(["executorch_exir_dialects_edge__ops_aten_sigmoid_default"])
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.run_method_and_compare_outputs(inputs=inputs)
pipeline = TosaPipelineMI[input_t1](module, test_data, [], exir_op=[])
pipeline.change_args("to_edge_transform_and_lower", additional_checks=[check])
pipeline.change_args(
"check_count.exir",
{"executorch_exir_dialects_edge__ops_aten_sigmoid_default": 1},
)
pipeline.run()
assert check.has_rejected_node()
assert "Rejected by DontPartitionModule" in caplog.text


def test_module_instance_reject():
@common.parametrize("test_data", CustomPartitioning.inputs)
def test_module_instance_reject(test_data: input_t1):
module = NestedModule()
inputs = module.inputs
compile_spec = common.get_tosa_compile_spec("TOSA-0.80+MI")
check = DontPartitionModule(instance_name="nested")
partitioner = TOSAPartitioner(compile_spec, additional_checks=[check])
(
ArmTester(
module,
example_inputs=inputs,
compile_spec=compile_spec,
)
.export()
.to_edge_transform_and_lower(partitioners=[partitioner])
.check(["executorch_exir_dialects_edge__ops_aten_sigmoid_default"])
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.run_method_and_compare_outputs(inputs=inputs)
pipeline = TosaPipelineMI[input_t1](module, test_data, [], exir_op=[])
pipeline.change_args("to_edge_transform_and_lower", additional_checks=[check])
pipeline.change_args(
"check_count.exir",
{"executorch_exir_dialects_edge__ops_aten_sigmoid_default": 1},
)
pipeline.run()
assert check.has_rejected_node()
Loading
Loading