Skip to content

Commit 0ef880d

Browse files
authored
dynamic shape for slice converter (#2901)
1 parent c0a2bea commit 0ef880d

File tree

2 files changed

+315
-15
lines changed

2 files changed

+315
-15
lines changed

py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py

Lines changed: 161 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import math
2+
import sys
23
from typing import Optional, Sequence
34

45
import numpy as np
@@ -14,6 +15,11 @@
1415
get_trt_tensor,
1516
)
1617
from torch_tensorrt.dynamo.conversion.impl.cat import cat
18+
from torch_tensorrt.dynamo.conversion.impl.elementwise import floor_divide
19+
from torch_tensorrt.dynamo.conversion.impl.elementwise.ops import (
20+
convert_binary_elementwise,
21+
)
22+
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
1723
from torch_tensorrt.dynamo.conversion.impl.shape import shape as get_shape
1824
from torch_tensorrt.dynamo.conversion.impl.slice.base import slice
1925
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
@@ -36,29 +42,175 @@ def slice_op( # TODO: This should be slice not whatever is in base
3642
stop: Optional[int],
3743
step: int,
3844
) -> TRTTensor:
45+
# check if dim is same as dynamic shape dimension
46+
# this is required when stop is ITensor
47+
dynamic_input_dim_equal = False
48+
for i in range(len(input.shape)):
49+
if input.shape[i] == DYNAMIC_DIM and i == dim:
50+
dynamic_input_dim_equal = True
51+
3952
# Special case for start being None
4053
if start is None:
4154
start = 0
4255

4356
# Special case for stop being None
57+
stop_dynamic_None = False
4458
if stop is None:
45-
stop = input.shape[dim]
59+
stop_dynamic_None = True if input.shape[dim] == -1 else False
60+
stop = 0 if input.shape[dim] == -1 else input.shape[dim]
4661

4762
dim = get_positive_dim(dim, len(input.shape))
48-
start = get_positive_dim(start, input.shape[dim])
49-
stop = get_positive_dim(stop, input.shape[dim])
5063

51-
if has_dynamic_shape(input.shape):
52-
# Check whether slice target dim is dynamic shape dim
53-
assert input.shape[dim] != -1, "Can't slice on dynamic shape dimension!"
64+
# Assign the initial start tensor
65+
start_slice = []
66+
# the add_slice will take care of dynamic input shape cases here
67+
if isinstance(start, int):
68+
start_slice = [0] * len(input.shape)
69+
start_slice[dim] = start
70+
else:
71+
for i in range(len(input.shape)):
72+
start_slice.append(0) if i != dim else start_slice.append(start)
73+
74+
# Assign the initial stop tensor
75+
stop_slice = []
76+
if isinstance(stop, int) and dynamic_input_dim_equal:
77+
stop_slice = input.shape
78+
stop_slice[dim] = stop
79+
else:
80+
# required for cases where stop is ITensor and dim != dynamic dim of input
81+
# not required for cases where stop is negative and dim != dynamic dim of inpu
82+
for i in range(len(input.shape)):
83+
if input.shape[i] == DYNAMIC_DIM and i != dim:
84+
stop_slice.append(
85+
get_shape(
86+
ctx, target, source_ir, name + f"_shape_dim_stop_{i}", input, i
87+
)
88+
)
89+
elif i == dim:
90+
stop_slice.append(stop)
91+
else:
92+
stop_slice.append(input.shape[i])
5493

55-
start_slice = [0] * len(input.shape)
56-
start_slice[dim] = start
5794
stride_slice = [1] * len(input.shape)
5895
stride_slice[dim] = step
5996
output_shape = list(input.shape)
60-
output_shape[dim] = math.ceil((stop - start) / step)
6197

98+
if input.shape[dim] != -1 and isinstance(start, int) and isinstance(stop, int):
99+
start = get_positive_dim(start, input.shape[dim])
100+
stop = get_positive_dim(stop, input.shape[dim])
101+
start_slice[dim] = start
102+
else:
103+
# the start and stop or None is dynamic along dim or or start or stop is an ITensor
104+
if (
105+
not (isinstance(start, int))
106+
or not (isinstance(stop, int))
107+
or start < 0
108+
or stop < 0
109+
or stop_dynamic_None
110+
or stop == sys.maxsize
111+
):
112+
# special assignments for dynamic cases
113+
if isinstance(start, int) and start < 0:
114+
start_slice = input.shape
115+
start_slice[dim] = -1 * start
116+
if (isinstance(stop, int) and stop < 0) or stop_dynamic_None:
117+
stop_slice = [0] * len(input.shape)
118+
stop_slice[dim] = -1 * stop
119+
if stop == sys.maxsize:
120+
stop_slice = [0] * len(input.shape)
121+
start_slice_tensor = cat(
122+
ctx,
123+
target,
124+
source_ir,
125+
name + "_start_slice_concat",
126+
tuple(start_slice),
127+
0,
128+
cast_dtype=trt.int32,
129+
)
130+
stop_slice_tensor = cat(
131+
ctx,
132+
target,
133+
source_ir,
134+
name + "_stop_slice_concat",
135+
tuple(stop_slice),
136+
0,
137+
cast_dtype=trt.int32,
138+
)
139+
stride_slice_tensor = cat(
140+
ctx,
141+
target,
142+
source_ir,
143+
name + "_stride_slice_concat",
144+
tuple(stride_slice),
145+
0,
146+
cast_dtype=trt.int32,
147+
)
148+
149+
if isinstance(start, int) and start < 0:
150+
shape = get_shape_with_dynamic_shape(
151+
ctx, target, source_ir, name, output_shape, input
152+
)
153+
start_slice_tensor = convert_binary_elementwise(
154+
ctx,
155+
target,
156+
source_ir,
157+
name + "_sub_start",
158+
trt.ElementWiseOperation.SUB,
159+
shape,
160+
start_slice_tensor,
161+
)
162+
if isinstance(stop, int) and (
163+
(stop < 0) or stop_dynamic_None or stop == sys.maxsize
164+
):
165+
shape = get_shape_with_dynamic_shape(
166+
ctx, target, source_ir, name, output_shape, input
167+
)
168+
stop_slice_tensor = convert_binary_elementwise(
169+
ctx,
170+
target,
171+
source_ir,
172+
name + "_sub_stop",
173+
trt.ElementWiseOperation.SUB,
174+
shape,
175+
stop_slice_tensor,
176+
)
177+
178+
# this is required for the ceil operation
179+
output_shape_tensor_num = convert_binary_elementwise(
180+
ctx,
181+
target,
182+
source_ir,
183+
name + "_sub_num",
184+
trt.ElementWiseOperation.SUB,
185+
start_slice_tensor,
186+
stop_slice_tensor,
187+
)
188+
output_shape_tensor_neg = floor_divide(
189+
ctx,
190+
target,
191+
source_ir,
192+
name + "_div",
193+
output_shape_tensor_num,
194+
stride_slice_tensor,
195+
)
196+
output_shape_tensor = convert_binary_elementwise(
197+
ctx,
198+
target,
199+
source_ir,
200+
name + "_prod",
201+
trt.ElementWiseOperation.PROD,
202+
output_shape_tensor_neg,
203+
-1,
204+
)
205+
layer = ctx.net.add_slice(
206+
input, start=trt.Dims(), shape=trt.Dims(), stride=trt.Dims()
207+
)
208+
layer.set_input(1, start_slice_tensor)
209+
layer.set_input(2, output_shape_tensor)
210+
layer.set_input(3, stride_slice_tensor)
211+
return layer.get_output(0)
212+
213+
output_shape[dim] = math.ceil((stop - start) / step)
62214
return slice(
63215
ctx, target, source_ir, name, input, start_slice, output_shape, stride_slice
64216
)

tests/py/dynamo/conversion/test_slice_aten.py

Lines changed: 154 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import torch
22
from parameterized import parameterized
33
from torch.testing._internal.common_utils import run_tests
4-
54
from torch_tensorrt import Input
65

76
from .harness import DispatchTestCase
@@ -53,11 +52,159 @@ def forward(self, input):
5352
class TestSliceConverterDynamicShape(DispatchTestCase):
5453
@parameterized.expand(
5554
[
56-
("slice_dim_start_stop_step", 1, 0, 7, 2),
57-
("slice_dim_start_stop_step", 1, 0, 10, 2),
55+
(
56+
"slice_dynamic_dim_start_stop_step_offset",
57+
(1, 10, 1),
58+
(1, 10, 10),
59+
(1, 10, 10),
60+
1,
61+
0,
62+
7,
63+
2,
64+
),
65+
(
66+
"slice_dynamic_dim_start_stop_step",
67+
(1, 10, 1),
68+
(1, 10, 10),
69+
(1, 10, 10),
70+
1,
71+
0,
72+
10,
73+
2,
74+
),
75+
(
76+
"slice_dynamic_dim_start_stop_step_negatives",
77+
(1, 10, 10),
78+
(10, 10, 10),
79+
(10, 10, 10),
80+
-2,
81+
-2,
82+
-1,
83+
1,
84+
),
85+
(
86+
"slice_dim_start_stop_step_max_int",
87+
(1, 10, 10),
88+
(10, 10, 10),
89+
(10, 10, 10),
90+
2,
91+
0,
92+
2**63 - 1,
93+
1,
94+
),
95+
(
96+
"slice_dim_start_stop_step_past_end",
97+
(1, 10, 10),
98+
(10, 10, 10),
99+
(10, 10, 10),
100+
2,
101+
0,
102+
2048,
103+
1,
104+
),
105+
(
106+
"slice_dim_start_stop_step_none",
107+
(1, 10, 10),
108+
(10, 10, 10),
109+
(10, 10, 10),
110+
2,
111+
None,
112+
None,
113+
1,
114+
),
115+
(
116+
"slice_dynamic_dim_start_stop_step_offset_4D",
117+
(1, 10, 1, 3),
118+
(1, 10, 10, 3),
119+
(1, 10, 10, 3),
120+
1,
121+
0,
122+
7,
123+
2,
124+
),
125+
(
126+
"slice_dynamic_dim_start_stop_step_4D",
127+
(1, 10, 1, 3),
128+
(1, 10, 10, 3),
129+
(1, 10, 10, 3),
130+
1,
131+
0,
132+
10,
133+
2,
134+
),
135+
(
136+
"slice_dynamic_dim_dyn_start_dyn_stop_step",
137+
(1, 10, 1),
138+
(1, 10, 10),
139+
(1, 10, 10),
140+
2,
141+
-2,
142+
10,
143+
2,
144+
),
145+
(
146+
"slice_dynamic_dim_dyn_start_stop_dyn_step",
147+
(1, 10, 1),
148+
(1, 10, 10),
149+
(1, 10, 10),
150+
2,
151+
0,
152+
-2,
153+
2,
154+
),
155+
(
156+
"slice_dynamic_dim_dyn_start_stop_None_step",
157+
(1, 10, 1),
158+
(1, 10, 10),
159+
(1, 10, 10),
160+
2,
161+
0,
162+
None,
163+
2,
164+
),
165+
(
166+
"slice_dynamic_dim_dyn_start_dyn_stop_dyn_step",
167+
(1, 10, 1),
168+
(1, 10, 10),
169+
(1, 10, 10),
170+
2,
171+
-8,
172+
-2,
173+
2,
174+
),
175+
(
176+
"slice_dynamic_dim_dyn_start_dyn_stop_dyn_step_ceil",
177+
(1, 10, 1),
178+
(1, 10, 10),
179+
(1, 10, 10),
180+
2,
181+
-9,
182+
-2,
183+
2,
184+
),
185+
(
186+
"slice_dynamic_dim_dyn_start_dyn_stop_dyn_step_diff_dim",
187+
(1, 10, 1),
188+
(1, 10, 10),
189+
(1, 10, 10),
190+
0,
191+
-8,
192+
-2,
193+
2,
194+
),
195+
(
196+
"slice_dynamic_dim_dyn_start_dyn_stop_dyn_step_diff_dim_ceil",
197+
(1, 10, 1),
198+
(1, 10, 10),
199+
(1, 10, 10),
200+
0,
201+
-9,
202+
-2,
203+
2,
204+
),
58205
]
59206
)
60-
def test_slice(self, _, dim, start, stop, step):
207+
def test_slice(self, _, min_shape, opt_shape, max_shape, dim, start, stop, step):
61208
class TestModule(torch.nn.Module):
62209
def __init__(self):
63210
super().__init__()
@@ -68,9 +215,10 @@ def forward(self, input):
68215

69216
input_specs = [
70217
Input(
71-
shape=(1, 10, -1),
218+
min_shape=min_shape,
219+
opt_shape=opt_shape,
220+
max_shape=max_shape,
72221
dtype=torch.float32,
73-
shape_ranges=[((1, 10, 1), (1, 10, 10), (1, 10, 10))],
74222
),
75223
]
76224
self.run_test_with_dynamic_shape(

0 commit comments

Comments
 (0)