Skip to content

Commit 979ab42

Browse files
committed
correcting selu, hard_tanh ops, adding tests for sigmoid, selu, elu and hard_tanh
1 parent 36f9a3f commit 979ab42

File tree

7 files changed

+227
-35
lines changed

7 files changed

+227
-35
lines changed

py/torch_tensorrt/fx/converters/activation.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -184,25 +184,14 @@ def add_sigmoid(network, target, kwargs, name):
184184

185185
def add_hard_tanh(network, target, kwargs, name):
186186
input_val = kwargs["input"]
187-
operation_type = trt.ActivationType.TANH
188-
return add_activation_layer(network, input_val, operation_type, target, name)
189-
190-
191-
def add_sigmoid(network, target, kwargs, name):
192-
input_val = kwargs["input"]
193-
187+
alpha = kwargs["min_val"]
188+
beta = kwargs["max_val"]
194189
if not isinstance(input_val, TRTTensor):
195190
raise RuntimeError(
196-
f"Hard sigmoid received input {input_val} that is not part "
191+
f"hardtanh received input {input_val} that is not part "
197192
"of the TensorRT region!"
198193
)
199-
194+
operation_type = trt.ActivationType.CLIP
200195
return add_activation_layer(
201-
network,
202-
input_val,
203-
trt.ActivationType.HARD_SIGMOID,
204-
target,
205-
name,
206-
alpha=1 / 6,
207-
beta=0.5,
196+
network, input_val, operation_type, target, name, alpha, beta
208197
)

py/torch_tensorrt/fx/converters/aten_ops_converters.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -503,26 +503,19 @@ def aten_ops_elu(
503503
kwargs: Dict[str, Argument],
504504
name: str,
505505
) -> Union[TRTTensor, Sequence[TRTTensor]]:
506-
kwargs_new = {
507-
"input": args[0],
508-
}
506+
if len(args) > 2:
507+
kwargs_new = {
508+
"input": args[0],
509+
}
510+
return add_selu(network, target, kwargs_new, name)
511+
else:
512+
kwargs_new = {
513+
"input": args[0],
514+
"alpha": args[1],
515+
}
509516
return add_elu(network, target, kwargs_new, name)
510517

511518

512-
@tensorrt_converter(torch.ops.aten.selu.default)
513-
def aten_ops_selu(
514-
network: TRTNetwork,
515-
target: Target,
516-
args: Tuple[Argument, ...],
517-
kwargs: Dict[str, Argument],
518-
name: str,
519-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
520-
kwargs_new = {
521-
"input": args[0],
522-
}
523-
return add_selu(network, target, kwargs_new, name)
524-
525-
526519
@tensorrt_converter(torch.ops.aten.gelu.default)
527520
def aten_ops_gelu(
528521
network: TRTNetwork,
@@ -551,7 +544,7 @@ def aten_ops_tanh(
551544
return add_tanh(network, target, kwargs_new, name)
552545

553546

554-
@tensorrt_converter(torch.ops.aten.sigmoid.default)
547+
@tensorrt_converter(torch.ops.aten.hardtanh.default)
555548
def aten_ops_hard_tanh(
556549
network: TRTNetwork,
557550
target: Target,
@@ -561,6 +554,8 @@ def aten_ops_hard_tanh(
561554
) -> Union[TRTTensor, Sequence[TRTTensor]]:
562555
kwargs_new = {
563556
"input": args[0],
557+
"min_val": args[1],
558+
"max_val": args[2],
564559
}
565560
return add_hard_tanh(network, target, kwargs_new, name)
566561

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import torch
2+
import torch.nn as nn
3+
from torch.testing._internal.common_utils import run_tests
4+
from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
5+
6+
7+
class TestELUConverter(DispatchTestCase):
8+
def test_elu(self):
9+
class TestModule(nn.Module):
10+
def forward(self, x):
11+
return nn.functional.elu(x)
12+
13+
inputs = [torch.randn(1, 10)]
14+
self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.elu.default})
15+
16+
def test_elu_with_dynamic_shape(self):
17+
class TestModule(nn.Module):
18+
def forward(self, x):
19+
return nn.functional.elu(x)
20+
21+
input_specs = [
22+
InputTensorSpec(
23+
shape=(-1, -1, -1),
24+
dtype=torch.float32,
25+
shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))],
26+
),
27+
]
28+
self.run_test_with_dynamic_shape(
29+
TestModule(), input_specs, expected_ops={torch.ops.aten.elu.default}
30+
)
31+
32+
def test_elu_with_dynamic_shape_four_dimensions(self):
33+
class TestModule(nn.Module):
34+
def forward(self, x):
35+
return nn.functional.elu(x)
36+
37+
input_specs = [
38+
InputTensorSpec(
39+
shape=(-1, -1, -1, -1),
40+
dtype=torch.float32,
41+
shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))],
42+
),
43+
]
44+
45+
self.run_test_with_dynamic_shape(
46+
TestModule(), input_specs, expected_ops={torch.ops.aten.elu.default}
47+
)
48+
49+
50+
if __name__ == "__main__":
51+
run_tests()
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import torch
2+
import torch.nn as nn
3+
from torch.testing._internal.common_utils import run_tests
4+
from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
5+
6+
7+
class TestHardTanHConverter(DispatchTestCase):
8+
def test_hardtanh(self):
9+
class TestModule(nn.Module):
10+
def forward(self, x):
11+
return nn.functional.hardtanh(x)
12+
13+
inputs = [torch.randn(1, 10)]
14+
self.run_test(
15+
TestModule(), inputs, expected_ops={torch.ops.aten.hardtanh.default}
16+
)
17+
18+
def test_hardtanh_with_dynamic_shape(self):
19+
class TestModule(nn.Module):
20+
def forward(self, x):
21+
return nn.functional.hardtanh(x)
22+
23+
input_specs = [
24+
InputTensorSpec(
25+
shape=(-1, -1, -1),
26+
dtype=torch.float32,
27+
shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))],
28+
),
29+
]
30+
self.run_test_with_dynamic_shape(
31+
TestModule(), input_specs, expected_ops={torch.ops.aten.hardtanh.default}
32+
)
33+
34+
def test_hardtanh_with_dynamic_shape_four_dimensions(self):
35+
class TestModule(nn.Module):
36+
def forward(self, x):
37+
return nn.functional.hardtanh(x)
38+
39+
input_specs = [
40+
InputTensorSpec(
41+
shape=(-1, -1, -1, -1),
42+
dtype=torch.float32,
43+
shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))],
44+
),
45+
]
46+
47+
self.run_test_with_dynamic_shape(
48+
TestModule(), input_specs, expected_ops={torch.ops.aten.hardtanh.default}
49+
)
50+
51+
52+
if __name__ == "__main__":
53+
run_tests()

py/torch_tensorrt/fx/test/converters/aten_op/test_leaky_relu_aten.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
55

66

7-
class TestReLUConverter(DispatchTestCase):
7+
class TestLeakyReLUConverter(DispatchTestCase):
88
def test_leaky_relu(self):
99
class TestModule(nn.Module):
1010
def forward(self, x):
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import torch
2+
import torch.nn as nn
3+
from torch.testing._internal.common_utils import run_tests
4+
from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
5+
6+
7+
class TestSeLUConverter(DispatchTestCase):
8+
def test_selu(self):
9+
class TestModule(nn.Module):
10+
def forward(self, x):
11+
return nn.functional.selu(x)
12+
13+
inputs = [torch.randn(1, 10)]
14+
self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.elu.default})
15+
16+
def test_selu_with_dynamic_shape(self):
17+
class TestModule(nn.Module):
18+
def forward(self, x):
19+
return nn.functional.selu(x)
20+
21+
input_specs = [
22+
InputTensorSpec(
23+
shape=(-1, -1, -1),
24+
dtype=torch.float32,
25+
shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))],
26+
),
27+
]
28+
self.run_test_with_dynamic_shape(
29+
TestModule(), input_specs, expected_ops={torch.ops.aten.elu.default}
30+
)
31+
32+
def test_selu_with_dynamic_shape_four_dimensions(self):
33+
class TestModule(nn.Module):
34+
def forward(self, x):
35+
return nn.functional.selu(x)
36+
37+
input_specs = [
38+
InputTensorSpec(
39+
shape=(-1, -1, -1, -1),
40+
dtype=torch.float32,
41+
shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))],
42+
),
43+
]
44+
45+
self.run_test_with_dynamic_shape(
46+
TestModule(), input_specs, expected_ops={torch.ops.aten.elu.default}
47+
)
48+
49+
50+
if __name__ == "__main__":
51+
run_tests()
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import torch
2+
import torch.nn as nn
3+
from torch.testing._internal.common_utils import run_tests
4+
from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
5+
6+
7+
class TestSigmoidConverter(DispatchTestCase):
8+
def test_sigmoid(self):
9+
class TestModule(nn.Module):
10+
def forward(self, x):
11+
return nn.functional.sigmoid(x)
12+
13+
inputs = [torch.randn(1, 10)]
14+
self.run_test(
15+
TestModule(), inputs, expected_ops={torch.ops.aten.sigmoid.default}
16+
)
17+
18+
def test_sigmoid_with_dynamic_shape(self):
19+
class TestModule(nn.Module):
20+
def forward(self, x):
21+
return nn.functional.sigmoid(x)
22+
23+
input_specs = [
24+
InputTensorSpec(
25+
shape=(-1, -1, -1),
26+
dtype=torch.float32,
27+
shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))],
28+
),
29+
]
30+
self.run_test_with_dynamic_shape(
31+
TestModule(), input_specs, expected_ops={torch.ops.aten.sigmoid.default}
32+
)
33+
34+
def test_sigmoid_with_dynamic_shape_four_dimensions(self):
35+
class TestModule(nn.Module):
36+
def forward(self, x):
37+
return nn.functional.sigmoid(x)
38+
39+
input_specs = [
40+
InputTensorSpec(
41+
shape=(-1, -1, -1, -1),
42+
dtype=torch.float32,
43+
shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))],
44+
),
45+
]
46+
47+
self.run_test_with_dynamic_shape(
48+
TestModule(), input_specs, expected_ops={torch.ops.aten.sigmoid.default}
49+
)
50+
51+
52+
if __name__ == "__main__":
53+
run_tests()

0 commit comments

Comments
 (0)