Skip to content

Commit 21605a3

Browse files
committed
Converter reorg elu
1 parent 1696cd2 commit 21605a3

File tree

5 files changed

+119
-5
lines changed

5 files changed

+119
-5
lines changed

py/torch_tensorrt/fx/converters/acc_ops_converters.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1045,11 +1045,14 @@ def acc_ops_elu(
10451045
kwargs: Dict[str, Argument],
10461046
name: str,
10471047
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1048-
input_val = kwargs["input"]
1049-
alpha = kwargs["alpha"]
1050-
operation_type = trt.ActivationType.ELU
1051-
return activation.convert_activation(
1052-
network, target, SourceIR.ACC, name, operation_type, input_val, alpha=alpha
1048+
1049+
return activation.elu(
1050+
network,
1051+
target,
1052+
SourceIR.ACC,
1053+
name,
1054+
kwargs["input"],
1055+
kwargs["alpha"],
10531056
)
10541057

10551058

py/torch_tensorrt/fx/converters/aten_ops_converters.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,25 @@ def aten_ops_div(
170170
)
171171

172172

173+
@tensorrt_converter(torch.ops.aten.elu.default)
174+
def aten_ops_elu(
175+
network: TRTNetwork,
176+
target: Target,
177+
args: Tuple[Argument, ...],
178+
kwargs: Dict[str, Argument],
179+
name: str,
180+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
181+
182+
return activation.elu(
183+
network,
184+
target,
185+
SourceIR.ATEN,
186+
name,
187+
args[0],
188+
args[1],
189+
)
190+
191+
173192
@tensorrt_converter(torch.ops.aten.floor_divide.default)
174193
def aten_ops_floor_div(
175194
network: TRTNetwork,

py/torch_tensorrt/fx/converters/impl/activation.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,28 @@ def relu_dyn_range_fn(dyn_range):
9090
input_val,
9191
dyn_range_fn=relu_dyn_range_fn,
9292
)
93+
94+
95+
def elu(
96+
network: TRTNetwork,
97+
target: Target,
98+
source_ir: Optional[SourceIR],
99+
name: str,
100+
input_val: TRTTensor,
101+
alpha: Optional[Any],
102+
):
103+
operation_type = trt.ActivationType.ELU
104+
105+
def elu_dyn_range_fn(dyn_range):
106+
return (torch.nn.ELU(dyn_range[0]), torch.nn.ELU(dyn_range[1]))
107+
108+
return convert_activation(
109+
network,
110+
target,
111+
source_ir,
112+
name,
113+
operation_type,
114+
input_val,
115+
alpha,
116+
dyn_range_fn=elu_dyn_range_fn,
117+
)

py/torch_tensorrt/fx/converters/nn_ops_converters.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,19 @@ def relu(network, submod, args, kwargs, layer_name):
2222
name=layer_name,
2323
input_val=kwargs["input"],
2424
)
25+
26+
27+
@tensorrt_converter(torch.nn.functional.elu)
28+
@tensorrt_converter(torch.nn.modules.activation.ELU)
29+
def relu(network, submod, args, kwargs, layer_name):
30+
# args/kwargs should have already been normalized to kwargs
31+
assert len(args) == 0
32+
33+
return activation.relu(
34+
network=network,
35+
target="torch.nn.functional.elu",
36+
source_ir=SourceIR.NN,
37+
name=layer_name,
38+
input_val=kwargs["input"],
39+
)
40+
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import torch
2+
import torch.nn as nn
3+
from torch.testing._internal.common_utils import run_tests
4+
from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
5+
6+
7+
class TestELUConverter(DispatchTestCase):
8+
def test_elu(self):
9+
class TestModule(nn.Module):
10+
def forward(self, x):
11+
return nn.functional.elu(x)
12+
13+
inputs = [torch.randn(1, 10)]
14+
self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.elu.default})
15+
16+
def test_elu_with_dynamic_shape(self):
17+
class TestModule(nn.Module):
18+
def forward(self, x):
19+
return nn.functional.elu(x)
20+
21+
input_specs = [
22+
InputTensorSpec(
23+
shape=(-1, -1, -1),
24+
dtype=torch.float32,
25+
shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))],
26+
),
27+
]
28+
self.run_test_with_dynamic_shape(
29+
TestModule(), input_specs, expected_ops={torch.ops.aten.elu.default}
30+
)
31+
32+
def test_elu_with_dynamic_shape_four_dimensions(self):
33+
class TestModule(nn.Module):
34+
def forward(self, x):
35+
return nn.functional.elu(x)
36+
37+
input_specs = [
38+
InputTensorSpec(
39+
shape=(-1, -1, -1, -1),
40+
dtype=torch.float32,
41+
shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))],
42+
),
43+
]
44+
45+
self.run_test_with_dynamic_shape(
46+
TestModule(), input_specs, expected_ops={torch.ops.aten.elu.default}
47+
)
48+
49+
50+
if __name__ == "__main__":
51+
run_tests()

0 commit comments

Comments
 (0)