-
Notifications
You must be signed in to change notification settings - Fork 364
dynamic shape for slice converter #2901
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
import math | ||
import sys | ||
from typing import Optional, Sequence | ||
|
||
import numpy as np | ||
|
@@ -14,6 +15,11 @@ | |
get_trt_tensor, | ||
) | ||
from torch_tensorrt.dynamo.conversion.impl.cat import cat | ||
from torch_tensorrt.dynamo.conversion.impl.elementwise import floor_divide | ||
from torch_tensorrt.dynamo.conversion.impl.elementwise.ops import ( | ||
convert_binary_elementwise, | ||
) | ||
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape | ||
from torch_tensorrt.dynamo.conversion.impl.shape import shape as get_shape | ||
from torch_tensorrt.dynamo.conversion.impl.slice.base import slice | ||
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM | ||
|
@@ -36,29 +42,176 @@ def slice_op( # TODO: This should be slice not whatever is in base | |
stop: Optional[int], | ||
step: int, | ||
) -> TRTTensor: | ||
# check if dim is same as dynamic shape dimension | ||
# this is required when stop is ITensor | ||
dynamic_input_dim_equal = False | ||
for i in range(len(input.shape)): | ||
if input.shape[i] == DYNAMIC_DIM and i == dim: | ||
dynamic_input_dim_equal = True | ||
|
||
# Special case for start being None | ||
if start is None: | ||
start = 0 | ||
|
||
# Special case for stop being None | ||
stop_dynamic_None = False | ||
if stop is None: | ||
stop_dynamic_None = True if input.shape[dim] == -1 else False | ||
if stop is None: | ||
stop = input.shape[dim] | ||
stop = 0 if input.shape[dim] == -1 else input.shape[dim] | ||
|
||
dim = get_positive_dim(dim, len(input.shape)) | ||
start = get_positive_dim(start, input.shape[dim]) | ||
stop = get_positive_dim(stop, input.shape[dim]) | ||
|
||
if has_dynamic_shape(input.shape): | ||
# Check whether slice target dim is dynamic shape dim | ||
assert input.shape[dim] != -1, "Can't slice on dynamic shape dimension!" | ||
# Assign the initial start tensor | ||
start_slice = [] | ||
# the add_slice will take care of dynamic input shape cases here | ||
if isinstance(start, int): | ||
start_slice = [0] * len(input.shape) | ||
start_slice[dim] = start | ||
else: | ||
for i in range(len(input.shape)): | ||
start_slice.append(0) if i == dim else start_slice.append(start) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you explain what is happening here ? Why are we appending start (an ITensor) to all the other dimensions except dim ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for pointing this out. start_slice should have start appended for i==dim. |
||
|
||
# Assign the initial stop tensor | ||
stop_slice = [] | ||
if isinstance(stop, int) and dynamic_input_dim_equal: | ||
stop_slice = input.shape | ||
stop_slice[dim] = stop | ||
else: | ||
# required for cases where stop is ITensor and dim != dynamic dim of input | ||
# not required for cases where stop is negative and dim != dynamic dim of inpu | ||
for i in range(len(input.shape)): | ||
if input.shape[i] == DYNAMIC_DIM and i != dim: | ||
stop_slice.append( | ||
get_shape( | ||
ctx, target, source_ir, name + f"_shape_dim_stop_{i}", input, i | ||
) | ||
) | ||
elif i == dim: | ||
stop_slice.append(stop) | ||
else: | ||
stop_slice.append(input.shape[i]) | ||
|
||
start_slice = [0] * len(input.shape) | ||
start_slice[dim] = start | ||
stride_slice = [1] * len(input.shape) | ||
stride_slice[dim] = step | ||
output_shape = list(input.shape) | ||
output_shape[dim] = math.ceil((stop - start) / step) | ||
|
||
if input.shape[dim] != -1 and isinstance(start, int) and isinstance(stop, int): | ||
start = get_positive_dim(start, input.shape[dim]) | ||
stop = get_positive_dim(stop, input.shape[dim]) | ||
start_slice[dim] = start | ||
else: | ||
# the start and stop or None is dynamic along dim or or start or stop is an ITensor | ||
if ( | ||
not (isinstance(start, int)) | ||
or not (isinstance(stop, int)) | ||
or start < 0 | ||
or stop < 0 | ||
or stop_dynamic_None | ||
or stop == sys.maxsize | ||
): | ||
# special assignments for dynamic cases | ||
if isinstance(start, int) and start < 0: | ||
start_slice = input.shape | ||
start_slice[dim] = -1 * start | ||
if (isinstance(stop, int) and stop < 0) or stop_dynamic_None: | ||
stop_slice = [0] * len(input.shape) | ||
stop_slice[dim] = -1 * stop | ||
if stop == sys.maxsize: | ||
stop_slice = [0] * len(input.shape) | ||
start_slice_tensor = cat( | ||
ctx, | ||
target, | ||
source_ir, | ||
name + "_start_slice_concat", | ||
tuple(start_slice), | ||
0, | ||
cast_dtype=trt.int32, | ||
) | ||
stop_slice_tensor = cat( | ||
ctx, | ||
target, | ||
source_ir, | ||
name + "_stop_slice_concat", | ||
tuple(stop_slice), | ||
0, | ||
cast_dtype=trt.int32, | ||
) | ||
stride_slice_tensor = cat( | ||
ctx, | ||
target, | ||
source_ir, | ||
name + "_stride_slice_concat", | ||
tuple(stride_slice), | ||
0, | ||
cast_dtype=trt.int32, | ||
) | ||
|
||
if isinstance(start, int) and start < 0: | ||
shape = get_shape_with_dynamic_shape( | ||
ctx, target, source_ir, name, output_shape, input | ||
) | ||
start_slice_tensor = convert_binary_elementwise( | ||
ctx, | ||
target, | ||
source_ir, | ||
name + "_sub_start", | ||
trt.ElementWiseOperation.SUB, | ||
shape, | ||
start_slice_tensor, | ||
) | ||
if isinstance(stop, int) and ( | ||
(stop < 0) or stop_dynamic_None or stop == sys.maxsize | ||
): | ||
shape = get_shape_with_dynamic_shape( | ||
ctx, target, source_ir, name, output_shape, input | ||
) | ||
stop_slice_tensor = convert_binary_elementwise( | ||
ctx, | ||
target, | ||
source_ir, | ||
name + "_sub_stop", | ||
trt.ElementWiseOperation.SUB, | ||
shape, | ||
stop_slice_tensor, | ||
) | ||
|
||
# this is required for the ceil operation | ||
output_shape_tensor_num = convert_binary_elementwise( | ||
ctx, | ||
target, | ||
source_ir, | ||
name + "_sub_num", | ||
trt.ElementWiseOperation.SUB, | ||
start_slice_tensor, | ||
stop_slice_tensor, | ||
) | ||
output_shape_tensor_neg = floor_divide( | ||
ctx, | ||
target, | ||
source_ir, | ||
name + "_div", | ||
output_shape_tensor_num, | ||
stride_slice_tensor, | ||
) | ||
output_shape_tensor = convert_binary_elementwise( | ||
ctx, | ||
target, | ||
source_ir, | ||
name + "_prod", | ||
trt.ElementWiseOperation.PROD, | ||
output_shape_tensor_neg, | ||
-1, | ||
) | ||
layer = ctx.net.add_slice( | ||
input, start=trt.Dims(), shape=trt.Dims(), stride=trt.Dims() | ||
) | ||
layer.set_input(1, start_slice_tensor) | ||
layer.set_input(2, output_shape_tensor) | ||
layer.set_input(3, stride_slice_tensor) | ||
return layer.get_output(0) | ||
|
||
output_shape[dim] = math.ceil((stop - start) / step) | ||
return slice( | ||
ctx, target, source_ir, name, input, start_slice, output_shape, stride_slice | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you merge the two if stop is None statements ?