Skip to content

Commit aa9873a

Browse files
committed
Adding cases for slicing on dynamic dimension
1 parent 2e1bac2 commit aa9873a

File tree

2 files changed

+167
-9
lines changed

2 files changed

+167
-9
lines changed

py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py

Lines changed: 117 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@
1414
get_trt_tensor,
1515
)
1616
from torch_tensorrt.dynamo.conversion.impl.cat import cat
17+
from torch_tensorrt.dynamo.conversion.impl.elementwise import floor_divide
18+
from torch_tensorrt.dynamo.conversion.impl.elementwise.ops import (
19+
convert_binary_elementwise,
20+
)
21+
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
1722
from torch_tensorrt.dynamo.conversion.impl.slice.base import slice
1823
from torch_tensorrt.fx.converters.converter_utils import (
1924
has_dynamic_shape,
@@ -39,24 +44,127 @@ def slice_op( # TODO: This should be slice not whatever is in base
3944
start = 0
4045

4146
# Special case for stop being None
47+
stop_dynamic_None = False
48+
if stop is None:
49+
stop_dynamic_None = True if input.shape[dim] == -1 else False
4250
if stop is None:
43-
stop = input.shape[dim]
51+
stop = 0 if input.shape[dim] == -1 else input.shape[dim]
4452

4553
dim = get_positive_dim(dim, len(input.shape))
46-
start = get_positive_dim(start, input.shape[dim])
47-
stop = get_positive_dim(stop, input.shape[dim])
48-
49-
if has_dynamic_shape(input.shape):
50-
# Check whether slice target dim is dynamic shape dim
51-
assert input.shape[dim] != -1, "Can't slice on dynamic shape dimension!"
52-
5354
start_slice = [0] * len(input.shape)
5455
start_slice[dim] = start
56+
stop_slice = input.shape
57+
stop_slice[dim] = stop
5558
stride_slice = [1] * len(input.shape)
5659
stride_slice[dim] = step
5760
output_shape = list(input.shape)
58-
output_shape[dim] = math.ceil((stop - start) / step)
5961

62+
if input.shape[dim] != -1:
63+
start = get_positive_dim(start, input.shape[dim])
64+
stop = get_positive_dim(stop, input.shape[dim])
65+
start_slice[dim] = start
66+
else:
67+
# the start and stop or None is dynamic along dim
68+
if start < 0 or stop < 0 or stop_dynamic_None:
69+
# special assignments for dynamic cases
70+
if start < 0:
71+
start_slice = input.shape
72+
start_slice[dim] = -1 * start
73+
if stop < 0 or stop_dynamic_None:
74+
stop_slice = [0] * len(input.shape)
75+
stop_slice[dim] = -1 * stop
76+
77+
start_slice_tensor = cat(
78+
ctx,
79+
target,
80+
source_ir,
81+
name + "_start_slice_concat",
82+
tuple(start_slice),
83+
0,
84+
cast_dtype=trt.int32,
85+
)
86+
stop_slice_tensor = cat(
87+
ctx,
88+
target,
89+
source_ir,
90+
name + "_stop_slice_concat",
91+
tuple(stop_slice),
92+
0,
93+
cast_dtype=trt.int32,
94+
)
95+
stride_slice_tensor = cat(
96+
ctx,
97+
target,
98+
source_ir,
99+
name + "_stride_slice_concat",
100+
tuple(stride_slice),
101+
0,
102+
cast_dtype=trt.int32,
103+
)
104+
105+
if start < 0:
106+
shape = get_shape_with_dynamic_shape(
107+
ctx, target, source_ir, name, output_shape, input
108+
)
109+
start_slice_tensor = convert_binary_elementwise(
110+
ctx,
111+
target,
112+
source_ir,
113+
name + "_sub_start",
114+
trt.ElementWiseOperation.SUB,
115+
shape,
116+
start_slice_tensor,
117+
)
118+
if (stop < 0) or stop_dynamic_None:
119+
shape = get_shape_with_dynamic_shape(
120+
ctx, target, source_ir, name, output_shape, input
121+
)
122+
stop_slice_tensor = convert_binary_elementwise(
123+
ctx,
124+
target,
125+
source_ir,
126+
name + "_sub_stop",
127+
trt.ElementWiseOperation.SUB,
128+
shape,
129+
stop_slice_tensor,
130+
)
131+
132+
# this is required for the ceil operation
133+
output_shape_tensor_num = convert_binary_elementwise(
134+
ctx,
135+
target,
136+
source_ir,
137+
name + "_sub_num",
138+
trt.ElementWiseOperation.SUB,
139+
start_slice_tensor,
140+
stop_slice_tensor,
141+
)
142+
output_shape_tensor_neg = floor_divide(
143+
ctx,
144+
target,
145+
source_ir,
146+
name + "_div",
147+
output_shape_tensor_num,
148+
stride_slice_tensor,
149+
)
150+
output_shape_tensor = convert_binary_elementwise(
151+
ctx,
152+
target,
153+
source_ir,
154+
name + "_prod",
155+
trt.ElementWiseOperation.PROD,
156+
output_shape_tensor_neg,
157+
-1,
158+
)
159+
layer = ctx.net.add_slice(
160+
input, start=trt.Dims(), shape=trt.Dims(), stride=trt.Dims()
161+
)
162+
layer.set_input(1, start_slice_tensor)
163+
layer.set_input(2, output_shape_tensor)
164+
layer.set_input(3, stride_slice_tensor)
165+
return layer.get_output(0)
166+
167+
output_shape[dim] = math.ceil((stop - start) / step)
60168
return slice(
61169
ctx, target, source_ir, name, input, start_slice, output_shape, stride_slice
62170
)

tests/py/dynamo/conversion/test_slice_aten.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,56 @@ class TestSliceConverterDynamicShape(DispatchTestCase):
132132
10,
133133
2,
134134
),
135+
(
136+
"slice_dynamic_dim_dyn_start_dyn_stop_step",
137+
(1, 10, 1),
138+
(1, 10, 10),
139+
(1, 10, 10),
140+
2,
141+
-2,
142+
10,
143+
2,
144+
),
145+
(
146+
"slice_dynamic_dim_dyn_start_stop_dyn_step",
147+
(1, 10, 1),
148+
(1, 10, 10),
149+
(1, 10, 10),
150+
2,
151+
0,
152+
-2,
153+
2,
154+
),
155+
(
156+
"slice_dynamic_dim_dyn_start_stop_None_step",
157+
(1, 10, 1),
158+
(1, 10, 10),
159+
(1, 10, 10),
160+
2,
161+
0,
162+
None,
163+
2,
164+
),
165+
(
166+
"slice_dynamic_dim_dyn_start_dyn_stop_dyn_step",
167+
(1, 10, 1),
168+
(1, 10, 10),
169+
(1, 10, 10),
170+
2,
171+
-8,
172+
-2,
173+
2,
174+
),
175+
(
176+
"slice_dynamic_dim_dyn_start_dyn_stop_dyn_step_ceil",
177+
(1, 10, 1),
178+
(1, 10, 10),
179+
(1, 10, 10),
180+
2,
181+
-9,
182+
-2,
183+
2,
184+
),
135185
]
136186
)
137187
def test_slice(self, _, min_shape, opt_shape, max_shape, dim, start, stop, step):

0 commit comments

Comments
 (0)