Skip to content

Commit a0d8245

Browse files
apbosegs-olive
authored andcommitted
Adding selu converter
1 parent 77156ec commit a0d8245

File tree

5 files changed

+104
-7
lines changed

5 files changed

+104
-7
lines changed

py/torch_tensorrt/fx/converters/acc_ops_converters.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1064,16 +1064,15 @@ def acc_ops_selu(
10641064
kwargs: Dict[str, Argument],
10651065
name: str,
10661066
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1067-
input_val = kwargs["input"]
1068-
operation_type = trt.ActivationType.SELU
1069-
return activation.convert_activation(
1067+
1068+
return activation.selu(
10701069
network,
10711070
target,
10721071
SourceIR.ACC,
10731072
name,
1074-
operation_type,
1075-
input_val,
1073+
kwargs["input"],
10761074
)
1075+
10771076

10781077

10791078
@tensorrt_converter(acc_ops.softsign)

py/torch_tensorrt/fx/converters/aten_ops_converters.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,14 @@ def aten_ops_elu(
179179
name: str,
180180
) -> Union[TRTTensor, Sequence[TRTTensor]]:
181181

182+
if (len(args) > 2) :
183+
return activation.selu(
184+
network,
185+
target,
186+
SourceIR.ATEN,
187+
name,
188+
args[0],
189+
)
182190
return activation.elu(
183191
network,
184192
target,

py/torch_tensorrt/fx/converters/impl/activation.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,4 +114,28 @@ def elu_dyn_range_fn(dyn_range):
114114
input_val,
115115
alpha,
116116
dyn_range_fn=elu_dyn_range_fn,
117+
)
118+
119+
120+
def selu(
121+
network: TRTNetwork,
122+
target: Target,
123+
source_ir: Optional[SourceIR],
124+
name: str,
125+
input_val: TRTTensor,
126+
alpha: Optional[Any],
127+
):
128+
operation_type = trt.ActivationType.SELU
129+
130+
def elu_dyn_range_fn(dyn_range):
131+
return (torch.nn.SELU(dyn_range[0]), torch.nn.ELU(dyn_range[1]))
132+
133+
return convert_activation(
134+
network,
135+
target,
136+
source_ir,
137+
name,
138+
operation_type,
139+
input_val,
140+
dyn_range_fn=elu_dyn_range_fn,
117141
)

py/torch_tensorrt/fx/converters/nn_ops_converters.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,30 @@ def relu(network, submod, args, kwargs, layer_name):
2626

2727
@tensorrt_converter(torch.nn.functional.elu)
2828
@tensorrt_converter(torch.nn.modules.activation.ELU)
29-
def relu(network, submod, args, kwargs, layer_name):
29+
def elu(network, submod, args, kwargs, layer_name):
3030
# args/kwargs should have already been normalized to kwargs
3131
assert len(args) == 0
3232

33-
return activation.relu(
33+
return activation.elu(
3434
network=network,
3535
target="torch.nn.functional.elu",
3636
source_ir=SourceIR.NN,
3737
name=layer_name,
3838
input_val=kwargs["input"],
3939
)
4040

41+
42+
@tensorrt_converter(torch.nn.functional.selu)
43+
@tensorrt_converter(torch.nn.modules.activation.SELU)
44+
def selu(network, submod, args, kwargs, layer_name):
45+
# args/kwargs should have already been normalized to kwargs
46+
assert len(args) == 0
47+
48+
return activation.selu(
49+
network=network,
50+
target="torch.nn.functional.selu",
51+
source_ir=SourceIR.NN,
52+
name=layer_name,
53+
input_val=kwargs["input"],
54+
alpha = kwargs["alpha"]
55+
)
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import torch
2+
import torch.nn as nn
3+
from torch.testing._internal.common_utils import run_tests
4+
from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
5+
6+
7+
class TestSeLUConverter(DispatchTestCase):
8+
def test_selu(self):
9+
class TestModule(nn.Module):
10+
def forward(self, x):
11+
return nn.functional.selu(x)
12+
13+
inputs = [torch.randn(1, 10)]
14+
self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.elu.default})
15+
16+
def test_selu_with_dynamic_shape(self):
17+
class TestModule(nn.Module):
18+
def forward(self, x):
19+
return nn.functional.selu(x)
20+
21+
input_specs = [
22+
InputTensorSpec(
23+
shape=(-1, -1, -1),
24+
dtype=torch.float32,
25+
shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))],
26+
),
27+
]
28+
self.run_test_with_dynamic_shape(
29+
TestModule(), input_specs, expected_ops={torch.ops.aten.elu.default}
30+
)
31+
32+
def test_selu_with_dynamic_shape_four_dimensions(self):
33+
class TestModule(nn.Module):
34+
def forward(self, x):
35+
return nn.functional.selu(x)
36+
37+
input_specs = [
38+
InputTensorSpec(
39+
shape=(-1, -1, -1, -1),
40+
dtype=torch.float32,
41+
shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))],
42+
),
43+
]
44+
45+
self.run_test_with_dynamic_shape(
46+
TestModule(), input_specs, expected_ops={torch.ops.aten.elu.default}
47+
)
48+
49+
50+
if __name__ == "__main__":
51+
run_tests()

0 commit comments

Comments
 (0)