Skip to content

Commit b981b2e

Browse files
Aleksei-grovetyfreddan80
authored andcommitted
Arm backend: Fix bug in ConvertExpandCopyToRepeatPass
In the ConvertExpandCopyToRepeatPass the arguments for the repeat operation are formed incorrectly. For the torch.Tensor.expand operation passing -1 as the size for a dimension means that the size of that dimension does not change. For the DeiT-tiny case, torch.ones(1, 1, 192).expand(1, -1, -1) the pass will prepare arguments to the repeat operation as [1, -1, 1] which will cause an error, in this case the arguments should be [1, 1, 1].
1 parent 745d47a commit b981b2e

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

backends/arm/_passes/convert_expand_copy_to_repeat.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,11 @@ def call_operator(self, op, args, kwargs, meta):
3636
]
3737

3838
# To convert expand arg to repeat arg, non-repeated dims should have
39-
# multiples[dim] = 1.
39+
# multiples[dim] = 1. Passing -1 to expand arg means
40+
# not changing the size of that dimension.
4041
multiples = [
41-
multiples[i] if extended_shape[i] == 1 else 1 for i in range(expanded_rank)
42+
multiples[i] if multiples[i] != -1 and extended_shape[i] == 1 else 1
43+
for i in range(expanded_rank)
4244
]
4345
return super().call_operator(
4446
op=self.repeat, args=(args[0], multiples), kwargs=kwargs, meta=meta

backends/arm/test/ops/test_expand.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class Expand(torch.nn.Module):
3636
(torch.ones(1, 1, 2, 2), (4, 3, -1, 2)),
3737
(torch.ones(1), (2, 2, 4)),
3838
(torch.ones(3, 2, 4, 1), (-1, -1, -1, 3)),
39+
(torch.ones(1, 1, 192), (1, -1, -1)),
3940
]
4041

4142
def forward(self, x: torch.Tensor, multiples: Sequence):

0 commit comments

Comments
 (0)