Skip to content

Add pass to convert split to many slice #4562

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/arm_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
exir_ops.edge.aten.hardtanh.default,
exir_ops.edge.aten.convolution.default,
exir_ops.edge.aten.div.Tensor,
exir_ops.edge.aten.split_with_sizes_copy.default,
exir_ops.edge.aten.full.default,
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
exir_ops.edge.aten.avg_pool2d.default,
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/operators/op_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def define_node(
shape = input_node.shape
dim = dim.number
end = (shape[dim] + end.number) % shape[dim]
if end == 0:
end = shape[dim]
size = end - start.number
assert size > 0
assert size <= shape[dim]
Expand Down
4 changes: 4 additions & 0 deletions backends/arm/passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from executorch.backends.arm.passes.convert_expand_copy_to_repeat import (
ConvertExpandCopyToRepeatPass,
)
from executorch.backends.arm.passes.convert_split_to_slice import (
ConvertSplitToSlicePass,
)
from executorch.backends.arm.passes.remove_clone_pass import RemoveClonePass
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.pass_manager import PassManager
Expand All @@ -28,6 +31,7 @@ def transform_to_backend_pipeline(
"""Apply passes before transforming program to backend"""
self.add_pass(RemoveClonePass())
self.add_pass(ConvertExpandCopyToRepeatPass())
self.add_pass(ConvertSplitToSlicePass())
for spec in compile_spec:
if spec.key == "permute_memory_format":
memory_format = spec.value.decode()
Expand Down
70 changes: 70 additions & 0 deletions backends/arm/passes/convert_split_to_slice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch.fx
from executorch.backends.arm.tosa_mapping import extract_tensor_meta
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult


class ConvertSplitToSlicePass(ExportPass):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, like it how simple you made it. :)

"""
Replace a split operation with many slice operations.
"""

split_ops = (
exir_ops.edge.aten.split_with_sizes_copy.default,
exir_ops.edge.aten.split_copy.Tensor,
)
slice = exir_ops.edge.aten.slice_copy.Tensor

def call(self, graph_module: torch.fx.GraphModule):
graph = graph_module.graph
for node in graph.nodes:
if node.target not in self.split_ops:
continue

# Get useful variables
split_node = node
input_node = split_node.all_input_nodes[0]
output_nodes = split_node.users.copy()
_, shape, _ = extract_tensor_meta(input_node.meta)
rank = len(shape)
split_lengths = split_node.args[1]
dim = split_node.args[2] if len(split_node.args) > 2 else 0
dim = (dim + rank) % rank

assert (
sum(split_lengths) == shape[dim]
), "Given split lengths don't sum up to the size of the dimension."

# Convert split argument 'split_lengths' to slice arguments start and end.
starts = [0] * len(split_lengths)
ends = [0] * len(split_lengths)
start = 0
end = 0
for i, split_length in enumerate(split_lengths):
end = start + split_length
starts[i] = start
ends[i] = end
start = end

# Output nodes are of type getitem
# Create one slice node for each output node with matching argumetns.
with graph_module.graph.inserting_before(split_node):
for output_node in output_nodes:
index = output_node.args[1]
slice_node = graph.create_node(
"call_function",
self.slice,
(input_node, dim, starts[index], ends[index]),
)
slice_node.meta = split_node.meta.copy()
slice_node.meta["val"] = slice_node.meta["val"][index]
output_node.replace_input_with(split_node, slice_node)
graph.eliminate_dead_code()
graph_module.recompile()
return PassResult(graph_module, True)
4 changes: 4 additions & 0 deletions backends/arm/quantizer/arm_quantizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# Utility functions for ArmQuantizer
#

import operator
from typing import Callable, cast, List

import torch
Expand Down Expand Up @@ -141,8 +142,11 @@ def is_share_obs_or_fq_op(op: Callable) -> bool:
torch.ops.aten.view_copy.default,
torch.ops.aten.view.default,
torch.ops.aten.slice.Tensor,
torch.ops.aten.split.Tensor,
torch.ops.aten.split_with_sizes.default,
torch.ops.aten.flatten.using_ints,
torch.ops.aten.dropout.default,
operator.getitem,
]


Expand Down
2 changes: 1 addition & 1 deletion backends/arm/test/ops/test_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def forward(self, x: torch.Tensor):
elif x.dim() == 3:
return x[0:7, 0:1, 0:8]
elif x.dim() == 4:
return x[:, 2:5, 3:5, 4:5]
return x[:, 2:5, 3:5, 4:10]

def _test_slice_tosa_MI_pipeline(
self, module: torch.nn.Module, test_data: torch.Tensor
Expand Down
139 changes: 139 additions & 0 deletions backends/arm/test/ops/test_split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from executorch.backends.arm.quantizer.arm_quantizer import (
ArmQuantizer,
get_symmetric_quantization_config,
)
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.arm_tester import ArmTester
from executorch.backends.xnnpack.test.tester.tester import Quantize
from parameterized import parameterized

test_data_t = tuple[torch.Tensor, int | list[int], int]


class TestSimpleSplit(unittest.TestCase):
class Split(torch.nn.Module):

test_data: list[tuple[test_data_t]] = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love this.

((torch.rand(10), 2, 0),),
((torch.rand(10, 10), 3, 1),),
((torch.rand(10, 10), 4, -1),),
((torch.rand(10, 15, 10), [2, 2, 11], 1),),
((torch.rand(4, 4, 4, 4), 2, 0),),
((torch.rand(4, 4, 4, 4), [1, 1, 1, 1], -2),),
]

def forward(
self, x: torch.Tensor, split_size_or_sections: int | list[int], dim: int
):
return x.split(split_size=split_size_or_sections, dim=dim)

class SplitWithSizes(torch.nn.Module):
def forward(self, x: torch.Tensor, split_sizes: list[int], dim: int):
return x.split_with_sizes(split_sizes=split_sizes, dim=dim)

class SplitSingleOut(torch.nn.Module):
def forward(
self, x: torch.Tensor, split_size_or_sections: int | list[int], dim: int
):
return x.split(split_size=split_size_or_sections, dim=dim)[1]

class SplitTwoOut(torch.nn.Module):
def forward(
self, x: torch.Tensor, split_size_or_sections: int | list[int], dim: int
):
return x.split(split_size=split_size_or_sections, dim=dim)[1:3]

def _test_split_tosa_MI_pipeline(
self, module: torch.nn.Module, test_data: test_data_t
):
(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_tosa_compile_spec(),
)
.export()
.to_edge()
.check(
[
"executorch_exir_dialects_edge__ops_aten_split_with_sizes_copy_default"
]
)
.partition()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.run_method_and_compare_outputs(inputs=test_data)
)

def _test_split_tosa_BI_pipeline(
self, module: torch.nn.Module, test_data: test_data_t
):

quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_tosa_compile_spec(),
)
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
.export()
.to_edge()
.partition()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.run_method_and_compare_outputs(inputs=test_data, qtol=1)
)

def _test_split_u55_BI_pipeline(
self, module: torch.nn.Module, test_data: test_data_t
):
quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_u55_compile_spec(),
)
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
.export()
.check(["torch.ops.aten.split.Tensor"])
.to_edge()
.partition()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
)

@parameterized.expand(Split.test_data)
def test_split_tosa_MI(self, test_data: test_data_t):
self._test_split_tosa_MI_pipeline(self.Split(), test_data)

@parameterized.expand([Split.test_data[3], Split.test_data[5]])
def test_split_with_sizes_tosa_MI(self, test_data: test_data_t):
assert isinstance(test_data[1], list)
self._test_split_tosa_MI_pipeline(self.SplitWithSizes(), test_data)

@parameterized.expand(Split.test_data)
def test_split_n_out_tosa_MI(self, test_data: test_data_t):
self._test_split_tosa_MI_pipeline(self.SplitSingleOut(), test_data)
self._test_split_tosa_MI_pipeline(self.SplitTwoOut(), test_data)

@parameterized.expand(Split.test_data)
def test_split_tosa_BI(self, test_data: test_data_t):
self._test_split_tosa_BI_pipeline(self.Split(), test_data)

# Fails during Vela compilation when trying to use a Tuple as a Named tuple,
# Could be Vela Issue, wait until Regor.
@parameterized.expand(Split.test_data)
@unittest.expectedFailure
def test_split_u55_BI(self, test_data: test_data_t):
self._test_split_u55_BI_pipeline(self.Split(), test_data)
35 changes: 19 additions & 16 deletions backends/arm/test/runner_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def set_timeout(self, timeout: int):
def run_corstone300(
self,
inputs: Tuple[torch.Tensor],
) -> torch.Tensor:
) -> list[torch.Tensor]:

assert (
self._has_init_run
Expand Down Expand Up @@ -268,12 +268,12 @@ def run_corstone300(

tosa_ref_output = np.fromfile(out_path_with_suffix, dtype=np.float32)
tosa_ref_output = torch.from_numpy(tosa_ref_output).reshape(inputs[0].shape)
return tosa_ref_output
return [tosa_ref_output]

def run_tosa_ref_model(
self,
inputs: Tuple[torch.Tensor],
) -> torch.Tensor:
) -> list[torch.Tensor]:
"""
Run TOSA reference model using the tosa_refence_model program.

Expand Down Expand Up @@ -369,23 +369,26 @@ def run_tosa_ref_model(
# Load desc.json, just to get the name of the output file above
with open(desc_file_path) as f:
desc_json = json.load(f)
ofm_file_npy = os.path.join(self.intermediate_path, desc_json["ofm_file"][0])

# Load the output file (OFM) and return it as a numpy array
tosa_ref_output = np.load(ofm_file_npy)
tosa_ref_outputs = []
for ofm_file in desc_json["ofm_file"]:
ofm_file_npy = os.path.join(self.intermediate_path, ofm_file)

if self.is_quantized:
# Need to dequant back to FP32 for comparison with torch output
quant_param = self.qp_output
assert (
quant_param is not None
), "There are no quantization parameters, check output parameters"
tosa_ref_output = (tosa_ref_output - quant_param.zp) * quant_param.scale
# Load the output file (OFM) and return it as a numpy array
tosa_ref_output = np.load(ofm_file_npy)

# tosa_output is a numpy array, convert to torch tensor for comparison
tosa_ref_output = torch.from_numpy(tosa_ref_output.astype("float32"))
if self.is_quantized:
# Need to dequant back to FP32 for comparison with torch output
quant_param = self.qp_output
assert (
quant_param is not None
), "There are no quantization parameters, check output parameters"
tosa_ref_output = (tosa_ref_output - quant_param.zp) * quant_param.scale

return tosa_ref_output
# tosa_output is a numpy array, convert to torch tensor for comparison
tosa_ref_outputs.append(torch.from_numpy(tosa_ref_output.astype("float32")))

return tosa_ref_outputs


def prep_data_for_save(
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/test/tester/arm_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def run_method_and_compare_outputs(
print(f"Run {run_iteration} with input shapes: {input_shapes}")

reference_output = reference_stage.run_artifact(reference_input)
test_output = (test_stage.run_artifact(test_input),)
test_output = tuple(test_stage.run_artifact(test_input))
if is_nhwc:
test_output = self.transpose_data_format(test_output, "NCHW")

Expand Down
Loading