Skip to content

Commit 0d4af77

Browse files
authored
feat: dynamic shape support for squeeze ops (#2994)
1 parent 3a43fd2 commit 0d4af77

File tree

3 files changed

+55
-22
lines changed

3 files changed

+55
-22
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -622,8 +622,8 @@ def aten_ops_quantize_fp8(
622622
)
623623

624624

625-
@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dim)
626-
@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dims)
625+
@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dim, supports_dynamic_shapes=True)
626+
@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dims, supports_dynamic_shapes=True)
627627
def aten_ops_squeeze(
628628
ctx: ConversionContext,
629629
target: Target,

py/torch_tensorrt/dynamo/conversion/impl/squeeze.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
from torch.fx.node import Target
44
from torch_tensorrt.dynamo._SourceIR import SourceIR
55
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
6-
from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim
7-
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
6+
from torch_tensorrt.dynamo.conversion.converter_utils import (
7+
get_positive_dim,
8+
set_layer_name,
9+
)
810
from torch_tensorrt.fx.types import TRTTensor
9-
from torch_tensorrt.fx.utils import get_dynamic_dims
1011

1112

1213
def squeeze(
@@ -25,8 +26,8 @@ def squeeze(
2526
if isinstance(dim, int):
2627
dims.append(dim)
2728
else:
28-
for dim in dim:
29-
dims.append(dim)
29+
for d in dim:
30+
dims.append(d)
3031

3132
new_dims = []
3233
for dim in dims:
@@ -36,17 +37,22 @@ def squeeze(
3637
)
3738

3839
assert input.shape[dim] != -1, "We don't support squeeze dynamic dim."
39-
assert (
40-
len(get_dynamic_dims(input.shape)) <= 1
41-
), "Currently more than one dynamic dim for input to squeeze is not supported."
4240
new_dims.append(dim)
4341

44-
output_shape = []
42+
dim_to_remove = []
43+
new_permutation = []
4544
for i, s in enumerate(input.shape):
4645
if (i in new_dims) and s == 1:
47-
continue
48-
output_shape.append(s)
46+
dim_to_remove.append(i)
47+
else:
48+
new_permutation.append(i)
49+
# If number of reshape dimensions is less than input, 0s are resolved by aligning
50+
# the most significant dimensions of input
51+
output_shape = tuple([0] * len(new_permutation))
52+
new_permutation += dim_to_remove
53+
4954
layer = ctx.net.add_shuffle(input)
50-
layer.reshape_dims = tuple(output_shape)
55+
layer.first_transpose = new_permutation
56+
layer.reshape_dims = output_shape
5157
set_layer_name(layer, target, name, source_ir)
5258
return layer.get_output(0)

tests/py/dynamo/conversion/test_squeeze_aten.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,23 +43,50 @@ def forward(self, x):
4343
)
4444

4545

46-
class TestSqueezeConverter(DispatchTestCase):
46+
class TestSqueezeConverterDynamic(DispatchTestCase):
4747
@parameterized.expand(
4848
[
49-
("2d_dim", (1), (-1, 1), [((1, 1), (1, 1), (3, 1))]),
50-
("3d_one_dim", (1), (-1, 2, 1), [((1, 2, 1), (1, 2, 1), (3, 2, 1))]),
49+
(
50+
"5d_two_dynamic_shape_-1",
51+
(0,),
52+
(1, 1, 1, 1, 1),
53+
(1, 2, 1, 2, 1),
54+
(1, 4, 1, 3, 1),
55+
),
56+
(
57+
"5d_two_dynamic_shape_-2",
58+
(0, 2),
59+
(1, 1, 1, 1, 1),
60+
(1, 2, 1, 2, 1),
61+
(1, 4, 1, 3, 1),
62+
),
63+
(
64+
"5d_three_dynamic_shape_-2",
65+
(0, 4),
66+
(1, 1, 1, 1, 1),
67+
(1, 2, 4, 2, 1),
68+
(1, 4, 4, 3, 1),
69+
),
70+
(
71+
"4d_two_dynamic_shape_-2",
72+
(0, 2),
73+
(1, 1, 2, 1),
74+
(1, 2, 2, 2),
75+
(1, 4, 2, 3),
76+
),
5177
]
5278
)
53-
def test_squeeze(self, _, dim, init_size, shape_range):
79+
def test_squeeze(self, _, dim, min_shape, opt_shape, max_shape):
5480
class Squeeze(nn.Module):
5581
def forward(self, x):
56-
return torch.ops.aten.squeeze.dim(x, dim)
82+
return torch.ops.aten.squeeze.dims(x, dim)
5783

5884
input_specs = [
5985
Input(
60-
shape=init_size,
61-
dtype=torch.float32,
62-
shape_ranges=shape_range,
86+
min_shape=min_shape,
87+
opt_shape=opt_shape,
88+
max_shape=max_shape,
89+
dtype=torch.float,
6390
),
6491
]
6592
self.run_test_with_dynamic_shape(

0 commit comments

Comments
 (0)