Skip to content

Commit 31d30e2

Browse files
authored
feat: Exempt default softmax from decomposition (#2268)
1 parent a65c95c commit 31d30e2

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
1-
from typing import Any, Callable, Dict, Set
1+
from typing import Any, Callable, Dict, Set, Union
22

33
import torch
44
from torch._decomp import core_aten_decompositions
55
from torch._decomp import get_decompositions as get_torch_decompositions
6-
from torch._ops import OpOverload
6+
from torch._ops import OpOverload, OpOverloadPacket
77

88
aten = torch.ops.aten
99

1010
_core_aten_decompositions: Dict[
1111
OpOverload, Callable[[Any], Any]
1212
] = core_aten_decompositions()
13-
torch_enabled_decompositions: Set[OpOverload] = {
13+
torch_enabled_decompositions: Set[Union[OpOverload, OpOverloadPacket]] = {
1414
aten._adaptive_avg_pool2d_backward,
1515
aten.addcdiv,
1616
aten.addcdiv_,
@@ -140,7 +140,7 @@
140140
aten.smooth_l1_loss_backward,
141141
aten.soft_margin_loss,
142142
aten.soft_margin_loss_backward,
143-
aten._softmax,
143+
aten._softmax.out,
144144
aten._softmax_backward_data,
145145
aten.softplus,
146146
aten.softplus_backward,
@@ -176,7 +176,9 @@
176176
aten.full,
177177
aten.repeat,
178178
}
179-
torch_disabled_decompositions: Set[OpOverload] = set()
179+
torch_disabled_decompositions: Set[Union[OpOverload, OpOverloadPacket]] = {
180+
aten._softmax.default,
181+
}
180182

181183

182184
ENABLED_TORCH_DECOMPOSITIONS: Dict[

0 commit comments

Comments
 (0)