14
14
get_trt_tensor ,
15
15
)
16
16
from torch_tensorrt .dynamo .conversion .impl .cat import cat
17
+ from torch_tensorrt .dynamo .conversion .impl .elementwise import floor_divide
18
+ from torch_tensorrt .dynamo .conversion .impl .elementwise .ops import (
19
+ convert_binary_elementwise ,
20
+ )
21
+ from torch_tensorrt .dynamo .conversion .impl .shape import get_shape_with_dynamic_shape
17
22
from torch_tensorrt .dynamo .conversion .impl .slice .base import slice
18
23
from torch_tensorrt .fx .converters .converter_utils import (
19
24
has_dynamic_shape ,
@@ -39,24 +44,127 @@ def slice_op( # TODO: This should be slice not whatever is in base
39
44
start = 0
40
45
41
46
# Special case for stop being None
47
+ stop_dynamic_None = False
48
+ if stop is None :
49
+ stop_dynamic_None = True if input .shape [dim ] == - 1 else False
42
50
if stop is None :
43
- stop = input .shape [dim ]
51
+ stop = 0 if input . shape [ dim ] == - 1 else input .shape [dim ]
44
52
45
53
dim = get_positive_dim (dim , len (input .shape ))
46
- start = get_positive_dim (start , input .shape [dim ])
47
- stop = get_positive_dim (stop , input .shape [dim ])
48
-
49
- if has_dynamic_shape (input .shape ):
50
- # Check whether slice target dim is dynamic shape dim
51
- assert input .shape [dim ] != - 1 , "Can't slice on dynamic shape dimension!"
52
-
53
54
start_slice = [0 ] * len (input .shape )
54
55
start_slice [dim ] = start
56
+ stop_slice = input .shape
57
+ stop_slice [dim ] = stop
55
58
stride_slice = [1 ] * len (input .shape )
56
59
stride_slice [dim ] = step
57
60
output_shape = list (input .shape )
58
- output_shape [dim ] = math .ceil ((stop - start ) / step )
59
61
62
+ if input .shape [dim ] != - 1 :
63
+ start = get_positive_dim (start , input .shape [dim ])
64
+ stop = get_positive_dim (stop , input .shape [dim ])
65
+ start_slice [dim ] = start
66
+ else :
67
+ # the start and stop or None is dynamic along dim
68
+ if start < 0 or stop < 0 or stop_dynamic_None :
69
+ # special assignments for dynamic cases
70
+ if start < 0 :
71
+ start_slice = input .shape
72
+ start_slice [dim ] = - 1 * start
73
+ if stop < 0 or stop_dynamic_None :
74
+ stop_slice = [0 ] * len (input .shape )
75
+ stop_slice [dim ] = - 1 * stop
76
+
77
+ start_slice_tensor = cat (
78
+ ctx ,
79
+ target ,
80
+ source_ir ,
81
+ name + "_start_slice_concat" ,
82
+ tuple (start_slice ),
83
+ 0 ,
84
+ cast_dtype = trt .int32 ,
85
+ )
86
+ stop_slice_tensor = cat (
87
+ ctx ,
88
+ target ,
89
+ source_ir ,
90
+ name + "_stop_slice_concat" ,
91
+ tuple (stop_slice ),
92
+ 0 ,
93
+ cast_dtype = trt .int32 ,
94
+ )
95
+ stride_slice_tensor = cat (
96
+ ctx ,
97
+ target ,
98
+ source_ir ,
99
+ name + "_stride_slice_concat" ,
100
+ tuple (stride_slice ),
101
+ 0 ,
102
+ cast_dtype = trt .int32 ,
103
+ )
104
+
105
+ if start < 0 :
106
+ shape = get_shape_with_dynamic_shape (
107
+ ctx , target , source_ir , name , output_shape , input
108
+ )
109
+ start_slice_tensor = convert_binary_elementwise (
110
+ ctx ,
111
+ target ,
112
+ source_ir ,
113
+ name + "_sub_start" ,
114
+ trt .ElementWiseOperation .SUB ,
115
+ shape ,
116
+ start_slice_tensor ,
117
+ )
118
+ if (stop < 0 ) or stop_dynamic_None :
119
+ shape = get_shape_with_dynamic_shape (
120
+ ctx , target , source_ir , name , output_shape , input
121
+ )
122
+ stop_slice_tensor = convert_binary_elementwise (
123
+ ctx ,
124
+ target ,
125
+ source_ir ,
126
+ name + "_sub_stop" ,
127
+ trt .ElementWiseOperation .SUB ,
128
+ shape ,
129
+ stop_slice_tensor ,
130
+ )
131
+
132
+ # this is required for the ceil operation
133
+ output_shape_tensor_num = convert_binary_elementwise (
134
+ ctx ,
135
+ target ,
136
+ source_ir ,
137
+ name + "_sub_num" ,
138
+ trt .ElementWiseOperation .SUB ,
139
+ start_slice_tensor ,
140
+ stop_slice_tensor ,
141
+ )
142
+ output_shape_tensor_neg = floor_divide (
143
+ ctx ,
144
+ target ,
145
+ source_ir ,
146
+ name + "_div" ,
147
+ output_shape_tensor_num ,
148
+ stride_slice_tensor ,
149
+ )
150
+ output_shape_tensor = convert_binary_elementwise (
151
+ ctx ,
152
+ target ,
153
+ source_ir ,
154
+ name + "_prod" ,
155
+ trt .ElementWiseOperation .PROD ,
156
+ output_shape_tensor_neg ,
157
+ - 1 ,
158
+ )
159
+ layer = ctx .net .add_slice (
160
+ input , start = trt .Dims (), shape = trt .Dims (), stride = trt .Dims ()
161
+ )
162
+ layer .set_input (1 , start_slice_tensor )
163
+ layer .set_input (2 , output_shape_tensor )
164
+ layer .set_input (3 , stride_slice_tensor )
165
+ return layer .get_output (0 )
166
+
167
+ output_shape [dim ] = math .ceil ((stop - start ) / step )
60
168
return slice (
61
169
ctx , target , source_ir , name , input , start_slice , output_shape , stride_slice
62
170
)
0 commit comments