Skip to content

Commit 0f7977e

Browse files
committed
feat: dynamic shapes support for neg ops
1 parent dc9948d commit 0f7977e

File tree

2 files changed

+44
-1
lines changed

2 files changed

+44
-1
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,7 @@ def aten_ops_rsqrt(
563563
)
564564

565565

566-
@dynamo_tensorrt_converter(torch.ops.aten.neg.default)
566+
@dynamo_tensorrt_converter(torch.ops.aten.neg.default, supports_dynamic_shapes=True)
567567
def aten_ops_neg(
568568
ctx: ConversionContext,
569569
target: Target,

tests/py/dynamo/conversion/test_neg_aten.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,49 @@ def forward(self, input):
4242
check_dtype=False,
4343
)
4444

45+
@parameterized.expand(
46+
[
47+
(
48+
"3d_dim_dtype_int32",
49+
(3, 2, 1),
50+
(3, 2, 3),
51+
(3, 3, 4),
52+
torch.int32,
53+
),
54+
(
55+
"2d_dim_dtype_half",
56+
(1, 1),
57+
(2, 2),
58+
(4, 4),
59+
torch.half,
60+
),
61+
(
62+
"3d_dim_dtype_float",
63+
(1, 1, 1),
64+
(1, 2, 3),
65+
(3, 3, 3),
66+
torch.float,
67+
),
68+
]
69+
)
70+
def test_dynamic_shape_neg(self, _, min_shape, opt_shape, max_shape, type):
71+
class neg(nn.Module):
72+
def forward(self, input):
73+
return torch.ops.aten.neg.default(input)
74+
75+
input_specs = [
76+
Input(
77+
min_shape=min_shape,
78+
opt_shape=opt_shape,
79+
max_shape=max_shape,
80+
dtype=type,
81+
),
82+
]
83+
84+
self.run_test_with_dynamic_shape(
85+
neg(), input_specs, check_dtype=(type != torch.int32)
86+
)
87+
4588

4689
if __name__ == "__main__":
4790
run_tests()

0 commit comments

Comments
 (0)