Skip to content

Commit 4f47cc9

Browse files
Erik-Lundellfreddan80
authored andcommitted
Handle new dims for repeat in pass
In addition to moving logic from node visitor, this also fixes repeating a rank 3 tensor to make a rank 3 tensor. Signed-off-by: Erik Lundell <[email protected]> Change-Id: I7090159bce47b6aa4d6613bbeb2d681d5cfcb193
1 parent 44bcfc3 commit 4f47cc9

File tree

5 files changed

+151
-31
lines changed

5 files changed

+151
-31
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@
4141
ScalarsToAttributePass,
4242
)
4343
from executorch.backends.arm._passes.size_adjust_conv2d_pass import SizeAdjustConv2DPass
44+
from executorch.backends.arm._passes.unsqueeze_before_repeat_pass import (
45+
UnsqueezeBeforeRepeatPass,
46+
)
4447
from executorch.backends.arm._passes.unsqueeze_scalar_placeholders_pass import (
4548
UnsqueezeScalarPlaceholdersPass,
4649
)
@@ -66,6 +69,7 @@ def transform_to_backend_pipeline(
6669
self.add_pass(RemoveClonePass())
6770
self.add_pass(ConvertExpandCopyToRepeatPass())
6871
self.add_pass(DecomposeLayerNormPass())
72+
self.add_pass(UnsqueezeBeforeRepeatPass())
6973
self.add_pass(DecomposeVarPass())
7074
self.add_pass(ConvertMeanDimToAveragePool())
7175
self.add_pass(DecomposeMeanDimPass())
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
# pyre-unsafe
7+
import torch
8+
import torch.fx
9+
from executorch.backends.arm._passes.arm_pass_utils import (
10+
create_node,
11+
get_first_fake_tensor,
12+
)
13+
from executorch.exir.dialects._ops import ops as exir_ops
14+
from executorch.exir.pass_base import ExportPass, PassResult
15+
16+
17+
class UnsqueezeBeforeRepeatPass(ExportPass):
18+
"""
19+
A TOSA TILE op only supports rank(in) == rank(out).
20+
To support Pytorch's repeat which can also add dimensions,
21+
we add an explicit view op before which adds the new dimensions.
22+
New dimensions are appendend at the front, see
23+
https://pytorch.org/docs/stable/generated/torch.Tensor.expand.html
24+
25+
Original:
26+
repeat(multiples)
27+
After pass:
28+
view(shape = [1]*num_new_dims + old_shape)
29+
repeat(multiples)
30+
"""
31+
32+
def call(self, graph_module: torch.fx.GraphModule):
33+
modified_graph = False
34+
for node in graph_module.graph.nodes:
35+
if node.op != "call_function":
36+
continue
37+
if node.target != exir_ops.edge.aten.repeat.default:
38+
continue
39+
40+
old_shape = list(get_first_fake_tensor(node.all_input_nodes[0]).shape)
41+
old_rank = len(old_shape)
42+
multiples = node.args[1]
43+
new_rank = len(multiples)
44+
if old_rank == new_rank:
45+
continue
46+
47+
num_new_dims = new_rank - old_rank
48+
new_shape = [1] * num_new_dims + old_shape
49+
50+
with graph_module.graph.inserting_before(node):
51+
view_node = create_node(
52+
graph_module.graph,
53+
exir_ops.edge.aten.view_copy.default,
54+
(node.all_input_nodes[0], new_shape),
55+
)
56+
node.replace_input_with(node.all_input_nodes[0], view_node)
57+
modified_graph = True
58+
59+
if modified_graph:
60+
graph_module.recompile()
61+
graph_module = super().call(graph_module).graph_module
62+
return PassResult(graph_module, modified_graph)

backends/arm/operators/op_repeat.py

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -32,37 +32,8 @@ def define_node(
3232
is_quant_node: bool,
3333
) -> None:
3434

35-
item_name = inputs[0].name
36-
shape = inputs[0].shape
37-
rank = len(shape)
3835
multiples = inputs[1].special
39-
new_rank = len(multiples)
40-
41-
assert new_rank >= rank
42-
43-
# TILE only supports rank(in) == rank(out). To add more dims, we need a reshape first.
44-
if new_rank > rank:
45-
# Add length 1 dimensions to shape to match multiples
46-
num_new_dims = new_rank - rank
47-
expanded_shape = tuple(
48-
1 if i < num_new_dims else shape[i - num_new_dims]
49-
for i in range(new_rank)
50-
)
51-
expanded_shape = tosa_shape(expanded_shape, output.dim_order)
52-
dtype = (
53-
ts.dtype_str_to_val("INT8")
54-
if is_quant_node
55-
else ts.dtype_str_to_val("FP32")
56-
)
57-
58-
rescale_out = tosa_graph.addIntermediate(expanded_shape, dtype)
59-
rescale_attr = ts.TosaSerializerAttribute()
60-
rescale_attr.ReshapeAttribute(expanded_shape)
61-
tosa_graph.addOperator(
62-
TosaOp.Op().RESHAPE, [item_name], [rescale_out.name], rescale_attr
63-
)
64-
item_name = rescale_out.name
6536

6637
attr = ts.TosaSerializerAttribute()
6738
attr.TileAttribute(tosa_shape(multiples, output.dim_order))
68-
tosa_graph.addOperator(TosaOp.Op().TILE, [item_name], [output.name], attr)
39+
tosa_graph.addOperator(TosaOp.Op().TILE, [inputs[0].name], [output.name], attr)

backends/arm/test/ops/test_repeat.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class Repeat(torch.nn.Module):
3737
(torch.randn(3), (2, 2)),
3838
(torch.randn(3), (1, 2, 3)),
3939
(torch.randn((3, 3)), (2, 2, 2)),
40+
(torch.randn((3, 3, 3)), (2, 1, 2, 4)),
4041
]
4142

4243
def forward(self, x: torch.Tensor, multiples: Sequence):
@@ -106,12 +107,20 @@ def test_repeat_tosa_MI(self, test_input, multiples):
106107
def test_repeat_tosa_BI(self, test_input, multiples):
107108
self._test_repeat_tosa_BI_pipeline(self.Repeat(), (test_input, multiples))
108109

109-
@parameterized.expand(Repeat.test_parameters)
110+
@parameterized.expand(Repeat.test_parameters[:-1])
110111
def test_repeat_u55_BI(self, test_input, multiples):
111112
self._test_repeat_ethosu_pipeline(
112113
common.get_u55_compile_spec(), self.Repeat(), (test_input, multiples)
113114
)
114115

116+
# Final test requires transpose which is not supported on u55.
117+
@parameterized.expand(Repeat.test_parameters[-1:])
118+
@unittest.expectedFailure
119+
def test_repeat_u55_BI_xfails(self, test_input, multiples):
120+
self._test_repeat_ethosu_pipeline(
121+
common.get_u55_compile_spec(), self.Repeat(), (test_input, multiples)
122+
)
123+
115124
@parameterized.expand(Repeat.test_parameters)
116125
def test_repeat_u85_BI(self, test_input, multiples):
117126
self._test_repeat_ethosu_pipeline(
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import unittest
7+
8+
import torch
9+
from executorch.backends.arm._passes.unsqueeze_before_repeat_pass import (
10+
UnsqueezeBeforeRepeatPass,
11+
)
12+
from executorch.backends.arm.test import common
13+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
14+
from executorch.backends.xnnpack.test.tester.tester import RunPasses
15+
16+
17+
class Repeat(torch.nn.Module):
18+
"""
19+
Basic repeat model.
20+
"""
21+
22+
def forward(self, x: torch.Tensor):
23+
return x.repeat(2, 2, 2, 2)
24+
25+
26+
class TestUnsqueezeBeforeRepeatPass(unittest.TestCase):
27+
def test_tosa_MI_insert_view(self):
28+
"""
29+
When rank(input) != number of repeated dimensions (=4 in Repeat module),
30+
insert view.
31+
"""
32+
module = Repeat()
33+
inputs = (torch.rand((2, 3, 4)),)
34+
test_pass_stage = RunPasses([UnsqueezeBeforeRepeatPass])
35+
(
36+
(
37+
ArmTester(
38+
module,
39+
example_inputs=inputs,
40+
compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"),
41+
)
42+
.export()
43+
.to_edge()
44+
.check(["aten_repeat_default"])
45+
.check_not(["aten_view_copy_default"])
46+
.run_passes(test_pass_stage)
47+
.check(["aten_repeat_default", "aten_view_copy_default"])
48+
)
49+
)
50+
51+
def test_tosa_MI_dont_insert_view(self):
52+
"""
53+
When rank(input) == number of repeated dimensions (=4 in Repeat module),
54+
DON'T insert view.
55+
"""
56+
module = Repeat()
57+
inputs = (torch.rand((2, 3, 4, 1)),)
58+
test_pass_stage = RunPasses([UnsqueezeBeforeRepeatPass])
59+
(
60+
(
61+
ArmTester(
62+
module,
63+
example_inputs=inputs,
64+
compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"),
65+
)
66+
.export()
67+
.to_edge()
68+
.check(["aten_repeat_default"])
69+
.check_not(["aten_view_copy_default"])
70+
.run_passes(test_pass_stage)
71+
.check(["aten_repeat_default"])
72+
.check_not(["aten_view_copy_default"])
73+
)
74+
)

0 commit comments

Comments
 (0)