Skip to content

Commit d695f15

Browse files
authored
Add pass for decomposing (log)softmax
Differential Revision: D64472857 Pull Request resolved: #6287
1 parent 53a94af commit d695f15

File tree

8 files changed

+251
-121
lines changed

8 files changed

+251
-121
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
DecomposeLayerNormPass,
2424
)
2525
from executorch.backends.arm._passes.decompose_meandim_pass import DecomposeMeanDimPass
26+
from executorch.backends.arm._passes.decompose_softmaxes_pass import (
27+
DecomposeSoftmaxesPass,
28+
)
2629
from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass
2730
from executorch.backends.arm._passes.insert_squeeze_after_sum_pass import (
2831
InsertSqueezeAfterSumPass,
@@ -66,6 +69,7 @@ def transform_to_backend_pipeline(
6669
self.add_pass(DecomposeDivPass())
6770
self.add_pass(InsertSqueezeAfterSumPass())
6871
self.add_pass(ConvertSplitToSlicePass())
72+
self.add_pass(DecomposeSoftmaxesPass())
6973
for spec in compile_spec:
7074
if spec.key == "permute_memory_format":
7175
memory_format = spec.value.decode()
@@ -75,9 +79,10 @@ def transform_to_backend_pipeline(
7579
return self._transform(exported_program.graph_module)
7680

7781
def transform_for_annotation_pipeline(self, graph_module: torch.fx.GraphModule):
82+
self.add_pass(ScalarsToAttributePass())
7883
self.add_pass(DecomposeLayerNormPass())
7984
self.add_pass(DecomposeVarPass())
8085
self.add_pass(DecomposeMeanDimPass())
81-
self.add_pass(ScalarsToAttributePass())
8286
self.add_pass(DecomposeDivPass())
87+
self.add_pass(DecomposeSoftmaxesPass())
8388
return self._transform(graph_module)
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
from executorch.exir.pass_base import ExportPass
10+
11+
# For BI case
12+
torch_softmax = (torch.ops.aten.softmax.int, torch.ops.aten.log_softmax.int)
13+
14+
# For MI case
15+
edge_softmax = (
16+
exir_ops.edge.aten._softmax.default,
17+
exir_ops.edge.aten._log_softmax.default,
18+
)
19+
20+
log_softmax = (torch.ops.aten.log_softmax.int, exir_ops.edge.aten._log_softmax.default)
21+
22+
23+
def get_logsoftmax_ops(op) -> tuple:
24+
"""
25+
Returns the the (log_op, expo_op, sum_op, reciprocal_op), where the ops depends on if
26+
the logsoftmax op is in exir_ops torch.ops.aten.
27+
"""
28+
if op in edge_softmax:
29+
return (
30+
exir_ops.edge.aten.log.default,
31+
exir_ops.edge.aten.exp.default,
32+
exir_ops.edge.aten.sum.dim_IntList,
33+
exir_ops.edge.aten.reciprocal.default,
34+
exir_ops.edge.aten.mul.Tensor,
35+
)
36+
if op in torch_softmax:
37+
return (
38+
torch.ops.aten.log.default,
39+
torch.ops.aten.exp.default,
40+
torch.ops.aten.sum.dim_IntList,
41+
torch.ops.aten.reciprocal.default,
42+
torch.ops.aten.mul.Tensor,
43+
)
44+
raise RuntimeError(f"Can't get softmax decomposition ops for op {op}")
45+
46+
47+
class DecomposeSoftmaxesPass(ExportPass):
48+
"""
49+
This pass decomposes log softmax or softmax into more primitive ops.
50+
51+
Example:
52+
%op1 = exp(x)
53+
%op2 = sum(%op1, dim)
54+
%op3 = reciprocal(%op2)
55+
%op4 = mul(%op1, %op3)
56+
(in logsoftmax case: %op5 = log(%op4))
57+
"""
58+
59+
def call_operator(self, op, args, kwargs, meta):
60+
if op not in torch_softmax + edge_softmax:
61+
return super().call_operator(op, args, kwargs, meta)
62+
63+
log_op, exp_op, sum_op, reciprocal_op, mul_op = get_logsoftmax_ops(op)
64+
65+
_input = args[0]
66+
dim = [args[1]]
67+
68+
op1 = super().call_operator(exp_op, (_input,), {}, meta)
69+
op2 = super().call_operator(sum_op, (op1, dim, True), {}, meta)
70+
op3 = super().call_operator(reciprocal_op, (op2,), {}, meta)
71+
op4 = super().call_operator(mul_op, (op1, op3), {}, meta)
72+
if op in log_softmax:
73+
op4 = super().call_operator(log_op, (op4,), {}, meta)
74+
return op4

backends/arm/arm_partitioner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
6262
exir_ops.edge.aten.relu.default,
6363
exir_ops.edge.aten.rsqrt.default,
6464
exir_ops.edge.aten._softmax.default,
65+
exir_ops.edge.aten._log_softmax.default,
6566
exir_ops.edge.aten.slice_copy.Tensor,
6667
exir_ops.edge.aten.sub.Tensor,
6768
exir_ops.edge.aten.sum.dim_IntList,

backends/arm/operators/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
op_rsqrt,
3131
op_sigmoid,
3232
op_slice,
33-
op_softmax,
3433
op_squeeze,
3534
op_sub,
3635
op_sum,

backends/arm/operators/op_exp.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ def define_node(
4242
) -> None:
4343

4444
assert len(node.all_input_nodes) == 1
45-
assert len(node.users) == 1
4645

4746
if is_quant_node:
4847
# Assume quantized input is 8 bit.

backends/arm/operators/op_softmax.py

Lines changed: 0 additions & 99 deletions
This file was deleted.
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
from typing import Tuple
10+
11+
import torch
12+
from executorch.backends.arm.test import common
13+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
14+
from executorch.exir.backend.compile_spec_schema import CompileSpec
15+
from parameterized import parameterized
16+
17+
18+
test_data_suite = [
19+
# (test_name, test_data, dim)
20+
("zeros", torch.zeros(10, 10, 10, 10), 0),
21+
("zeros_neg_dim", torch.zeros(10, 10, 10, 10), -4),
22+
("ones", torch.ones(10, 10), 1),
23+
("rand_neg_dim", torch.rand(10, 10, 10), -1),
24+
("rand", torch.rand(10, 10, 10, 10), 2),
25+
("rand_neg_dim", torch.rand(10, 10, 2, 3), -2),
26+
("randn", torch.randn(10, 10, 5, 10), 3),
27+
("randn_neg_dim", torch.randn(1, 10, 10, 10), -3),
28+
]
29+
30+
31+
class TestLogSoftmax(unittest.TestCase):
32+
"""Tests logsoftmax."""
33+
34+
class LogSoftmax(torch.nn.Module):
35+
def __init__(self, dim: int = -1):
36+
super().__init__()
37+
self.logsoftmax = torch.nn.LogSoftmax(dim=dim)
38+
39+
def forward(self, x):
40+
return self.logsoftmax(x)
41+
42+
def _test_logsoftmax_tosa_MI_pipeline(
43+
self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
44+
):
45+
(
46+
ArmTester(
47+
module,
48+
example_inputs=test_data,
49+
compile_spec=common.get_tosa_compile_spec(),
50+
)
51+
.export()
52+
.check(["torch.ops.aten.log_softmax.int"])
53+
.check_not(["torch.ops.quantized_decomposed"])
54+
.to_edge()
55+
.partition()
56+
.check_not(["executorch_exir_dialects_edge__ops_aten__logsoftmax_default"])
57+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
58+
.to_executorch()
59+
.run_method_and_compare_outputs(inputs=test_data)
60+
)
61+
62+
def _test_logsoftmax_tosa_BI_pipeline(
63+
self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
64+
):
65+
(
66+
ArmTester(
67+
module,
68+
example_inputs=test_data,
69+
compile_spec=common.get_tosa_compile_spec(),
70+
)
71+
.quantize()
72+
.export()
73+
.check_not(["torch.ops.aten.log_softmax.int"])
74+
.check(["torch.ops.quantized_decomposed", "torch.ops.aten.mul.Tensor"])
75+
.to_edge()
76+
.partition()
77+
.check_not(["executorch_exir_dialects_edge__ops_aten__log_softmax_default"])
78+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
79+
.to_executorch()
80+
.run_method_and_compare_outputs(inputs=test_data, qtol=1)
81+
)
82+
83+
def _test_logsoftmax_tosa_ethos_BI_pipeline(
84+
self,
85+
compile_spec: list[CompileSpec],
86+
module: torch.nn.Module,
87+
test_data: Tuple[torch.tensor],
88+
):
89+
(
90+
ArmTester(
91+
module,
92+
example_inputs=test_data,
93+
compile_spec=compile_spec,
94+
)
95+
.quantize()
96+
.export()
97+
.check_not(["torch.ops.aten.log_softmax.int"])
98+
.check(["torch.ops.quantized_decomposed", "torch.ops.aten.mul.Tensor"])
99+
.to_edge()
100+
.partition()
101+
.check_not(["executorch_exir_dialects_edge__ops_aten__logsoftmax_default"])
102+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
103+
.to_executorch()
104+
)
105+
106+
def _test_logsoftmax_tosa_u55_BI_pipeline(
107+
self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
108+
):
109+
self._test_logsoftmax_tosa_ethos_BI_pipeline(
110+
common.get_u55_compile_spec(), module, test_data
111+
)
112+
113+
def _test_logsoftmax_tosa_u85_BI_pipeline(
114+
self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
115+
):
116+
self._test_logsoftmax_tosa_ethos_BI_pipeline(
117+
common.get_u85_compile_spec(), module, test_data
118+
)
119+
120+
@parameterized.expand(test_data_suite)
121+
def test_logsoftmax_tosa_MI(
122+
self,
123+
test_name: str,
124+
test_data: torch.Tensor,
125+
dim: int,
126+
):
127+
self._test_logsoftmax_tosa_MI_pipeline(self.LogSoftmax(dim=dim), (test_data,))
128+
129+
@parameterized.expand(test_data_suite)
130+
def test_logsoftmax_tosa_BI(
131+
self,
132+
test_name: str,
133+
test_data: torch.Tensor,
134+
dim: int,
135+
):
136+
self._test_logsoftmax_tosa_BI_pipeline(self.LogSoftmax(dim=dim), (test_data,))
137+
138+
@parameterized.expand(test_data_suite)
139+
def test_logsoftmax_tosa_u55_BI(
140+
self,
141+
test_name: str,
142+
test_data: torch.Tensor,
143+
dim: int,
144+
):
145+
self._test_logsoftmax_tosa_u55_BI_pipeline(
146+
self.LogSoftmax(dim=dim), (test_data,)
147+
)
148+
149+
@parameterized.expand(test_data_suite)
150+
def test_logsoftmax_tosa_u85_BI(
151+
self,
152+
test_name: str,
153+
test_data: torch.Tensor,
154+
dim: int,
155+
):
156+
self._test_logsoftmax_tosa_u55_BI_pipeline(
157+
self.LogSoftmax(dim=dim), (test_data,)
158+
)

0 commit comments

Comments
 (0)