Skip to content

dynamic shape for slice converter #2901

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 161 additions & 9 deletions py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
import sys
from typing import Optional, Sequence

import numpy as np
Expand All @@ -14,6 +15,11 @@
get_trt_tensor,
)
from torch_tensorrt.dynamo.conversion.impl.cat import cat
from torch_tensorrt.dynamo.conversion.impl.elementwise import floor_divide
from torch_tensorrt.dynamo.conversion.impl.elementwise.ops import (
convert_binary_elementwise,
)
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
from torch_tensorrt.dynamo.conversion.impl.shape import shape as get_shape
from torch_tensorrt.dynamo.conversion.impl.slice.base import slice
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
Expand All @@ -36,29 +42,175 @@ def slice_op( # TODO: This should be slice not whatever is in base
stop: Optional[int],
step: int,
) -> TRTTensor:
# check if dim is same as dynamic shape dimension
# this is required when stop is ITensor
dynamic_input_dim_equal = False
for i in range(len(input.shape)):
if input.shape[i] == DYNAMIC_DIM and i == dim:
dynamic_input_dim_equal = True

# Special case for start being None
if start is None:
start = 0

# Special case for stop being None
stop_dynamic_None = False
if stop is None:
stop = input.shape[dim]
stop_dynamic_None = True if input.shape[dim] == -1 else False
stop = 0 if input.shape[dim] == -1 else input.shape[dim]

dim = get_positive_dim(dim, len(input.shape))
start = get_positive_dim(start, input.shape[dim])
stop = get_positive_dim(stop, input.shape[dim])

if has_dynamic_shape(input.shape):
# Check whether slice target dim is dynamic shape dim
assert input.shape[dim] != -1, "Can't slice on dynamic shape dimension!"
# Assign the initial start tensor
start_slice = []
# the add_slice will take care of dynamic input shape cases here
if isinstance(start, int):
start_slice = [0] * len(input.shape)
start_slice[dim] = start
else:
for i in range(len(input.shape)):
start_slice.append(0) if i != dim else start_slice.append(start)

# Assign the initial stop tensor
stop_slice = []
if isinstance(stop, int) and dynamic_input_dim_equal:
stop_slice = input.shape
stop_slice[dim] = stop
else:
# required for cases where stop is ITensor and dim != dynamic dim of input
# not required for cases where stop is negative and dim != dynamic dim of inpu
for i in range(len(input.shape)):
if input.shape[i] == DYNAMIC_DIM and i != dim:
stop_slice.append(
get_shape(
ctx, target, source_ir, name + f"_shape_dim_stop_{i}", input, i
)
)
elif i == dim:
stop_slice.append(stop)
else:
stop_slice.append(input.shape[i])

start_slice = [0] * len(input.shape)
start_slice[dim] = start
stride_slice = [1] * len(input.shape)
stride_slice[dim] = step
output_shape = list(input.shape)
output_shape[dim] = math.ceil((stop - start) / step)

if input.shape[dim] != -1 and isinstance(start, int) and isinstance(stop, int):
start = get_positive_dim(start, input.shape[dim])
stop = get_positive_dim(stop, input.shape[dim])
start_slice[dim] = start
else:
# the start and stop or None is dynamic along dim or or start or stop is an ITensor
if (
not (isinstance(start, int))
or not (isinstance(stop, int))
or start < 0
or stop < 0
or stop_dynamic_None
or stop == sys.maxsize
):
# special assignments for dynamic cases
if isinstance(start, int) and start < 0:
start_slice = input.shape
start_slice[dim] = -1 * start
if (isinstance(stop, int) and stop < 0) or stop_dynamic_None:
stop_slice = [0] * len(input.shape)
stop_slice[dim] = -1 * stop
if stop == sys.maxsize:
stop_slice = [0] * len(input.shape)
start_slice_tensor = cat(
ctx,
target,
source_ir,
name + "_start_slice_concat",
tuple(start_slice),
0,
cast_dtype=trt.int32,
)
stop_slice_tensor = cat(
ctx,
target,
source_ir,
name + "_stop_slice_concat",
tuple(stop_slice),
0,
cast_dtype=trt.int32,
)
stride_slice_tensor = cat(
ctx,
target,
source_ir,
name + "_stride_slice_concat",
tuple(stride_slice),
0,
cast_dtype=trt.int32,
)

if isinstance(start, int) and start < 0:
shape = get_shape_with_dynamic_shape(
ctx, target, source_ir, name, output_shape, input
)
start_slice_tensor = convert_binary_elementwise(
ctx,
target,
source_ir,
name + "_sub_start",
trt.ElementWiseOperation.SUB,
shape,
start_slice_tensor,
)
if isinstance(stop, int) and (
(stop < 0) or stop_dynamic_None or stop == sys.maxsize
):
shape = get_shape_with_dynamic_shape(
ctx, target, source_ir, name, output_shape, input
)
stop_slice_tensor = convert_binary_elementwise(
ctx,
target,
source_ir,
name + "_sub_stop",
trt.ElementWiseOperation.SUB,
shape,
stop_slice_tensor,
)

# this is required for the ceil operation
output_shape_tensor_num = convert_binary_elementwise(
ctx,
target,
source_ir,
name + "_sub_num",
trt.ElementWiseOperation.SUB,
start_slice_tensor,
stop_slice_tensor,
)
output_shape_tensor_neg = floor_divide(
ctx,
target,
source_ir,
name + "_div",
output_shape_tensor_num,
stride_slice_tensor,
)
output_shape_tensor = convert_binary_elementwise(
ctx,
target,
source_ir,
name + "_prod",
trt.ElementWiseOperation.PROD,
output_shape_tensor_neg,
-1,
)
layer = ctx.net.add_slice(
input, start=trt.Dims(), shape=trt.Dims(), stride=trt.Dims()
)
layer.set_input(1, start_slice_tensor)
layer.set_input(2, output_shape_tensor)
layer.set_input(3, stride_slice_tensor)
return layer.get_output(0)

output_shape[dim] = math.ceil((stop - start) / step)
return slice(
ctx, target, source_ir, name, input, start_slice, output_shape, stride_slice
)
Expand Down
160 changes: 154 additions & 6 deletions tests/py/dynamo/conversion/test_slice_aten.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests

from torch_tensorrt import Input

from .harness import DispatchTestCase
Expand Down Expand Up @@ -53,11 +52,159 @@ def forward(self, input):
class TestSliceConverterDynamicShape(DispatchTestCase):
@parameterized.expand(
[
("slice_dim_start_stop_step", 1, 0, 7, 2),
("slice_dim_start_stop_step", 1, 0, 10, 2),
(
"slice_dynamic_dim_start_stop_step_offset",
(1, 10, 1),
(1, 10, 10),
(1, 10, 10),
1,
0,
7,
2,
),
(
"slice_dynamic_dim_start_stop_step",
(1, 10, 1),
(1, 10, 10),
(1, 10, 10),
1,
0,
10,
2,
),
(
"slice_dynamic_dim_start_stop_step_negatives",
(1, 10, 10),
(10, 10, 10),
(10, 10, 10),
-2,
-2,
-1,
1,
),
(
"slice_dim_start_stop_step_max_int",
(1, 10, 10),
(10, 10, 10),
(10, 10, 10),
2,
0,
2**63 - 1,
1,
),
(
"slice_dim_start_stop_step_past_end",
(1, 10, 10),
(10, 10, 10),
(10, 10, 10),
2,
0,
2048,
1,
),
(
"slice_dim_start_stop_step_none",
(1, 10, 10),
(10, 10, 10),
(10, 10, 10),
2,
None,
None,
1,
),
(
"slice_dynamic_dim_start_stop_step_offset_4D",
(1, 10, 1, 3),
(1, 10, 10, 3),
(1, 10, 10, 3),
1,
0,
7,
2,
),
(
"slice_dynamic_dim_start_stop_step_4D",
(1, 10, 1, 3),
(1, 10, 10, 3),
(1, 10, 10, 3),
1,
0,
10,
2,
),
(
"slice_dynamic_dim_dyn_start_dyn_stop_step",
(1, 10, 1),
(1, 10, 10),
(1, 10, 10),
2,
-2,
10,
2,
),
(
"slice_dynamic_dim_dyn_start_stop_dyn_step",
(1, 10, 1),
(1, 10, 10),
(1, 10, 10),
2,
0,
-2,
2,
),
(
"slice_dynamic_dim_dyn_start_stop_None_step",
(1, 10, 1),
(1, 10, 10),
(1, 10, 10),
2,
0,
None,
2,
),
(
"slice_dynamic_dim_dyn_start_dyn_stop_dyn_step",
(1, 10, 1),
(1, 10, 10),
(1, 10, 10),
2,
-8,
-2,
2,
),
(
"slice_dynamic_dim_dyn_start_dyn_stop_dyn_step_ceil",
(1, 10, 1),
(1, 10, 10),
(1, 10, 10),
2,
-9,
-2,
2,
),
(
"slice_dynamic_dim_dyn_start_dyn_stop_dyn_step_diff_dim",
(1, 10, 1),
(1, 10, 10),
(1, 10, 10),
0,
-8,
-2,
2,
),
(
"slice_dynamic_dim_dyn_start_dyn_stop_dyn_step_diff_dim_ceil",
(1, 10, 1),
(1, 10, 10),
(1, 10, 10),
0,
-9,
-2,
2,
),
]
)
def test_slice(self, _, dim, start, stop, step):
def test_slice(self, _, min_shape, opt_shape, max_shape, dim, start, stop, step):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -68,9 +215,10 @@ def forward(self, input):

input_specs = [
Input(
shape=(1, 10, -1),
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=torch.float32,
shape_ranges=[((1, 10, 1), (1, 10, 10), (1, 10, 10))],
),
]
self.run_test_with_dynamic_shape(
Expand Down
Loading