Skip to content

Commit 1ba6d13

Browse files
apbosegs-olive
authored andcommitted
Converter reorg and softmax operation
softmax linting error fix
1 parent e0b34b1 commit 1ba6d13

File tree

4 files changed

+98
-32
lines changed

4 files changed

+98
-32
lines changed

py/torch_tensorrt/fx/converters/acc_ops_converters.py

Lines changed: 2 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from torch_tensorrt.fx.converters.impl.elementwise import fmod
3232
from torch_tensorrt.fx.converters.impl.normalization import batch_norm
3333
from torch_tensorrt.fx.converters.impl.normalization import layer_norm
34+
from torch_tensorrt.fx.converters.impl.normalization import softmax
3435
from torch_tensorrt.fx.converters.impl.unary import sign
3536
from torch_tensorrt.fx.converters.impl.elementwise.base import (
3637
convert_binary_elementwise,
@@ -671,37 +672,7 @@ def acc_ops_softmax(
671672
kwargs: Dict[str, Argument],
672673
name: str,
673674
) -> Union[TRTTensor, Sequence[TRTTensor]]:
674-
input_val = kwargs["input"]
675-
input_ranks = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0) # type: ignore[union-attr]
676-
677-
if not isinstance(input_val, TRTTensor):
678-
raise RuntimeError(
679-
f"softmax received input {input_val} that is not part "
680-
"of the TensorRT region!"
681-
)
682-
683-
# Used to get dim when dim is None. Copied from PyTorch softmax implementation.
684-
def get_softmax_dim(ndim: int) -> int:
685-
if ndim == 0 or ndim == 1 or ndim == 3:
686-
ret = 0
687-
else:
688-
ret = 1
689-
return ret
690-
691-
if kwargs["dim"] is None:
692-
dim = get_softmax_dim(input_ranks)
693-
else:
694-
dim = cast(int, kwargs["dim"])
695-
696-
dim = get_positive_dim(dim, input_ranks)
697-
if network.has_implicit_batch_dimension:
698-
assert dim != 0, "Can't apply softmax on batch dimension when it's implicit."
699-
dim -= 1
700-
701-
layer = network.add_softmax(input_val)
702-
layer.axes = 1 << dim
703-
set_layer_name(layer, target, name)
704-
return layer.get_output(0)
675+
return softmax(network, target, SourceIR.ACC, name, kwargs["input"], kwargs["dim"])
705676

706677

707678
@tensorrt_converter(acc_ops.tile)

py/torch_tensorrt/fx/converters/aten_ops_converters.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from torch_tensorrt.fx.converters.impl.elementwise import rsub
2828
from torch_tensorrt.fx.converters.impl.normalization import batch_norm
2929
from torch_tensorrt.fx.converters.impl.normalization import layer_norm
30+
from torch_tensorrt.fx.converters.impl.normalization import softmax
3031

3132
_LOGGER: logging.Logger = logging.getLogger(__name__)
3233

@@ -490,6 +491,17 @@ def aten_ops_rsub(
490491
return rsub(network, target, SourceIR.ATEN, name, args[0], args[1], alpha)
491492

492493

494+
@tensorrt_converter(torch.ops.aten._softmax.default)
495+
def aten_ops_softmax(
496+
network: TRTNetwork,
497+
target: Target,
498+
args: Tuple[Argument, ...],
499+
kwargs: Dict[str, Argument],
500+
name: str,
501+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
502+
return softmax(network, target, SourceIR.ATEN, name, args[0], args[1])
503+
504+
493505
@tensorrt_converter(torch.ops.aten.tanh.default)
494506
def aten_ops_tanh(
495507
network: TRTNetwork,

py/torch_tensorrt/fx/converters/impl/normalization/ops.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import operator
22
import warnings
33
from typing import cast, Union, Callable, Any, Optional, Sequence
4-
import logging
54

65
import numpy as np
76

@@ -273,3 +272,43 @@ def layer_norm_no_plugin(
273272
scale_layer,
274273
beta_tensor.get_output(0),
275274
)
275+
276+
277+
def softmax(
278+
network: TRTNetwork,
279+
target: Target,
280+
source_ir: Optional[SourceIR],
281+
name: str,
282+
input: TRTTensor,
283+
dim: Optional[Any] = None,
284+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
285+
input_ranks = len(input.shape) + (1 if network.has_implicit_batch_dimension else 0) # type: ignore[union-attr]
286+
287+
if not isinstance(input, TRTTensor):
288+
raise RuntimeError(
289+
f"softmax received input {input} that is not part "
290+
"of the TensorRT region!"
291+
)
292+
293+
# Used to get dim when dim is None. Copied from PyTorch softmax implementation.
294+
def get_softmax_dim(ndim: int) -> int:
295+
if ndim == 0 or ndim == 1 or ndim == 3:
296+
ret = 0
297+
else:
298+
ret = 1
299+
return ret
300+
301+
if dim is None:
302+
dim = get_softmax_dim(input_ranks)
303+
else:
304+
dim = cast(int, dim)
305+
306+
dim = get_positive_dim(dim, input_ranks)
307+
if network.has_implicit_batch_dimension:
308+
assert dim != 0, "Can't apply softmax on batch dimension when it's implicit."
309+
dim -= 1
310+
311+
layer = network.add_softmax(input)
312+
layer.axes = 1 << dim
313+
set_layer_name(layer, target, name)
314+
return layer.get_output(0)
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import torch
2+
from torch.testing._internal.common_utils import run_tests
3+
from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
4+
5+
6+
class TestSoftMaxConverter(DispatchTestCase):
7+
def test_softmax(self):
8+
class TestModule(torch.nn.Module):
9+
def __init__(self):
10+
super().__init__()
11+
self.softmax = torch.nn.Softmax(1)
12+
13+
def forward(self, x):
14+
return self.softmax(x)
15+
16+
inputs = [torch.randn(1, 3, 224, 224)]
17+
self.run_test(
18+
TestModule(), inputs, expected_ops={torch.ops.aten._softmax.default}
19+
)
20+
21+
def test_softmax_with_dynamic_shape(self):
22+
class TestModule(torch.nn.Module):
23+
def __init__(self):
24+
super().__init__()
25+
self.softmax = torch.nn.Softmax(2)
26+
27+
def forward(self, x):
28+
return self.softmax(x)
29+
30+
input_specs = [
31+
InputTensorSpec(
32+
shape=(-1, 3, -1, -1),
33+
dtype=torch.float32,
34+
shape_ranges=[((1, 3, 1, 1), (1, 3, 5, 5), (2, 3, 10, 10))],
35+
),
36+
]
37+
38+
self.run_test_with_dynamic_shape(
39+
TestModule(), input_specs, expected_ops={torch.ops.aten._softmax.default}
40+
)
41+
42+
43+
if __name__ == "__main__":
44+
run_tests()

0 commit comments

Comments
 (0)