Skip to content

Commit 7da316e

Browse files
Fix incorrect output shapes for var and expand
Decompositions of var and expand_copy produced different output shapes than the original ops. Signed-off-by: Oscar Andersson <[email protected]> Change-Id: Ide51f147e3c2a8b794bc60660b6053eb5d47ecff
1 parent 5190106 commit 7da316e

File tree

4 files changed

+13
-13
lines changed

4 files changed

+13
-13
lines changed

backends/arm/_passes/convert_expand_copy_to_repeat.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
from typing import cast
1010

11-
from executorch.backends.arm.tosa_mapping import extract_tensor_meta
1211
from executorch.exir.dialects._ops import ops as exir_ops
1312
from executorch.exir.pass_base import ExportPass
1413

@@ -25,14 +24,14 @@ def call_operator(self, op, args, kwargs, meta):
2524
if op != self.expand_copy:
2625
return super().call_operator(op, args, kwargs, meta)
2726

28-
_, shape, _ = extract_tensor_meta(meta.data)
27+
input_shape = args[0].data.shape
2928
multiples = cast(list[int], args[1])
3029
expanded_rank = len(multiples)
3130

32-
# Expanded shape is 'shape' front-padded with ones.
33-
padding = expanded_rank - len(shape)
31+
# Expanded shape is 'input_shape' front-padded with ones.
32+
padding = expanded_rank - len(input_shape)
3433
extended_shape = [
35-
shape[i] if i >= 0 else 1 for i in range(-padding, len(shape))
34+
input_shape[i] if i >= 0 else 1 for i in range(-padding, len(input_shape))
3635
]
3736

3837
# To convert expand arg to repeat arg, non-repeated dims should have

backends/arm/_passes/decompose_var_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def call_operator(self, op, args, kwargs, meta):
8383
sum = super().call_operator(sum_op, (squared_diff, dim, keepdim), {}, meta)
8484
full = super().call_operator(
8585
full_op,
86-
([1 for _ in shape], 1 / max(0, N - correction)),
86+
([], 1 / max(0, N - correction)),
8787
{"dtype": dtype},
8888
meta,
8989
)

backends/arm/_passes/match_arg_ranks_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
9090
continue
9191

9292
# Calculate max rank of all inputs to node
93-
max_rank = 1
93+
max_rank = 0
9494
for arg in node.args:
9595
if isinstance(arg, Node):
9696
shape = get_first_fake_tensor(arg).shape

backends/arm/test/ops/test_expand.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,13 @@ class TestSimpleExpand(unittest.TestCase):
3434
class Expand(torch.nn.Module):
3535
# (input tensor, multiples)
3636
test_parameters = [
37-
(torch.ones(1), (2,)),
38-
(torch.ones(1, 4), (1, -1)),
39-
(torch.ones(1, 1, 2, 2), (4, 3, -1, 2)),
40-
(torch.ones(1), (2, 2, 4)),
41-
(torch.ones(3, 2, 4, 1), (-1, -1, -1, 3)),
42-
(torch.ones(1, 1, 192), (1, -1, -1)),
37+
(torch.rand(1), (2,)),
38+
(torch.randn(1, 4), (1, -1)),
39+
(torch.rand(1, 1, 2, 2), (4, 3, -1, 2)),
40+
(torch.randn(1), (2, 2, 4)),
41+
(torch.rand(3, 2, 4, 1), (-1, -1, -1, 3)),
42+
(torch.randn(1, 1, 192), (1, -1, -1)),
43+
(torch.randn(10, 1, 1, 97), (-1, 4, -1, -1)),
4344
]
4445

4546
def forward(self, x: torch.Tensor, multiples: Sequence):

0 commit comments

Comments
 (0)