Skip to content

Commit 2553b85

Browse files
Arm backend: extend Softmax to handle dim < 0
Differential Revision: D61852817 Pull Request resolved: #4819
1 parent 66b2f73 commit 2553b85

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

backends/arm/operators/op_softmax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def define_node(
3333
input_name = inputs[0].name
3434
dim_order = inputs[0].dim_order
3535
input_shape = tosa_shape(inputs[0].shape, dim_order)
36-
dim_value = dim_order.index(inputs[1].number)
36+
dim_value = dim_order.index(inputs[1].number % len(dim_order))
3737

3838
## softmax = exp(logits - max(logits)) / reduce_sum(exp(logits - max(logits)), -1)
3939
# FP32

backends/arm/test/ops/test_softmax.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

8-
import logging
98
import unittest
109

1110
from typing import Tuple
@@ -15,15 +14,17 @@
1514
from executorch.backends.arm.test.tester.arm_tester import ArmTester
1615
from parameterized import parameterized
1716

18-
logger = logging.getLogger(__name__)
19-
logger.setLevel(logging.INFO)
2017

2118
test_data_suite = [
2219
# (test_name, test_data, dim)
23-
("zeros", torch.zeros(10, 10, 10, 10), 1),
20+
("zeros", torch.zeros(10, 10, 10, 10), 0),
21+
("zeros_neg_dim", torch.zeros(10, 10, 10, 10), -4),
2422
("ones", torch.ones(10, 10, 10, 10), 1),
23+
("ones_neg_dim", torch.ones(10, 10, 10, 10), -1),
2524
("rand", torch.rand(10, 10, 10, 10), 2),
25+
("rand_neg_dim", torch.rand(10, 10, 10, 10), -2),
2626
("randn", torch.randn(10, 10, 10, 10), 3),
27+
("randn_neg_dim", torch.randn(10, 10, 10, 10), -3),
2728
]
2829

2930

0 commit comments

Comments
 (0)