Skip to content

Commit d242b7e

Browse files
committed
dynamic shapes for slice converter
Adding cases for slicing on dynamic dimension handling the case wjem stop is max int64 value in the dynamic dimension Addressing GPT2 cases- when stop is an ITensor
1 parent 4a0184a commit d242b7e

File tree

2 files changed

+318
-15
lines changed

2 files changed

+318
-15
lines changed

py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py

Lines changed: 164 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import math
2+
import sys
23
from typing import Optional, Sequence
34

45
import numpy as np
@@ -14,7 +15,14 @@
1415
get_trt_tensor,
1516
)
1617
from torch_tensorrt.dynamo.conversion.impl.cat import cat
18+
from torch_tensorrt.dynamo.conversion.impl.elementwise import floor_divide
19+
from torch_tensorrt.dynamo.conversion.impl.elementwise.ops import (
20+
convert_binary_elementwise,
21+
)
22+
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
23+
from torch_tensorrt.dynamo.conversion.impl.shape import shape as get_shape
1724
from torch_tensorrt.dynamo.conversion.impl.slice.base import slice
25+
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
1826
from torch_tensorrt.fx.converters.converter_utils import (
1927
has_dynamic_shape,
2028
prepend_ones,
@@ -34,29 +42,176 @@ def slice_op( # TODO: This should be slice not whatever is in base
3442
stop: Optional[int],
3543
step: int,
3644
) -> TRTTensor:
45+
# check if dim is same as dynamic shape dimension
46+
# this is required when stop is ITensor
47+
dynamic_input_dim_equal = False
48+
for i in range(len(input.shape)):
49+
if input.shape[i] == DYNAMIC_DIM and i == dim:
50+
dynamic_input_dim_equal = True
51+
3752
# Special case for start being None
3853
if start is None:
3954
start = 0
4055

4156
# Special case for stop being None
57+
stop_dynamic_None = False
58+
if stop is None:
59+
stop_dynamic_None = True if input.shape[dim] == -1 else False
4260
if stop is None:
43-
stop = input.shape[dim]
61+
stop = 0 if input.shape[dim] == -1 else input.shape[dim]
4462

4563
dim = get_positive_dim(dim, len(input.shape))
46-
start = get_positive_dim(start, input.shape[dim])
47-
stop = get_positive_dim(stop, input.shape[dim])
4864

49-
if has_dynamic_shape(input.shape):
50-
# Check whether slice target dim is dynamic shape dim
51-
assert input.shape[dim] != -1, "Can't slice on dynamic shape dimension!"
65+
# Assign the initial start tensor
66+
start_slice = []
67+
# the add_slice will take care of dynamic input shape cases here
68+
if isinstance(start, int):
69+
start_slice = [0] * len(input.shape)
70+
start_slice[dim] = start
71+
else:
72+
for i in range(len(input.shape)):
73+
start_slice.append(0) if i == dim else start_slice.append(start)
74+
75+
# Assign the initial stop tensor
76+
stop_slice = []
77+
if isinstance(stop, int) and dynamic_input_dim_equal:
78+
stop_slice = input.shape
79+
stop_slice[dim] = stop
80+
else:
81+
# required for cases where stop is ITensor and dim != dynamic dim of input
82+
# not required for cases where stop is negative and dim != dynamic dim of inpu
83+
for i in range(len(input.shape)):
84+
if input.shape[i] == DYNAMIC_DIM and i != dim:
85+
stop_slice.append(
86+
get_shape(
87+
ctx, target, source_ir, name + f"_shape_dim_stop_{i}", input, i
88+
)
89+
)
90+
elif i == dim:
91+
stop_slice.append(stop)
92+
else:
93+
stop_slice.append(input.shape[i])
5294

53-
start_slice = [0] * len(input.shape)
54-
start_slice[dim] = start
5595
stride_slice = [1] * len(input.shape)
5696
stride_slice[dim] = step
5797
output_shape = list(input.shape)
58-
output_shape[dim] = math.ceil((stop - start) / step)
5998

99+
if input.shape[dim] != -1 and isinstance(start, int) and isinstance(stop, int):
100+
start = get_positive_dim(start, input.shape[dim])
101+
stop = get_positive_dim(stop, input.shape[dim])
102+
start_slice[dim] = start
103+
else:
104+
# the start and stop or None is dynamic along dim or or start or stop is an ITensor
105+
if (
106+
not (isinstance(start, int))
107+
or not (isinstance(stop, int))
108+
or start < 0
109+
or stop < 0
110+
or stop_dynamic_None
111+
or stop == sys.maxsize
112+
):
113+
# special assignments for dynamic cases
114+
if isinstance(start, int) and start < 0:
115+
start_slice = input.shape
116+
start_slice[dim] = -1 * start
117+
if (isinstance(stop, int) and stop < 0) or stop_dynamic_None:
118+
stop_slice = [0] * len(input.shape)
119+
stop_slice[dim] = -1 * stop
120+
if stop == sys.maxsize:
121+
stop_slice = [0] * len(input.shape)
122+
start_slice_tensor = cat(
123+
ctx,
124+
target,
125+
source_ir,
126+
name + "_start_slice_concat",
127+
tuple(start_slice),
128+
0,
129+
cast_dtype=trt.int32,
130+
)
131+
stop_slice_tensor = cat(
132+
ctx,
133+
target,
134+
source_ir,
135+
name + "_stop_slice_concat",
136+
tuple(stop_slice),
137+
0,
138+
cast_dtype=trt.int32,
139+
)
140+
stride_slice_tensor = cat(
141+
ctx,
142+
target,
143+
source_ir,
144+
name + "_stride_slice_concat",
145+
tuple(stride_slice),
146+
0,
147+
cast_dtype=trt.int32,
148+
)
149+
150+
if isinstance(start, int) and start < 0:
151+
shape = get_shape_with_dynamic_shape(
152+
ctx, target, source_ir, name, output_shape, input
153+
)
154+
start_slice_tensor = convert_binary_elementwise(
155+
ctx,
156+
target,
157+
source_ir,
158+
name + "_sub_start",
159+
trt.ElementWiseOperation.SUB,
160+
shape,
161+
start_slice_tensor,
162+
)
163+
if isinstance(stop, int) and (
164+
(stop < 0) or stop_dynamic_None or stop == sys.maxsize
165+
):
166+
shape = get_shape_with_dynamic_shape(
167+
ctx, target, source_ir, name, output_shape, input
168+
)
169+
stop_slice_tensor = convert_binary_elementwise(
170+
ctx,
171+
target,
172+
source_ir,
173+
name + "_sub_stop",
174+
trt.ElementWiseOperation.SUB,
175+
shape,
176+
stop_slice_tensor,
177+
)
178+
179+
# this is required for the ceil operation
180+
output_shape_tensor_num = convert_binary_elementwise(
181+
ctx,
182+
target,
183+
source_ir,
184+
name + "_sub_num",
185+
trt.ElementWiseOperation.SUB,
186+
start_slice_tensor,
187+
stop_slice_tensor,
188+
)
189+
output_shape_tensor_neg = floor_divide(
190+
ctx,
191+
target,
192+
source_ir,
193+
name + "_div",
194+
output_shape_tensor_num,
195+
stride_slice_tensor,
196+
)
197+
output_shape_tensor = convert_binary_elementwise(
198+
ctx,
199+
target,
200+
source_ir,
201+
name + "_prod",
202+
trt.ElementWiseOperation.PROD,
203+
output_shape_tensor_neg,
204+
-1,
205+
)
206+
layer = ctx.net.add_slice(
207+
input, start=trt.Dims(), shape=trt.Dims(), stride=trt.Dims()
208+
)
209+
layer.set_input(1, start_slice_tensor)
210+
layer.set_input(2, output_shape_tensor)
211+
layer.set_input(3, stride_slice_tensor)
212+
return layer.get_output(0)
213+
214+
output_shape[dim] = math.ceil((stop - start) / step)
60215
return slice(
61216
ctx, target, source_ir, name, input, start_slice, output_shape, stride_slice
62217
)

tests/py/dynamo/conversion/test_slice_aten.py

Lines changed: 154 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import torch
22
from parameterized import parameterized
33
from torch.testing._internal.common_utils import run_tests
4-
54
from torch_tensorrt import Input
65

76
from .harness import DispatchTestCase
@@ -53,11 +52,159 @@ def forward(self, input):
5352
class TestSliceConverterDynamicShape(DispatchTestCase):
5453
@parameterized.expand(
5554
[
56-
("slice_dim_start_stop_step", 1, 0, 7, 2),
57-
("slice_dim_start_stop_step", 1, 0, 10, 2),
55+
(
56+
"slice_dynamic_dim_start_stop_step_offset",
57+
(1, 10, 1),
58+
(1, 10, 10),
59+
(1, 10, 10),
60+
1,
61+
0,
62+
7,
63+
2,
64+
),
65+
(
66+
"slice_dynamic_dim_start_stop_step",
67+
(1, 10, 1),
68+
(1, 10, 10),
69+
(1, 10, 10),
70+
1,
71+
0,
72+
10,
73+
2,
74+
),
75+
(
76+
"slice_dynamic_dim_start_stop_step_negatives",
77+
(1, 10, 10),
78+
(10, 10, 10),
79+
(10, 10, 10),
80+
-2,
81+
-2,
82+
-1,
83+
1,
84+
),
85+
(
86+
"slice_dim_start_stop_step_max_int",
87+
(1, 10, 10),
88+
(10, 10, 10),
89+
(10, 10, 10),
90+
2,
91+
0,
92+
2**63 - 1,
93+
1,
94+
),
95+
(
96+
"slice_dim_start_stop_step_past_end",
97+
(1, 10, 10),
98+
(10, 10, 10),
99+
(10, 10, 10),
100+
2,
101+
0,
102+
2048,
103+
1,
104+
),
105+
(
106+
"slice_dim_start_stop_step_none",
107+
(1, 10, 10),
108+
(10, 10, 10),
109+
(10, 10, 10),
110+
2,
111+
None,
112+
None,
113+
1,
114+
),
115+
(
116+
"slice_dynamic_dim_start_stop_step_offset_4D",
117+
(1, 10, 1, 3),
118+
(1, 10, 10, 3),
119+
(1, 10, 10, 3),
120+
1,
121+
0,
122+
7,
123+
2,
124+
),
125+
(
126+
"slice_dynamic_dim_start_stop_step_4D",
127+
(1, 10, 1, 3),
128+
(1, 10, 10, 3),
129+
(1, 10, 10, 3),
130+
1,
131+
0,
132+
10,
133+
2,
134+
),
135+
(
136+
"slice_dynamic_dim_dyn_start_dyn_stop_step",
137+
(1, 10, 1),
138+
(1, 10, 10),
139+
(1, 10, 10),
140+
2,
141+
-2,
142+
10,
143+
2,
144+
),
145+
(
146+
"slice_dynamic_dim_dyn_start_stop_dyn_step",
147+
(1, 10, 1),
148+
(1, 10, 10),
149+
(1, 10, 10),
150+
2,
151+
0,
152+
-2,
153+
2,
154+
),
155+
(
156+
"slice_dynamic_dim_dyn_start_stop_None_step",
157+
(1, 10, 1),
158+
(1, 10, 10),
159+
(1, 10, 10),
160+
2,
161+
0,
162+
None,
163+
2,
164+
),
165+
(
166+
"slice_dynamic_dim_dyn_start_dyn_stop_dyn_step",
167+
(1, 10, 1),
168+
(1, 10, 10),
169+
(1, 10, 10),
170+
2,
171+
-8,
172+
-2,
173+
2,
174+
),
175+
(
176+
"slice_dynamic_dim_dyn_start_dyn_stop_dyn_step_ceil",
177+
(1, 10, 1),
178+
(1, 10, 10),
179+
(1, 10, 10),
180+
2,
181+
-9,
182+
-2,
183+
2,
184+
),
185+
(
186+
"slice_dynamic_dim_dyn_start_dyn_stop_dyn_step_diff_dim",
187+
(1, 10, 1),
188+
(1, 10, 10),
189+
(1, 10, 10),
190+
0,
191+
-8,
192+
-2,
193+
2,
194+
),
195+
(
196+
"slice_dynamic_dim_dyn_start_dyn_stop_dyn_step_diff_dim_ceil",
197+
(1, 10, 1),
198+
(1, 10, 10),
199+
(1, 10, 10),
200+
0,
201+
-9,
202+
-2,
203+
2,
204+
),
58205
]
59206
)
60-
def test_slice(self, _, dim, start, stop, step):
207+
def test_slice(self, _, min_shape, opt_shape, max_shape, dim, start, stop, step):
61208
class TestModule(torch.nn.Module):
62209
def __init__(self):
63210
super().__init__()
@@ -68,9 +215,10 @@ def forward(self, input):
68215

69216
input_specs = [
70217
Input(
71-
shape=(1, 10, -1),
218+
min_shape=min_shape,
219+
opt_shape=opt_shape,
220+
max_shape=max_shape,
72221
dtype=torch.float32,
73-
shape_ranges=[((1, 10, 1), (1, 10, 10), (1, 10, 10))],
74222
),
75223
]
76224
self.run_test_with_dynamic_shape(

0 commit comments

Comments
 (0)