1
1
import math
2
+ import sys
2
3
from typing import Optional , Sequence
3
4
4
5
import numpy as np
14
15
get_trt_tensor ,
15
16
)
16
17
from torch_tensorrt .dynamo .conversion .impl .cat import cat
18
+ from torch_tensorrt .dynamo .conversion .impl .elementwise import floor_divide
19
+ from torch_tensorrt .dynamo .conversion .impl .elementwise .ops import (
20
+ convert_binary_elementwise ,
21
+ )
22
+ from torch_tensorrt .dynamo .conversion .impl .shape import get_shape_with_dynamic_shape
17
23
from torch_tensorrt .dynamo .conversion .impl .shape import shape as get_shape
18
24
from torch_tensorrt .dynamo .conversion .impl .slice .base import slice
19
25
from torch_tensorrt .dynamo .utils import DYNAMIC_DIM
@@ -36,29 +42,175 @@ def slice_op( # TODO: This should be slice not whatever is in base
36
42
stop : Optional [int ],
37
43
step : int ,
38
44
) -> TRTTensor :
45
+ # check if dim is same as dynamic shape dimension
46
+ # this is required when stop is ITensor
47
+ dynamic_input_dim_equal = False
48
+ for i in range (len (input .shape )):
49
+ if input .shape [i ] == DYNAMIC_DIM and i == dim :
50
+ dynamic_input_dim_equal = True
51
+
39
52
# Special case for start being None
40
53
if start is None :
41
54
start = 0
42
55
43
56
# Special case for stop being None
57
+ stop_dynamic_None = False
44
58
if stop is None :
45
- stop = input .shape [dim ]
59
+ stop_dynamic_None = True if input .shape [dim ] == - 1 else False
60
+ stop = 0 if input .shape [dim ] == - 1 else input .shape [dim ]
46
61
47
62
dim = get_positive_dim (dim , len (input .shape ))
48
- start = get_positive_dim (start , input .shape [dim ])
49
- stop = get_positive_dim (stop , input .shape [dim ])
50
63
51
- if has_dynamic_shape (input .shape ):
52
- # Check whether slice target dim is dynamic shape dim
53
- assert input .shape [dim ] != - 1 , "Can't slice on dynamic shape dimension!"
64
+ # Assign the initial start tensor
65
+ start_slice = []
66
+ # the add_slice will take care of dynamic input shape cases here
67
+ if isinstance (start , int ):
68
+ start_slice = [0 ] * len (input .shape )
69
+ start_slice [dim ] = start
70
+ else :
71
+ for i in range (len (input .shape )):
72
+ start_slice .append (0 ) if i != dim else start_slice .append (start )
73
+
74
+ # Assign the initial stop tensor
75
+ stop_slice = []
76
+ if isinstance (stop , int ) and dynamic_input_dim_equal :
77
+ stop_slice = input .shape
78
+ stop_slice [dim ] = stop
79
+ else :
80
+ # required for cases where stop is ITensor and dim != dynamic dim of input
81
+ # not required for cases where stop is negative and dim != dynamic dim of inpu
82
+ for i in range (len (input .shape )):
83
+ if input .shape [i ] == DYNAMIC_DIM and i != dim :
84
+ stop_slice .append (
85
+ get_shape (
86
+ ctx , target , source_ir , name + f"_shape_dim_stop_{ i } " , input , i
87
+ )
88
+ )
89
+ elif i == dim :
90
+ stop_slice .append (stop )
91
+ else :
92
+ stop_slice .append (input .shape [i ])
54
93
55
- start_slice = [0 ] * len (input .shape )
56
- start_slice [dim ] = start
57
94
stride_slice = [1 ] * len (input .shape )
58
95
stride_slice [dim ] = step
59
96
output_shape = list (input .shape )
60
- output_shape [dim ] = math .ceil ((stop - start ) / step )
61
97
98
+ if input .shape [dim ] != - 1 and isinstance (start , int ) and isinstance (stop , int ):
99
+ start = get_positive_dim (start , input .shape [dim ])
100
+ stop = get_positive_dim (stop , input .shape [dim ])
101
+ start_slice [dim ] = start
102
+ else :
103
+ # the start and stop or None is dynamic along dim or or start or stop is an ITensor
104
+ if (
105
+ not (isinstance (start , int ))
106
+ or not (isinstance (stop , int ))
107
+ or start < 0
108
+ or stop < 0
109
+ or stop_dynamic_None
110
+ or stop == sys .maxsize
111
+ ):
112
+ # special assignments for dynamic cases
113
+ if isinstance (start , int ) and start < 0 :
114
+ start_slice = input .shape
115
+ start_slice [dim ] = - 1 * start
116
+ if (isinstance (stop , int ) and stop < 0 ) or stop_dynamic_None :
117
+ stop_slice = [0 ] * len (input .shape )
118
+ stop_slice [dim ] = - 1 * stop
119
+ if stop == sys .maxsize :
120
+ stop_slice = [0 ] * len (input .shape )
121
+ start_slice_tensor = cat (
122
+ ctx ,
123
+ target ,
124
+ source_ir ,
125
+ name + "_start_slice_concat" ,
126
+ tuple (start_slice ),
127
+ 0 ,
128
+ cast_dtype = trt .int32 ,
129
+ )
130
+ stop_slice_tensor = cat (
131
+ ctx ,
132
+ target ,
133
+ source_ir ,
134
+ name + "_stop_slice_concat" ,
135
+ tuple (stop_slice ),
136
+ 0 ,
137
+ cast_dtype = trt .int32 ,
138
+ )
139
+ stride_slice_tensor = cat (
140
+ ctx ,
141
+ target ,
142
+ source_ir ,
143
+ name + "_stride_slice_concat" ,
144
+ tuple (stride_slice ),
145
+ 0 ,
146
+ cast_dtype = trt .int32 ,
147
+ )
148
+
149
+ if isinstance (start , int ) and start < 0 :
150
+ shape = get_shape_with_dynamic_shape (
151
+ ctx , target , source_ir , name , output_shape , input
152
+ )
153
+ start_slice_tensor = convert_binary_elementwise (
154
+ ctx ,
155
+ target ,
156
+ source_ir ,
157
+ name + "_sub_start" ,
158
+ trt .ElementWiseOperation .SUB ,
159
+ shape ,
160
+ start_slice_tensor ,
161
+ )
162
+ if isinstance (stop , int ) and (
163
+ (stop < 0 ) or stop_dynamic_None or stop == sys .maxsize
164
+ ):
165
+ shape = get_shape_with_dynamic_shape (
166
+ ctx , target , source_ir , name , output_shape , input
167
+ )
168
+ stop_slice_tensor = convert_binary_elementwise (
169
+ ctx ,
170
+ target ,
171
+ source_ir ,
172
+ name + "_sub_stop" ,
173
+ trt .ElementWiseOperation .SUB ,
174
+ shape ,
175
+ stop_slice_tensor ,
176
+ )
177
+
178
+ # this is required for the ceil operation
179
+ output_shape_tensor_num = convert_binary_elementwise (
180
+ ctx ,
181
+ target ,
182
+ source_ir ,
183
+ name + "_sub_num" ,
184
+ trt .ElementWiseOperation .SUB ,
185
+ start_slice_tensor ,
186
+ stop_slice_tensor ,
187
+ )
188
+ output_shape_tensor_neg = floor_divide (
189
+ ctx ,
190
+ target ,
191
+ source_ir ,
192
+ name + "_div" ,
193
+ output_shape_tensor_num ,
194
+ stride_slice_tensor ,
195
+ )
196
+ output_shape_tensor = convert_binary_elementwise (
197
+ ctx ,
198
+ target ,
199
+ source_ir ,
200
+ name + "_prod" ,
201
+ trt .ElementWiseOperation .PROD ,
202
+ output_shape_tensor_neg ,
203
+ - 1 ,
204
+ )
205
+ layer = ctx .net .add_slice (
206
+ input , start = trt .Dims (), shape = trt .Dims (), stride = trt .Dims ()
207
+ )
208
+ layer .set_input (1 , start_slice_tensor )
209
+ layer .set_input (2 , output_shape_tensor )
210
+ layer .set_input (3 , stride_slice_tensor )
211
+ return layer .get_output (0 )
212
+
213
+ output_shape [dim ] = math .ceil ((stop - start ) / step )
62
214
return slice (
63
215
ctx , target , source_ir , name , input , start_slice , output_shape , stride_slice
64
216
)
0 commit comments