Skip to content

Commit 2bba527

Browse files
add dynamic support for floor/logical_not/sign/round/isinf/isnan (#2963)
1 parent d4b5e40 commit 2bba527

File tree

7 files changed

+180
-6
lines changed

7 files changed

+180
-6
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1721,7 +1721,7 @@ def aten_ops_ceil(
17211721
)
17221722

17231723

1724-
@dynamo_tensorrt_converter(torch.ops.aten.floor.default)
1724+
@dynamo_tensorrt_converter(torch.ops.aten.floor.default, supports_dynamic_shapes=True)
17251725
def aten_ops_floor(
17261726
ctx: ConversionContext,
17271727
target: Target,
@@ -1738,7 +1738,9 @@ def aten_ops_floor(
17381738
)
17391739

17401740

1741-
@dynamo_tensorrt_converter(torch.ops.aten.logical_not.default)
1741+
@dynamo_tensorrt_converter(
1742+
torch.ops.aten.logical_not.default, supports_dynamic_shapes=True
1743+
)
17421744
def aten_ops_logical_not(
17431745
ctx: ConversionContext,
17441746
target: Target,
@@ -1755,7 +1757,7 @@ def aten_ops_logical_not(
17551757
)
17561758

17571759

1758-
@dynamo_tensorrt_converter(torch.ops.aten.sign.default)
1760+
@dynamo_tensorrt_converter(torch.ops.aten.sign.default, supports_dynamic_shapes=True)
17591761
def aten_ops_sign(
17601762
ctx: ConversionContext,
17611763
target: Target,
@@ -1772,7 +1774,7 @@ def aten_ops_sign(
17721774
)
17731775

17741776

1775-
@dynamo_tensorrt_converter(torch.ops.aten.round.default)
1777+
@dynamo_tensorrt_converter(torch.ops.aten.round.default, supports_dynamic_shapes=True)
17761778
def aten_ops_round(
17771779
ctx: ConversionContext,
17781780
target: Target,
@@ -1789,7 +1791,7 @@ def aten_ops_round(
17891791
)
17901792

17911793

1792-
@dynamo_tensorrt_converter(torch.ops.aten.isinf.default)
1794+
@dynamo_tensorrt_converter(torch.ops.aten.isinf.default, supports_dynamic_shapes=True)
17931795
def aten_ops_isinf(
17941796
ctx: ConversionContext,
17951797
target: Target,
@@ -1806,7 +1808,7 @@ def aten_ops_isinf(
18061808
)
18071809

18081810

1809-
@dynamo_tensorrt_converter(torch.ops.aten.isnan.default)
1811+
@dynamo_tensorrt_converter(torch.ops.aten.isnan.default, supports_dynamic_shapes=True)
18101812
def aten_ops_isnan(
18111813
ctx: ConversionContext,
18121814
target: Target,

tests/py/dynamo/conversion/test_floor_aten.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch.nn as nn
33
from parameterized import parameterized
44
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt import Input
56

67
from .harness import DispatchTestCase
78

@@ -26,6 +27,31 @@ def forward(self, input):
2627
inputs,
2728
)
2829

30+
@parameterized.expand(
31+
[
32+
((10,), (11,), (12,)),
33+
((1, 3, 4), (2, 3, 5), (3, 4, 6)),
34+
((2, 3, 4, 5), (3, 5, 4, 5), (4, 6, 4, 5)),
35+
]
36+
)
37+
def test_floor_dynamic_shape(self, min_shape, opt_shape, max_shape):
38+
class floor(nn.Module):
39+
def forward(self, input):
40+
return torch.ops.aten.floor.default(input)
41+
42+
input_specs = [
43+
Input(
44+
dtype=torch.float32,
45+
min_shape=min_shape,
46+
opt_shape=opt_shape,
47+
max_shape=max_shape,
48+
),
49+
]
50+
self.run_test_with_dynamic_shape(
51+
floor(),
52+
input_specs,
53+
)
54+
2955
@parameterized.expand(
3056
[
3157
((10,), torch.int, 0, 5),

tests/py/dynamo/conversion/test_isinf_aten.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch.nn as nn
33
from parameterized import parameterized
44
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt import Input
56

67
from .harness import DispatchTestCase
78

@@ -37,6 +38,49 @@ def forward(self, input):
3738
inputs,
3839
)
3940

41+
def test_isinf_dynamic_shape_float(self):
42+
class isinf(nn.Module):
43+
def forward(self, input):
44+
return torch.ops.aten.isinf.default(input)
45+
46+
inputs = [
47+
Input(
48+
min_shape=(1, 2, 3),
49+
opt_shape=(3, 2, 3),
50+
max_shape=(5, 3, 3),
51+
dtype=torch.float32,
52+
torch_tensor=torch.tensor(
53+
([[[2.7, float("-inf"), 1.1], [4.7, -2.3, float("inf")]]]),
54+
dtype=torch.float32,
55+
).cuda(),
56+
)
57+
]
58+
self.run_test_with_dynamic_shape(
59+
isinf(),
60+
inputs,
61+
use_example_tensors=False,
62+
)
63+
64+
def test_isinf_dynamic_shape_int(self):
65+
class isinf(nn.Module):
66+
def forward(self, input):
67+
return torch.ops.aten.isinf.default(input)
68+
69+
inputs = [
70+
Input(
71+
min_shape=(1, 2),
72+
opt_shape=(3, 2),
73+
max_shape=(5, 3),
74+
dtype=torch.int,
75+
torch_tensor=torch.tensor(([[-3, 2]]), dtype=torch.int).cuda(),
76+
)
77+
]
78+
self.run_test_with_dynamic_shape(
79+
isinf(),
80+
inputs,
81+
use_example_tensors=False,
82+
)
83+
4084
@parameterized.expand(
4185
[
4286
((10,), torch.int, 0, 5),

tests/py/dynamo/conversion/test_isnan_aten.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch.nn as nn
44
from parameterized import parameterized
55
from torch.testing._internal.common_utils import run_tests
6+
from torch_tensorrt import Input
67

78
from .harness import DispatchTestCase
89

@@ -39,6 +40,29 @@ def forward(self, input):
3940
inputs,
4041
)
4142

43+
def test_isnan_dynamic_shape_float(self):
44+
class isnan(nn.Module):
45+
def forward(self, input):
46+
return torch.ops.aten.isnan.default(input)
47+
48+
inputs = [
49+
Input(
50+
min_shape=(1, 2, 3),
51+
opt_shape=(3, 2, 3),
52+
max_shape=(5, 3, 3),
53+
dtype=torch.float32,
54+
torch_tensor=torch.tensor(
55+
([[[3.2, float("nan"), 3.1], [float("inf"), 1.1, float("nan")]]]),
56+
dtype=torch.float32,
57+
).cuda(),
58+
)
59+
]
60+
self.run_test_with_dynamic_shape(
61+
isnan(),
62+
inputs,
63+
use_example_tensors=False,
64+
)
65+
4266
@parameterized.expand(
4367
[
4468
(torch.full((2, 2), float("nan"), dtype=torch.float32),),

tests/py/dynamo/conversion/test_logical_not_aten.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch.nn as nn
33
from parameterized import parameterized
44
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt import Input
56

67
from .harness import DispatchTestCase
78

@@ -60,6 +61,31 @@ def forward(self, input):
6061
inputs,
6162
)
6263

64+
@parameterized.expand(
65+
[
66+
((10,), (11,), (13,)),
67+
((1, 5), (2, 5), (3, 5)),
68+
((2, 3, 4), (2, 3, 5), (3, 4, 6)),
69+
]
70+
)
71+
def test_logical_not_dynamic_shape(self, min_shape, opt_shape, max_shape):
72+
class logical_not(nn.Module):
73+
def forward(self, input):
74+
return torch.ops.aten.logical_not.default(input)
75+
76+
input_specs = [
77+
Input(
78+
dtype=torch.float32,
79+
min_shape=min_shape,
80+
opt_shape=opt_shape,
81+
max_shape=max_shape,
82+
),
83+
]
84+
self.run_test_with_dynamic_shape(
85+
logical_not(),
86+
input_specs,
87+
)
88+
6389

6490
if __name__ == "__main__":
6591
run_tests()

tests/py/dynamo/conversion/test_round_aten.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch.nn as nn
33
from parameterized import parameterized
44
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt import Input
56

67
from .harness import DispatchTestCase
78

@@ -26,6 +27,31 @@ def forward(self, input):
2627
inputs,
2728
)
2829

30+
@parameterized.expand(
31+
[
32+
((10,), (11,), (12,)),
33+
((1, 3, 4), (2, 3, 5), (3, 4, 6)),
34+
((2, 3, 4, 5), (3, 5, 4, 5), (4, 6, 4, 5)),
35+
]
36+
)
37+
def test_round_dynamic_shape(self, min_shape, opt_shape, max_shape):
38+
class round(nn.Module):
39+
def forward(self, input):
40+
return torch.ops.aten.round.default(input)
41+
42+
input_specs = [
43+
Input(
44+
dtype=torch.float32,
45+
min_shape=min_shape,
46+
opt_shape=opt_shape,
47+
max_shape=max_shape,
48+
),
49+
]
50+
self.run_test_with_dynamic_shape(
51+
round(),
52+
input_specs,
53+
)
54+
2955
@parameterized.expand(
3056
[
3157
((10,), torch.int, 0, 5),

tests/py/dynamo/conversion/test_sign_aten.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch.nn as nn
33
from parameterized import parameterized
44
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt import Input
56

67
from .harness import DispatchTestCase
78

@@ -26,6 +27,31 @@ def forward(self, input):
2627
inputs,
2728
)
2829

30+
@parameterized.expand(
31+
[
32+
((10,), (11,), (12,)),
33+
((1, 3, 4), (2, 3, 5), (3, 4, 6)),
34+
((2, 3, 4, 5), (3, 5, 4, 5), (4, 6, 4, 5)),
35+
]
36+
)
37+
def test_sign_dynamic_shape(self, min_shape, opt_shape, max_shape):
38+
class sign(nn.Module):
39+
def forward(self, input):
40+
return torch.ops.aten.sign.default(input)
41+
42+
input_specs = [
43+
Input(
44+
dtype=torch.float32,
45+
min_shape=min_shape,
46+
opt_shape=opt_shape,
47+
max_shape=max_shape,
48+
),
49+
]
50+
self.run_test_with_dynamic_shape(
51+
sign(),
52+
input_specs,
53+
)
54+
2955
@parameterized.expand(
3056
[
3157
((10,), torch.int, -2, 2),

0 commit comments

Comments
 (0)