Skip to content

Commit 2cf4eaa

Browse files
committed
dynamic shape for squeeze ops
Not support for dds(data-dependent shape) input
1 parent c0a2bea commit 2cf4eaa

File tree

2 files changed

+117
-19
lines changed

2 files changed

+117
-19
lines changed

py/torch_tensorrt/dynamo/conversion/impl/squeeze.py

Lines changed: 84 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
from typing import Optional, Sequence, Union
22

3+
import numpy as np
4+
import tensorrt as trt
35
from torch.fx.node import Target
46
from torch_tensorrt.dynamo._SourceIR import SourceIR
57
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
6-
from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim
7-
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
8+
from torch_tensorrt.dynamo.conversion.converter_utils import (
9+
cast_trt_tensor,
10+
get_positive_dim,
11+
set_layer_name,
12+
)
13+
from torch_tensorrt.dynamo.conversion.impl.elementwise import ne
814
from torch_tensorrt.fx.types import TRTTensor
9-
from torch_tensorrt.fx.utils import get_dynamic_dims
1015

1116

1217
def squeeze(
@@ -29,24 +34,90 @@ def squeeze(
2934
dims.append(dim)
3035

3136
new_dims = []
37+
dim_has_dynamic_shape = False
3238
for dim in dims:
3339
dim = get_positive_dim(
3440
dim,
3541
len(input.shape),
3642
)
3743

38-
assert input.shape[dim] != -1, "We don't support squeeze dynamic dim."
39-
assert (
40-
len(get_dynamic_dims(input.shape)) <= 1
41-
), "Currently more than one dynamic dim for input to squeeze is not supported."
44+
if input.shape[dim] == -1:
45+
dim_has_dynamic_shape = True
4246
new_dims.append(dim)
4347

44-
output_shape = []
45-
for i, s in enumerate(input.shape):
46-
if (i in new_dims) and s == 1:
47-
continue
48-
output_shape.append(s)
4948
layer = ctx.net.add_shuffle(input)
50-
layer.reshape_dims = tuple(output_shape)
5149
set_layer_name(layer, target, name, source_ir)
50+
if dim_has_dynamic_shape:
51+
num_shape = len(input.shape)
52+
53+
tensor_shape_layer = ctx.net.add_shape(input)
54+
tensor_shape = tensor_shape_layer.get_output(0)
55+
tensor_shape = cast_trt_tensor(
56+
ctx, tensor_shape, trt.int32, name + "shape_casted", "shape"
57+
)
58+
59+
# change it to get_trt_tensor
60+
one_layer = ctx.net.add_constant(
61+
(num_shape,),
62+
np.ascontiguousarray([1] * num_shape, np.int32),
63+
)
64+
set_layer_name(one_layer, target, name + "_one", source_ir)
65+
66+
zero_layer = ctx.net.add_constant(
67+
(num_shape,),
68+
np.zeros((num_shape,), dtype=np.int32),
69+
)
70+
set_layer_name(zero_layer, target, name + "_zero", source_ir)
71+
72+
# append last element value
73+
num_append = num_shape - len(new_dims)
74+
if num_append > 0:
75+
new_dims += [new_dims[-1]] * num_append
76+
77+
index_value = np.array(new_dims, dtype=np.int32)
78+
index_layer = ctx.net.add_constant(index_value.shape, index_value)
79+
set_layer_name(index_layer, target, name + "_index", source_ir)
80+
81+
scatter_layer = ctx.net.add_scatter(
82+
zero_layer.get_output(0),
83+
index_layer.get_output(0),
84+
one_layer.get_output(0),
85+
trt.ScatterMode.ELEMENT,
86+
)
87+
set_layer_name(scatter_layer, target, name + "_scatter", source_ir)
88+
89+
# [1, 2, 1, 3, 1]
90+
# [0, 0, 1, 1, 1]
91+
# [t, t, f, t, f]
92+
ne_tensor = ne(
93+
ctx,
94+
target,
95+
source_ir,
96+
name + "_ne",
97+
tensor_shape,
98+
scatter_layer.get_output(0),
99+
)
100+
101+
# [t, t, f, t, f] -> [0, 1, 3]
102+
non_zero_layer = ctx.net.add_non_zero(ne_tensor)
103+
set_layer_name(non_zero_layer, target, name + "_non_zero", source_ir)
104+
105+
non_zero_shuffle_layer = ctx.net.add_shuffle(non_zero_layer.get_output(0))
106+
set_layer_name(non_zero_shuffle_layer, target, name + "_shuffle", source_ir)
107+
non_zero_shuffle_layer.second_transpose = (1, 0)
108+
109+
# (1,2,1,3,1) + [0, 1, 3 ,4] -> [1, 2, 3, 1]
110+
gather_layer = ctx.net.add_gather_v2(
111+
tensor_shape, non_zero_shuffle_layer.get_output(0), mode=trt.GatherMode.ND
112+
)
113+
set_layer_name(gather_layer, target, name + "_gather", source_ir)
114+
115+
layer.set_input(1, gather_layer.get_output(0))
116+
else:
117+
output_shape = []
118+
for i, s in enumerate(input.shape):
119+
if (i in new_dims) and s == 1:
120+
continue
121+
output_shape.append(s)
122+
layer.reshape_dims = tuple(output_shape)
52123
return layer.get_output(0)

tests/py/dynamo/conversion/test_squeeze_aten.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,25 +46,52 @@ def forward(self, x):
4646
class TestSqueezeConverter(DispatchTestCase):
4747
@parameterized.expand(
4848
[
49-
("2d_dim", (1), (-1, 1), [((1, 1), (1, 1), (3, 1))]),
50-
("3d_one_dim", (1), (-1, 2, 1), [((1, 2, 1), (1, 2, 1), (3, 2, 1))]),
49+
(
50+
"2d_dim",
51+
(1),
52+
(1, 1),
53+
(1, 1),
54+
(3, 1),
55+
torch.half,
56+
torch.half,
57+
),
58+
(
59+
"3d_one_dim",
60+
(1),
61+
(1, 2, 1),
62+
(1, 2, 1),
63+
(3, 2, 1),
64+
torch.float,
65+
torch.float,
66+
),
67+
(
68+
"3d_one_dim_dynamic",
69+
(0),
70+
(1, 2, 1),
71+
(1, 2, 1),
72+
(3, 2, 1),
73+
torch.float,
74+
torch.float,
75+
),
5176
]
5277
)
53-
def test_squeeze(self, _, dim, init_size, shape_range):
78+
def test_squeeze(self, _, dim, min_shape, opt_shape, max_shape, type, output_type):
5479
class Squeeze(nn.Module):
5580
def forward(self, x):
5681
return torch.ops.aten.squeeze.dim(x, dim)
5782

5883
input_specs = [
5984
Input(
60-
shape=init_size,
61-
dtype=torch.float32,
62-
shape_ranges=shape_range,
85+
min_shape=min_shape,
86+
opt_shape=opt_shape,
87+
max_shape=max_shape,
88+
dtype=type,
6389
),
6490
]
6591
self.run_test_with_dynamic_shape(
6692
Squeeze(),
6793
input_specs,
94+
output_dtypes=[output_type],
6895
)
6996

7097

0 commit comments

Comments
 (0)