1
1
import math
2
+ import sys
2
3
from typing import Optional , Sequence
3
4
4
5
import numpy as np
14
15
get_trt_tensor ,
15
16
)
16
17
from torch_tensorrt .dynamo .conversion .impl .cat import cat
18
+ from torch_tensorrt .dynamo .conversion .impl .elementwise import floor_divide
19
+ from torch_tensorrt .dynamo .conversion .impl .elementwise .ops import (
20
+ convert_binary_elementwise ,
21
+ )
22
+ from torch_tensorrt .dynamo .conversion .impl .shape import get_shape_with_dynamic_shape
17
23
from torch_tensorrt .dynamo .conversion .impl .shape import shape as get_shape
18
24
from torch_tensorrt .dynamo .conversion .impl .slice .base import slice
19
25
from torch_tensorrt .dynamo .utils import DYNAMIC_DIM
@@ -36,29 +42,176 @@ def slice_op( # TODO: This should be slice not whatever is in base
36
42
stop : Optional [int ],
37
43
step : int ,
38
44
) -> TRTTensor :
45
+ # check if dim is same as dynamic shape dimension
46
+ # this is required when stop is ITensor
47
+ dynamic_input_dim_equal = False
48
+ for i in range (len (input .shape )):
49
+ if input .shape [i ] == DYNAMIC_DIM and i == dim :
50
+ dynamic_input_dim_equal = True
51
+
39
52
# Special case for start being None
40
53
if start is None :
41
54
start = 0
42
55
43
56
# Special case for stop being None
57
+ stop_dynamic_None = False
58
+ if stop is None :
59
+ stop_dynamic_None = True if input .shape [dim ] == - 1 else False
44
60
if stop is None :
45
- stop = input .shape [dim ]
61
+ stop = 0 if input . shape [ dim ] == - 1 else input .shape [dim ]
46
62
47
63
dim = get_positive_dim (dim , len (input .shape ))
48
- start = get_positive_dim (start , input .shape [dim ])
49
- stop = get_positive_dim (stop , input .shape [dim ])
50
64
51
- if has_dynamic_shape (input .shape ):
52
- # Check whether slice target dim is dynamic shape dim
53
- assert input .shape [dim ] != - 1 , "Can't slice on dynamic shape dimension!"
65
+ # Assign the initial start tensor
66
+ start_slice = []
67
+ # the add_slice will take care of dynamic input shape cases here
68
+ if isinstance (start , int ):
69
+ start_slice = [0 ] * len (input .shape )
70
+ start_slice [dim ] = start
71
+ else :
72
+ for i in range (len (input .shape )):
73
+ start_slice .append (0 ) if i == dim else start_slice .append (start )
74
+
75
+ # Assign the initial stop tensor
76
+ stop_slice = []
77
+ if isinstance (stop , int ) and dynamic_input_dim_equal :
78
+ stop_slice = input .shape
79
+ stop_slice [dim ] = stop
80
+ else :
81
+ # required for cases where stop is ITensor and dim != dynamic dim of input
82
+ # not required for cases where stop is negative and dim != dynamic dim of inpu
83
+ for i in range (len (input .shape )):
84
+ if input .shape [i ] == DYNAMIC_DIM and i != dim :
85
+ stop_slice .append (
86
+ get_shape (
87
+ ctx , target , source_ir , name + f"_shape_dim_stop_{ i } " , input , i
88
+ )
89
+ )
90
+ elif i == dim :
91
+ stop_slice .append (stop )
92
+ else :
93
+ stop_slice .append (input .shape [i ])
54
94
55
- start_slice = [0 ] * len (input .shape )
56
- start_slice [dim ] = start
57
95
stride_slice = [1 ] * len (input .shape )
58
96
stride_slice [dim ] = step
59
97
output_shape = list (input .shape )
60
- output_shape [dim ] = math .ceil ((stop - start ) / step )
61
98
99
+ if input .shape [dim ] != - 1 and isinstance (start , int ) and isinstance (stop , int ):
100
+ start = get_positive_dim (start , input .shape [dim ])
101
+ stop = get_positive_dim (stop , input .shape [dim ])
102
+ start_slice [dim ] = start
103
+ else :
104
+ # the start and stop or None is dynamic along dim or or start or stop is an ITensor
105
+ if (
106
+ not (isinstance (start , int ))
107
+ or not (isinstance (stop , int ))
108
+ or start < 0
109
+ or stop < 0
110
+ or stop_dynamic_None
111
+ or stop == sys .maxsize
112
+ ):
113
+ # special assignments for dynamic cases
114
+ if isinstance (start , int ) and start < 0 :
115
+ start_slice = input .shape
116
+ start_slice [dim ] = - 1 * start
117
+ if (isinstance (stop , int ) and stop < 0 ) or stop_dynamic_None :
118
+ stop_slice = [0 ] * len (input .shape )
119
+ stop_slice [dim ] = - 1 * stop
120
+ if stop == sys .maxsize :
121
+ stop_slice = [0 ] * len (input .shape )
122
+ start_slice_tensor = cat (
123
+ ctx ,
124
+ target ,
125
+ source_ir ,
126
+ name + "_start_slice_concat" ,
127
+ tuple (start_slice ),
128
+ 0 ,
129
+ cast_dtype = trt .int32 ,
130
+ )
131
+ stop_slice_tensor = cat (
132
+ ctx ,
133
+ target ,
134
+ source_ir ,
135
+ name + "_stop_slice_concat" ,
136
+ tuple (stop_slice ),
137
+ 0 ,
138
+ cast_dtype = trt .int32 ,
139
+ )
140
+ stride_slice_tensor = cat (
141
+ ctx ,
142
+ target ,
143
+ source_ir ,
144
+ name + "_stride_slice_concat" ,
145
+ tuple (stride_slice ),
146
+ 0 ,
147
+ cast_dtype = trt .int32 ,
148
+ )
149
+
150
+ if isinstance (start , int ) and start < 0 :
151
+ shape = get_shape_with_dynamic_shape (
152
+ ctx , target , source_ir , name , output_shape , input
153
+ )
154
+ start_slice_tensor = convert_binary_elementwise (
155
+ ctx ,
156
+ target ,
157
+ source_ir ,
158
+ name + "_sub_start" ,
159
+ trt .ElementWiseOperation .SUB ,
160
+ shape ,
161
+ start_slice_tensor ,
162
+ )
163
+ if isinstance (stop , int ) and (
164
+ (stop < 0 ) or stop_dynamic_None or stop == sys .maxsize
165
+ ):
166
+ shape = get_shape_with_dynamic_shape (
167
+ ctx , target , source_ir , name , output_shape , input
168
+ )
169
+ stop_slice_tensor = convert_binary_elementwise (
170
+ ctx ,
171
+ target ,
172
+ source_ir ,
173
+ name + "_sub_stop" ,
174
+ trt .ElementWiseOperation .SUB ,
175
+ shape ,
176
+ stop_slice_tensor ,
177
+ )
178
+
179
+ # this is required for the ceil operation
180
+ output_shape_tensor_num = convert_binary_elementwise (
181
+ ctx ,
182
+ target ,
183
+ source_ir ,
184
+ name + "_sub_num" ,
185
+ trt .ElementWiseOperation .SUB ,
186
+ start_slice_tensor ,
187
+ stop_slice_tensor ,
188
+ )
189
+ output_shape_tensor_neg = floor_divide (
190
+ ctx ,
191
+ target ,
192
+ source_ir ,
193
+ name + "_div" ,
194
+ output_shape_tensor_num ,
195
+ stride_slice_tensor ,
196
+ )
197
+ output_shape_tensor = convert_binary_elementwise (
198
+ ctx ,
199
+ target ,
200
+ source_ir ,
201
+ name + "_prod" ,
202
+ trt .ElementWiseOperation .PROD ,
203
+ output_shape_tensor_neg ,
204
+ - 1 ,
205
+ )
206
+ layer = ctx .net .add_slice (
207
+ input , start = trt .Dims (), shape = trt .Dims (), stride = trt .Dims ()
208
+ )
209
+ layer .set_input (1 , start_slice_tensor )
210
+ layer .set_input (2 , output_shape_tensor )
211
+ layer .set_input (3 , stride_slice_tensor )
212
+ return layer .get_output (0 )
213
+
214
+ output_shape [dim ] = math .ceil ((stop - start ) / step )
62
215
return slice (
63
216
ctx , target , source_ir , name , input , start_slice , output_shape , stride_slice
64
217
)
0 commit comments