1
1
import math
2
+ import sys
2
3
from typing import Optional, Sequence
3
4
4
5
import numpy as np
14
15
get_trt_tensor,
15
16
)
16
17
from torch_tensorrt.dynamo.conversion.impl.cat import cat
18
+ from torch_tensorrt.dynamo.conversion.impl.elementwise import floor_divide
19
+ from torch_tensorrt.dynamo.conversion.impl.elementwise.ops import (
20
+ convert_binary_elementwise,
21
+ )
22
+ from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
17
23
from torch_tensorrt.dynamo.conversion.impl.shape import shape as get_shape
18
24
from torch_tensorrt.dynamo.conversion.impl.slice.base import slice
19
25
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
@@ -36,29 +42,175 @@ def slice_op( # TODO: This should be slice not whatever is in base
36
42
stop: Optional[int],
37
43
step: int,
38
44
) -> TRTTensor:
45
+ # check if dim is same as dynamic shape dimension
46
+ # this is required when stop is ITensor
47
+ dynamic_input_dim_equal = False
48
+ for i in range(len(input.shape)):
49
+ if input.shape[i] == DYNAMIC_DIM and i == dim:
50
+ dynamic_input_dim_equal = True
51
+
39
52
# Special case for start being None
40
53
if start is None:
41
54
start = 0
42
55
43
56
# Special case for stop being None
57
+ stop_dynamic_None = False
44
58
if stop is None:
45
- stop = input.shape[dim]
59
+ stop_dynamic_None = True if input.shape[dim] == -1 else False
60
+ stop = 0 if input.shape[dim] == -1 else input.shape[dim]
46
61
47
62
dim = get_positive_dim(dim, len(input.shape))
48
- start = get_positive_dim(start, input.shape[dim])
49
- stop = get_positive_dim(stop, input.shape[dim])
50
63
51
- if has_dynamic_shape(input.shape):
52
- # Check whether slice target dim is dynamic shape dim
53
- assert input.shape[dim] != -1, "Can't slice on dynamic shape dimension!"
64
+ # Assign the initial start tensor
65
+ start_slice = []
66
+ # the add_slice will take care of dynamic input shape cases here
67
+ if isinstance(start, int):
68
+ start_slice = [0] * len(input.shape)
69
+ start_slice[dim] = start
70
+ else:
71
+ for i in range(len(input.shape)):
72
+ start_slice.append(0) if i != dim else start_slice.append(start)
73
+
74
+ # Assign the initial stop tensor
75
+ stop_slice = []
76
+ if isinstance(stop, int) and dynamic_input_dim_equal:
77
+ stop_slice = input.shape
78
+ stop_slice[dim] = stop
79
+ else:
80
+ # required for cases where stop is ITensor and dim != dynamic dim of input
81
+ # not required for cases where stop is negative and dim != dynamic dim of inpu
82
+ for i in range(len(input.shape)):
83
+ if input.shape[i] == DYNAMIC_DIM and i != dim:
84
+ stop_slice.append(
85
+ get_shape(
86
+ ctx, target, source_ir, name + f"_shape_dim_stop_{i}", input, i
87
+ )
88
+ )
89
+ elif i == dim:
90
+ stop_slice.append(stop)
91
+ else:
92
+ stop_slice.append(input.shape[i])
54
93
55
- start_slice = [0] * len(input.shape)
56
- start_slice[dim] = start
57
94
stride_slice = [1] * len(input.shape)
58
95
stride_slice[dim] = step
59
96
output_shape = list(input.shape)
60
- output_shape[dim] = math.ceil((stop - start) / step)
61
97
98
+ if input.shape[dim] != -1 and isinstance(start, int) and isinstance(stop, int):
99
+ start = get_positive_dim(start, input.shape[dim])
100
+ stop = get_positive_dim(stop, input.shape[dim])
101
+ start_slice[dim] = start
102
+ else:
103
+ # the start and stop or None is dynamic along dim or or start or stop is an ITensor
104
+ if (
105
+ not (isinstance(start, int))
106
+ or not (isinstance(stop, int))
107
+ or start < 0
108
+ or stop < 0
109
+ or stop_dynamic_None
110
+ or stop == sys.maxsize
111
+ ):
112
+ # special assignments for dynamic cases
113
+ if isinstance(start, int) and start < 0:
114
+ start_slice = input.shape
115
+ start_slice[dim] = -1 * start
116
+ if (isinstance(stop, int) and stop < 0) or stop_dynamic_None:
117
+ stop_slice = [0] * len(input.shape)
118
+ stop_slice[dim] = -1 * stop
119
+ if stop == sys.maxsize:
120
+ stop_slice = [0] * len(input.shape)
121
+ start_slice_tensor = cat(
122
+ ctx,
123
+ target,
124
+ source_ir,
125
+ name + "_start_slice_concat",
126
+ tuple(start_slice),
127
+ 0,
128
+ cast_dtype=trt.int32,
129
+ )
130
+ stop_slice_tensor = cat(
131
+ ctx,
132
+ target,
133
+ source_ir,
134
+ name + "_stop_slice_concat",
135
+ tuple(stop_slice),
136
+ 0,
137
+ cast_dtype=trt.int32,
138
+ )
139
+ stride_slice_tensor = cat(
140
+ ctx,
141
+ target,
142
+ source_ir,
143
+ name + "_stride_slice_concat",
144
+ tuple(stride_slice),
145
+ 0,
146
+ cast_dtype=trt.int32,
147
+ )
148
+
149
+ if isinstance(start, int) and start < 0:
150
+ shape = get_shape_with_dynamic_shape(
151
+ ctx, target, source_ir, name, output_shape, input
152
+ )
153
+ start_slice_tensor = convert_binary_elementwise(
154
+ ctx,
155
+ target,
156
+ source_ir,
157
+ name + "_sub_start",
158
+ trt.ElementWiseOperation.SUB,
159
+ shape,
160
+ start_slice_tensor,
161
+ )
162
+ if isinstance(stop, int) and (
163
+ (stop < 0) or stop_dynamic_None or stop == sys.maxsize
164
+ ):
165
+ shape = get_shape_with_dynamic_shape(
166
+ ctx, target, source_ir, name, output_shape, input
167
+ )
168
+ stop_slice_tensor = convert_binary_elementwise(
169
+ ctx,
170
+ target,
171
+ source_ir,
172
+ name + "_sub_stop",
173
+ trt.ElementWiseOperation.SUB,
174
+ shape,
175
+ stop_slice_tensor,
176
+ )
177
+
178
+ # this is required for the ceil operation
179
+ output_shape_tensor_num = convert_binary_elementwise(
180
+ ctx,
181
+ target,
182
+ source_ir,
183
+ name + "_sub_num",
184
+ trt.ElementWiseOperation.SUB,
185
+ start_slice_tensor,
186
+ stop_slice_tensor,
187
+ )
188
+ output_shape_tensor_neg = floor_divide(
189
+ ctx,
190
+ target,
191
+ source_ir,
192
+ name + "_div",
193
+ output_shape_tensor_num,
194
+ stride_slice_tensor,
195
+ )
196
+ output_shape_tensor = convert_binary_elementwise(
197
+ ctx,
198
+ target,
199
+ source_ir,
200
+ name + "_prod",
201
+ trt.ElementWiseOperation.PROD,
202
+ output_shape_tensor_neg,
203
+ -1,
204
+ )
205
+ layer = ctx.net.add_slice(
206
+ input, start=trt.Dims(), shape=trt.Dims(), stride=trt.Dims()
207
+ )
208
+ layer.set_input(1, start_slice_tensor)
209
+ layer.set_input(2, output_shape_tensor)
210
+ layer.set_input(3, stride_slice_tensor)
211
+ return layer.get_output(0)
212
+
213
+ output_shape[dim] = math.ceil((stop - start) / step)
62
214
return slice(
63
215
ctx, target, source_ir, name, input, start_slice, output_shape, stride_slice
64
216
)
0 commit comments