Skip to content

Insert transposes around view_copy ops #6435

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 79 additions & 26 deletions backends/arm/_passes/annotate_channels_last_dim_order_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
from executorch.backends.arm._passes.arm_pass_utils import (
create_node,
get_first_fake_tensor,
insert_q_dq_pair,
)
from executorch.backends.arm.tosa_quant_utils import dq_op
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op
from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
Expand Down Expand Up @@ -79,37 +80,89 @@ def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node):

return False

def insert_input_transpose(self, node, input_node, graph_module):
quantize = input_node.target == dq_op
q_params = input_node.args[1:] if quantize else None
with graph_module.graph.inserting_before(node):
permute_node = create_node(
graph_module.graph,
torch.ops.passthrough_to_tosa._transpose,
args=(input_node, list(self.NHWC_inverse_order)),
quantize=quantize,
q_params=q_params,
)
node.replace_input_with(input_node, permute_node)

permute_node.meta["tosa_dim_order"] = tuple(
range(len(input_node.meta["val"].size()))
)

def insert_output_transpose(self, node, graph_module):
with graph_module.graph.inserting_after(node):
permute_node = create_node(
graph_module.graph,
torch.ops.passthrough_to_tosa._transpose,
args=(node, list(self.NHWC_order)),
)
permute_node.meta["tosa_dim_order"] = self.NHWC_order
node.meta["tosa_dim_order"] = (0, 1, 2, 3)
users = [user for user in node.users if user != permute_node]
for user in users:
user.replace_input_with(node, permute_node)

quantize = node.args[0] == q_op
if quantize:
q_params = node.args[0].args[1:]
insert_q_dq_pair(graph_module.graph, node, q_params)

def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
"""
Reshape operations are not equivalent in NCHW and NHWC.
To get around this, transposes need to be added if the previous or new shape
fulfil the following condition:
C > 1 and (H or W > 1)

This is relevant for the following operations;
squeeze: 4D -> 3D
unsqueeze: <4D -> 4D
view: <4D -> 4D
view: 4D -> <4D
view: 4D -> 4D
"""

def transpose_condition(shape):
if len(shape) != 4:
return False
C = shape[1]
H = shape[2]
W = shape[3]
return C > 1 and (H > 1 or W > 1)

for node in graph_module.graph.nodes:
if node.op != "call_function":
continue
if node.target == exir_ops.edge.aten.squeeze_copy.dims:
input_node = node.args[0]
if input_node.meta["val"].dim() == 4:
with graph_module.graph.inserting_before(node):
permute_node = create_node(
graph_module.graph,
torch.ops.passthrough_to_tosa._transpose,
args=(input_node, list(self.NHWC_inverse_order)),
)
permute_node.meta["tosa_dim_order"] = tuple(
range(len(input_node.meta["val"].size()))
)
node.replace_input_with(input_node, permute_node)

if node.target == exir_ops.edge.aten.unsqueeze_copy.default:
if node.meta["val"].dim() == 4:
with graph_module.graph.inserting_after(node):
permute_node = create_node(
graph_module.graph,
torch.ops.passthrough_to_tosa._transpose,
args=(node, list(self.NHWC_order)),
)
permute_node.meta["tosa_dim_order"] = self.NHWC_order
node.meta["tosa_dim_order"] = (0, 1, 2, 3)
users = [user for user in node.users if user != permute_node]
for user in users:
user.replace_input_with(node, permute_node)
input_shape = input_node.meta["val"].shape
if transpose_condition(input_shape):
self.insert_input_transpose(node, input_node, graph_module)

elif node.target == exir_ops.edge.aten.unsqueeze_copy.default:
output_shape = node.meta["val"].shape
if transpose_condition(output_shape):
self.insert_output_transpose(node, graph_module)

elif node.target == exir_ops.edge.aten.view_copy.default:
input_node = node.args[0]

old_shape = input_node.meta["val"].shape
new_shape = node.meta["val"].shape

if transpose_condition(old_shape):
self.insert_input_transpose(node, input_node, graph_module)

if transpose_condition(new_shape):
self.insert_output_transpose(node, graph_module)

def call(self, graph_module: torch.fx.GraphModule):
for node in graph_module.graph.nodes:
Expand Down
3 changes: 1 addition & 2 deletions backends/arm/test/misc/test_debug_feats.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,8 @@ def test_numerical_diff_prints(self):
ArmTester(
model,
example_inputs=model.get_inputs(),
compile_spec=common.get_tosa_compile_spec(),
compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=False),
)
.quantize()
.export()
.to_edge()
.partition()
Expand Down
7 changes: 4 additions & 3 deletions backends/arm/test/ops/test_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def _test_layernorm_tosa_MI_pipeline(
ArmTester(
model=module,
example_inputs=test_data,
compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=False),
compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True),
)
.export()
.check(["torch.ops.aten.layer_norm.default"])
Expand All @@ -93,7 +93,7 @@ def _test_layernorm_tosa_BI_pipeline(
ArmTester(
model=module,
example_inputs=test_data,
compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=False),
compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True),
)
.quantize()
.check_not(["torch.ops.aten.layer_norm.default"])
Expand Down Expand Up @@ -148,7 +148,8 @@ def test_layer_norm_tosa_BI(
self.LayerNorm(*model_params), (test_data,)
)

@parameterized.expand(test_data_suite)
# Skip tests that require transposes.
@parameterized.expand(test_data_suite[:-2])
def test_layer_norm_u55_BI(
self,
test_name: str,
Expand Down
3 changes: 1 addition & 2 deletions backends/arm/test/ops/test_squeeze.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ def _test_squeeze_tosa_BI_pipeline(
.check_count({export_target: 1})
.to_edge()
.partition()
.dump_artifact()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.run_method_and_compare_outputs(inputs=test_data, qtol=1)
Expand Down Expand Up @@ -156,7 +155,7 @@ def test_squeeze_u85_BI(
test_tensor: torch.Tensor,
):
self._test_squeeze_ethosu_BI_pipeline(
common.get_u85_compile_spec(permute_memory_to_nhwc=False),
common.get_u85_compile_spec(permute_memory_to_nhwc=True),
self.Squeeze(),
(test_tensor,),
"torch.ops.aten.squeeze.default",
Expand Down
64 changes: 41 additions & 23 deletions backends/arm/test/ops/test_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,33 @@
from parameterized import parameterized


class TestSimpleView(unittest.TestCase):
class TestView(unittest.TestCase):
"""Tests the view operation."""

class View(torch.nn.Module):

sizes = [10, 15, 50, 100]
test_parameters = [(torch.ones(n),) for n in sizes]

def forward(self, x: torch.Tensor):
return x.view(-1, 5)
needs_transpose_tests = [
(torch.rand(100), (1, -1, 5, 2)),
(torch.rand(10, 2, 1, 5), (1, -1, 5, 2)),
(torch.rand(1, 2, 1, 9), (3, 1, 3, 2)),
(torch.rand(2, 1, 1, 9), (3, 2, 3, 1)),
(torch.rand(2, 50, 2, 1), (1, 200)),
(torch.rand(2, 5, 2, 3), (1, 15, 4)),
]

no_transpose_tests = [
(torch.rand(2, 1, 1, 9), (3, 1, 3, 2)),
(torch.rand(5, 10, 1, 1), (25, 2, 1, 1)),
(torch.rand(10, 2), (1, 1, 5, 4)),
(torch.rand(10, 10), (5, 1, 5, 4)),
(torch.rand(1, 1, 1, 10), (1, 1, 10, 1)),
(torch.rand(1, 1, 5, 10), (1, 1, 50, 1)),
(torch.rand(5, 10, 1, 1), (1, 25, 2)),
(torch.rand(2, 50, 1, 1), (1, 100)),
]

def forward(self, x: torch.Tensor, new_shape):
return x.view(new_shape)

def _test_view_tosa_MI_pipeline(
self, module: torch.nn.Module, test_data: torch.Tensor
Expand Down Expand Up @@ -82,11 +99,7 @@ def _test_view_ethos_BI_pipeline(
):
quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_u55_compile_spec(),
)
ArmTester(module, example_inputs=test_data, compile_spec=compile_spec)
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
.export()
.check_count({"torch.ops.aten.view.default": 1})
Expand All @@ -110,18 +123,23 @@ def _test_view_u85_BI_pipeline(
common.get_u85_compile_spec(), module, test_data
)

@parameterized.expand(View.test_parameters)
def test_view_tosa_MI(self, test_tensor: torch.Tensor):
self._test_view_tosa_MI_pipeline(self.View(), (test_tensor,))
@parameterized.expand(View.needs_transpose_tests + View.no_transpose_tests)
def test_view_tosa_MI(self, test_tensor: torch.Tensor, new_shape):
self._test_view_tosa_MI_pipeline(self.View(), (test_tensor, new_shape))

@parameterized.expand(View.needs_transpose_tests + View.no_transpose_tests)
def test_view_tosa_BI(self, test_tensor: torch.Tensor, new_shape):
self._test_view_tosa_BI_pipeline(self.View(), (test_tensor, new_shape))

@parameterized.expand(View.test_parameters)
def test_view_tosa_BI(self, test_tensor: torch.Tensor):
self._test_view_tosa_BI_pipeline(self.View(), (test_tensor,))
@parameterized.expand(View.no_transpose_tests)
def test_view_u55_BI(self, test_tensor: torch.Tensor, new_shape):
self._test_view_u55_BI_pipeline(self.View(), (test_tensor, new_shape))

@parameterized.expand(View.test_parameters)
def test_view_u55_BI(self, test_tensor: torch.Tensor):
self._test_view_u55_BI_pipeline(self.View(), (test_tensor,))
@parameterized.expand(View.needs_transpose_tests)
@unittest.expectedFailure
def test_view_transpose_u55_BI(self, test_tensor: torch.Tensor, new_shape):
self._test_view_u55_BI_pipeline(self.View(), (test_tensor, new_shape))

@parameterized.expand(View.test_parameters)
def test_view_u85_BI(self, test_tensor: torch.Tensor):
self._test_view_u85_BI_pipeline(self.View(), (test_tensor,))
@parameterized.expand(View.needs_transpose_tests + View.no_transpose_tests)
def test_view_u85_BI(self, test_tensor: torch.Tensor, new_shape):
self._test_view_u85_BI_pipeline(self.View(), (test_tensor, new_shape))
Loading