Skip to content

Commit 6d7f2fb

Browse files
committed
feat: dynamic shapes support for neg ops
1 parent 70b5b12 commit 6d7f2fb

File tree

2 files changed

+41
-1
lines changed

2 files changed

+41
-1
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,7 @@ def aten_ops_rsqrt(
563563
)
564564

565565

566-
@dynamo_tensorrt_converter(torch.ops.aten.neg.default)
566+
@dynamo_tensorrt_converter(torch.ops.aten.neg.default, supports_dynamic_shapes=True)
567567
def aten_ops_neg(
568568
ctx: ConversionContext,
569569
target: Target,

tests/py/dynamo/conversion/test_neg_aten.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,46 @@ def forward(self, input):
4242
check_dtype=False,
4343
)
4444

45+
@parameterized.expand(
46+
[
47+
(
48+
"2d_dim_dtype_half",
49+
(1, 1),
50+
(2, 2),
51+
(4, 4),
52+
torch.half,
53+
torch.half,
54+
),
55+
(
56+
"3d_dim_dtype_float",
57+
(1, 1, 1),
58+
(1, 2, 3),
59+
(3, 3, 3),
60+
torch.float,
61+
torch.float,
62+
),
63+
]
64+
)
65+
def test_dynamic_shape_neg(
66+
self, _, min_shape, opt_shape, max_shape, type, output_type
67+
):
68+
class neg(nn.Module):
69+
def forward(self, input):
70+
return torch.ops.aten.neg.default(input)
71+
72+
input_specs = [
73+
Input(
74+
min_shape=min_shape,
75+
opt_shape=opt_shape,
76+
max_shape=max_shape,
77+
dtype=type,
78+
),
79+
]
80+
81+
self.run_test_with_dynamic_shape(
82+
neg(), input_specs, output_dtypes=[output_type]
83+
)
84+
4585

4686
if __name__ == "__main__":
4787
run_tests()

0 commit comments

Comments
 (0)