Skip to content

Commit 5ee1e50

Browse files
committed
feat: dynamic shape support for pow/mod/eq operator
1 parent feb4d84 commit 5ee1e50

File tree

4 files changed

+173
-2
lines changed

4 files changed

+173
-2
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2002,6 +2002,7 @@ def aten_ops_div(
20022002
@dynamo_tensorrt_converter(
20032003
torch.ops.aten.pow.Tensor_Scalar, supports_dynamic_shapes=True
20042004
)
2005+
@dynamo_tensorrt_converter(operator.pow, supports_dynamic_shapes=True)
20052006
def aten_ops_pow(
20062007
ctx: ConversionContext,
20072008
target: Target,
@@ -2278,6 +2279,7 @@ def aten_ops_bitwise_not(
22782279

22792280
@dynamo_tensorrt_converter(torch.ops.aten.eq.Tensor)
22802281
@dynamo_tensorrt_converter(torch.ops.aten.eq.Scalar)
2282+
@dynamo_tensorrt_converter(operator.eq, supports_dynamic_shapes=True)
22812283
@enforce_tensor_types(
22822284
{
22832285
0: (TRTTensor,),
@@ -3149,8 +3151,13 @@ def aten_ops_copy(
31493151
)
31503152

31513153

3152-
@dynamo_tensorrt_converter(torch.ops.aten.remainder.Scalar)
3153-
@dynamo_tensorrt_converter(torch.ops.aten.remainder.Tensor)
3154+
@dynamo_tensorrt_converter(
3155+
torch.ops.aten.remainder.Scalar, supports_dynamic_shapes=True
3156+
)
3157+
@dynamo_tensorrt_converter(
3158+
torch.ops.aten.remainder.Tensor, supports_dynamic_shapes=True
3159+
)
3160+
@dynamo_tensorrt_converter(operator.mod, supports_dynamic_shapes=True)
31543161
@enforce_tensor_types(
31553162
{
31563163
0: (TRTTensor,),

tests/py/dynamo/conversion/test_eq_aten.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch.nn as nn
33
from parameterized import parameterized
44
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt import Input
56

67
from .harness import DispatchTestCase
78

@@ -61,6 +62,51 @@ def forward(self, lhs_val):
6162
inputs,
6263
)
6364

65+
@parameterized.expand(
66+
[
67+
(
68+
"2d_dim_dtype_half",
69+
(1, 1),
70+
(2, 2),
71+
(4, 4),
72+
torch.half,
73+
torch.bool,
74+
),
75+
(
76+
"3d_dim_dtype_float",
77+
(1, 1, 1),
78+
(1, 2, 3),
79+
(3, 3, 3),
80+
torch.float,
81+
torch.bool,
82+
),
83+
]
84+
)
85+
def test_eq_operator_dynamic_shape(
86+
self, _, min_shape, opt_shape, max_shape, type, output_type
87+
):
88+
class eq_operator(nn.Module):
89+
def forward(self, lhs_val, rhs_val):
90+
return lhs_val == rhs_val
91+
92+
input_specs = [
93+
Input(
94+
min_shape=min_shape,
95+
opt_shape=opt_shape,
96+
max_shape=max_shape,
97+
dtype=type,
98+
),
99+
Input(
100+
min_shape=min_shape,
101+
opt_shape=opt_shape,
102+
max_shape=max_shape,
103+
dtype=type,
104+
),
105+
]
106+
self.run_test_with_dynamic_shape(
107+
eq_operator(), input_specs, output_dtypes=[output_type]
108+
)
109+
64110

65111
if __name__ == "__main__":
66112
run_tests()

tests/py/dynamo/conversion/test_pow_aten.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,65 @@ def forward(self, lhs_val):
5959
inputs,
6060
)
6161

62+
@parameterized.expand(
63+
[
64+
(
65+
"2d_dim_dtype_half",
66+
(1, 1),
67+
(2, 2),
68+
(4, 4),
69+
torch.half,
70+
torch.half,
71+
),
72+
(
73+
"3d_dim_dtype_float",
74+
(1, 1, 1),
75+
(1, 2, 3),
76+
(3, 3, 3),
77+
torch.float,
78+
torch.float,
79+
),
80+
]
81+
)
82+
def test_pow_dynamic_shape(
83+
self, _, min_shape, opt_shape, max_shape, type, output_type
84+
):
85+
class pow(nn.Module):
86+
def forward(self, lhs_val, rhs_val):
87+
return torch.ops.aten.floor_divide.default(lhs_val, rhs_val)
88+
89+
class pow_scalar(nn.Module):
90+
def forward(self, lhs_val, rhs_val):
91+
return torch.ops.aten.pow.Tensor_Scalar(lhs_val, 2.0)
92+
93+
class pow_operator(nn.Module):
94+
def forward(self, lhs_val, rhs_val):
95+
return lhs_val**rhs_val
96+
97+
input_specs = [
98+
Input(
99+
min_shape=min_shape,
100+
opt_shape=opt_shape,
101+
max_shape=max_shape,
102+
dtype=type,
103+
),
104+
Input(
105+
min_shape=min_shape,
106+
opt_shape=opt_shape,
107+
max_shape=max_shape,
108+
dtype=type,
109+
),
110+
]
111+
self.run_test_with_dynamic_shape(
112+
pow(), input_specs, output_dtypes=[output_type]
113+
)
114+
self.run_test_with_dynamic_shape(
115+
pow_scalar(), input_specs, output_dtypes=[output_type]
116+
)
117+
self.run_test_with_dynamic_shape(
118+
pow_operator(), input_specs, output_dtypes=[output_type]
119+
)
120+
62121

63122
if __name__ == "__main__":
64123
run_tests()

tests/py/dynamo/conversion/test_remainder_aten.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,65 @@ def forward(self, lhs_val, rhs_val):
5555
inputs,
5656
)
5757

58+
@parameterized.expand(
59+
[
60+
(
61+
"2d_dim_dtype_half",
62+
(1, 1),
63+
(2, 2),
64+
(4, 4),
65+
torch.half,
66+
torch.half,
67+
),
68+
(
69+
"3d_dim_dtype_float",
70+
(1, 1, 1),
71+
(1, 2, 3),
72+
(3, 3, 3),
73+
torch.float,
74+
torch.float,
75+
),
76+
]
77+
)
78+
def test_remainder_dynamic_shape(
79+
self, _, min_shape, opt_shape, max_shape, type, output_type
80+
):
81+
class remainder(nn.Module):
82+
def forward(self, lhs_val, rhs_val):
83+
return torch.ops.aten.remainder.Tensor(lhs_val, rhs_val)
84+
85+
class remainder_scalar(nn.Module):
86+
def forward(self, lhs_val, rhs_val):
87+
return torch.ops.aten.remainder.Scalar(lhs_val, 2)
88+
89+
class mod_operator(nn.Module):
90+
def forward(self, lhs_val, rhs_val):
91+
return lhs_val % rhs_val
92+
93+
input_specs = [
94+
Input(
95+
min_shape=min_shape,
96+
opt_shape=opt_shape,
97+
max_shape=max_shape,
98+
dtype=type,
99+
),
100+
Input(
101+
min_shape=min_shape,
102+
opt_shape=opt_shape,
103+
max_shape=max_shape,
104+
dtype=type,
105+
),
106+
]
107+
self.run_test_with_dynamic_shape(
108+
remainder(), input_specs, output_dtypes=[output_type]
109+
)
110+
self.run_test_with_dynamic_shape(
111+
remainder_scalar(), input_specs, output_dtypes=[output_type]
112+
)
113+
self.run_test_with_dynamic_shape(
114+
mod_operator(), input_specs, output_dtypes=[output_type]
115+
)
116+
58117

59118
if __name__ == "__main__":
60119
run_tests()

0 commit comments

Comments
 (0)