3
3
4
4
import numpy as np
5
5
import tensorrt as trt
6
+ import torch
6
7
from torch .fx .node import Target
7
8
from torch_tensorrt .dynamo ._SourceIR import SourceIR
8
9
from torch_tensorrt .dynamo .conversion ._ConversionContext import ConversionContext
@@ -87,6 +88,8 @@ def index(
87
88
# _LOGGER.debug(f"The index shape is {index.shape}")
88
89
# check if the input is dynamic
89
90
dynamic_shape = has_dynamic_shape (input .shape )
91
+ #is_numpy is a flag to specify if input isa numpy
92
+ is_numpy = False
90
93
91
94
# here we need to check if all the index are broadcastable
92
95
# if no, then we need to broadcast
@@ -95,8 +98,14 @@ def index(
95
98
if ind is not None :
96
99
_LOGGER .debug (f"Shape of { i } index is { ind .shape } " )
97
100
adv_indx_indices .append (i )
98
- # torch.nn.parameter.Parameter=> torch.Tensor
99
- ind = get_trt_tensor (ctx , ind , name + f"_parameter_to_fp32_tensor_{ i } " )
101
+ # torch.nn.parameter.Parameter=> numpy array
102
+ # numpy array is kept as numpy
103
+ # other cases are kept as TRTTensor
104
+ if (isinstance (ind , torch .Tensor ) or (ind , np .ndarray )):
105
+ ind = to_numpy (ind )
106
+ is_numpy = True
107
+ else :
108
+ ind = get_trt_tensor (ctx , ind , name + f"_parameter_to_fp32_tensor_{ i } " )
100
109
if last_index is not None :
101
110
assert broadcastable (
102
111
ind , last_index
@@ -131,8 +140,11 @@ def index(
131
140
for i in range (rank ):
132
141
dim = input_shape [i ]
133
142
dim_tensor = get_trt_tensor (ctx , dim , name + f"_individual_dim_{ i } " )
134
- # dim_tensor_list is a list of tensors
135
- dim_tensor_list .append (dim_tensor )
143
+ # dim_tensor_list is a list of tensors or numpy
144
+ if (is_numpy ):
145
+ dim_tensor_list .append (dim )
146
+ else :
147
+ dim_tensor_list .append (dim_tensor )
136
148
137
149
# for cases like
138
150
# t: [x_1, y_1, y_2, ..., x_m, ..., y_n] -> t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n],
@@ -150,9 +162,14 @@ def index(
150
162
if i not in adv_indx_indices :
151
163
new_order .append (i )
152
164
_LOGGER .debug (f"The new transpose order is { new_order } " )
153
- transpose_layer .second_transpose = tuple (new_order )
154
- set_layer_name (transpose_layer , target , name + "_index_transpose" , source_ir )
155
- transpose_tensor = transpose_layer .get_output (0 )
165
+
166
+ transpose_tensor = None
167
+ if (is_numpy ):
168
+ transpose_tensor = input [new_order ]
169
+ else :
170
+ transpose_layer .second_transpose = tuple (new_order )
171
+ set_layer_name (transpose_layer , target , name + "_index_transpose" , source_ir )
172
+ transpose_tensor = transpose_layer .get_output (0 )
156
173
157
174
# Flatten [x_1, x_2,.......x_m, y_1, y_2,.....y_n]
158
175
# transpose_tensor_shape = ctx.net.add_shape(transpose_tensor)
@@ -165,57 +182,70 @@ def index(
165
182
for i in range (adv_indx_count , rank ):
166
183
mult_d1 = mult_d1 * transpose_tensor_shape [i ]
167
184
168
- concat_tensor_layer = ctx .net .add_concatenation (
169
- [
170
- get_trt_tensor (ctx , mult_d0 , name + "_d0_shape" ),
171
- get_trt_tensor (ctx , mult_d1 , name + "_d1_shape" ),
172
- ]
173
- )
174
- set_layer_name (concat_tensor_layer , target , name + "_index_Concat" , source_ir )
175
- concat_tensor = concat_tensor_layer .get_output (0 )
185
+ flatten_tensor = None
186
+ if (is_numpy ):
187
+ flatten_tensor = transpose_tensor .reshape (mult_d0 , mult_d1 )
188
+ else :
189
+ concat_tensor_layer = ctx .net .add_concatenation (
190
+ [
191
+ get_trt_tensor (ctx , mult_d0 , name + "_d0_shape" ),
192
+ get_trt_tensor (ctx , mult_d1 , name + "_d1_shape" ),
193
+ ]
194
+ )
195
+ set_layer_name (concat_tensor_layer , target , name + "_index_Concat" , source_ir )
196
+ concat_tensor = concat_tensor_layer .get_output (0 )
197
+
198
+ reshape_layer = ctx .net .add_shuffle (transpose_tensor )
199
+ reshape_layer .set_input (1 , concat_tensor )
200
+ flatten_tensor = reshape_layer .get_output (0 )
176
201
177
- reshape_layer = ctx .net .add_shuffle (transpose_tensor )
178
- # check this
179
- reshape_layer .set_input (1 , concat_tensor )
180
- flatten_tensor = reshape_layer .get_output (0 )
181
202
_LOGGER .debug (f"The flatten tensor shape is { flatten_tensor .shape } " )
182
203
183
204
# tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the
184
205
# // j dimension of input x.
185
- multiplier = get_trt_tensor (
186
- ctx ,
187
- dim_tensor_list [adv_indx_indices [adv_indx_count - 1 ]],
188
- name + "_dim_last" ,
189
- )
190
- cum_adv_index = tensor_indices [adv_indx_count - 1 ]
191
- for i in range (adv_indx_count - 2 , - 1 , - 1 ):
192
- adv_index = convert_binary_elementwise (
193
- ctx ,
194
- target ,
195
- source_ir ,
196
- name + f"_index_intermediate_{ i } " ,
197
- trt .ElementWiseOperation .PROD ,
198
- multiplier ,
199
- tensor_indices [i ],
200
- )
201
- cum_adv_index = convert_binary_elementwise (
202
- ctx ,
203
- target ,
204
- source_ir ,
205
- name + f"_index_sum_intermediate_{ i } " ,
206
- trt .ElementWiseOperation .SUM ,
207
- cum_adv_index ,
208
- adv_index ,
209
- )
210
- multiplier = convert_binary_elementwise (
206
+ if (is_numpy ):
207
+ multiplier = dim_tensor_list [adv_indx_indices [adv_indx_count - 1 ]]
208
+ cum_adv_index = tensor_indices [adv_indx_count - 1 ]
209
+ for i in range (adv_indx_count - 2 , - 1 , - 1 ):
210
+ adv_index = multiplier * tensor_indices [i ]
211
+ cum_adv_index = cum_adv_index + adv_index
212
+ multiplier = multiplier * dim_tensor_list [adv_indx_indices [i ]]
213
+ else :
214
+
215
+ multiplier = get_trt_tensor (
211
216
ctx ,
212
- target ,
213
- source_ir ,
214
- name + f"_index_intermediate_xj_{ i } " ,
215
- trt .ElementWiseOperation .PROD ,
216
- multiplier ,
217
- dim_tensor_list [adv_indx_indices [i ]],
217
+ dim_tensor_list [adv_indx_indices [adv_indx_count - 1 ]],
218
+ name + "_dim_last" ,
218
219
)
220
+ cum_adv_index = tensor_indices [adv_indx_count - 1 ]
221
+ for i in range (adv_indx_count - 2 , - 1 , - 1 ):
222
+ adv_index = convert_binary_elementwise (
223
+ ctx ,
224
+ target ,
225
+ source_ir ,
226
+ name + f"_index_intermediate_{ i } " ,
227
+ trt .ElementWiseOperation .PROD ,
228
+ multiplier ,
229
+ tensor_indices [i ],
230
+ )
231
+ cum_adv_index = convert_binary_elementwise (
232
+ ctx ,
233
+ target ,
234
+ source_ir ,
235
+ name + f"_index_sum_intermediate_{ i } " ,
236
+ trt .ElementWiseOperation .SUM ,
237
+ cum_adv_index ,
238
+ adv_index ,
239
+ )
240
+ multiplier = convert_binary_elementwise (
241
+ ctx ,
242
+ target ,
243
+ source_ir ,
244
+ name + f"_index_intermediate_xj_{ i } " ,
245
+ trt .ElementWiseOperation .PROD ,
246
+ multiplier ,
247
+ dim_tensor_list [adv_indx_indices [i ]],
248
+ )
219
249
220
250
gather_layer_element = ctx .net .add_gather (flatten_tensor , cum_adv_index , 0 )
221
251
set_layer_name (
@@ -239,29 +269,36 @@ def index(
239
269
== adv_indx_indices [adv_indx_count - 1 ] - adv_indx_indices [0 ] + 1
240
270
):
241
271
_LOGGER .debug (f"The indices are continuous in this case" )
242
- concat_tensor_reshape .append (
243
- get_trt_tensor (ctx , - 1 , name + "_dynamic_concat" )
244
- )
272
+ if (is_numpy ):
273
+ concat_tensor_reshape .append (- 1 )
274
+ else :
275
+ concat_tensor_reshape .append (
276
+ get_trt_tensor (ctx , - 1 , name + "_dynamic_concat" )
277
+ )
245
278
for i in range (0 , rank ):
246
279
if i not in adv_indx_indices :
247
280
curr_dim = dim_tensor_list [i ]
248
281
concat_tensor_reshape .append (curr_dim )
249
282
250
- concat_tensor_layer = ctx .net .add_concatenation (concat_tensor_reshape )
251
- set_layer_name (
252
- concat_tensor_layer , target , name + "_index_Concat_reshape" , source_ir
253
- )
254
- concat_tensor = concat_tensor_layer .get_output (0 )
283
+ unfold_tensor = None
284
+ if (is_numpy ):
285
+ unfold_tensor = gather_out .reshape (concat_tensor )
286
+ else :
287
+ concat_tensor_layer = ctx .net .add_concatenation (concat_tensor_reshape )
288
+ set_layer_name (
289
+ concat_tensor_layer , target , name + "_index_Concat_reshape" , source_ir
290
+ )
291
+ concat_tensor = concat_tensor_layer .get_output (0 )
255
292
256
- regular_index_shuffle_layer = ctx .net .add_shuffle (gather_out )
257
- regular_index_shuffle_layer .set_input (1 , concat_tensor )
258
- set_layer_name (
259
- regular_index_shuffle_layer ,
260
- target ,
261
- name + "_index_regular_index" ,
262
- source_ir ,
263
- )
264
- unfold_tensor = regular_index_shuffle_layer .get_output (0 )
293
+ regular_index_shuffle_layer = ctx .net .add_shuffle (gather_out )
294
+ regular_index_shuffle_layer .set_input (1 , concat_tensor )
295
+ set_layer_name (
296
+ regular_index_shuffle_layer ,
297
+ target ,
298
+ name + "_index_regular_index" ,
299
+ source_ir ,
300
+ )
301
+ unfold_tensor = regular_index_shuffle_layer .get_output (0 )
265
302
_LOGGER .debug (f"The tensor is unfolded now" )
266
303
_LOGGER .debug (f"The unfolded tensor shape is { unfold_tensor .shape } " )
267
304
@@ -275,14 +312,18 @@ def index(
275
312
new_order .append (i )
276
313
_LOGGER .debug (f"Transposing the indices to correct position { new_order } " )
277
314
278
- transpose_advanced_shuffle_layer .second_transpose = tuple (new_order )
279
- set_layer_name (
280
- transpose_advanced_shuffle_layer ,
281
- target ,
282
- name + "_index_advanced_shuffle_transpose" ,
283
- source_ir ,
284
- )
285
- transpose_tensor = transpose_advanced_shuffle_layer .get_output (0 )
315
+ transpose_tensor = None
316
+ if (is_numpy ):
317
+ transpose_tensor = unfold_tensor [new_order ]
318
+ else :
319
+ transpose_advanced_shuffle_layer .second_transpose = tuple (new_order )
320
+ set_layer_name (
321
+ transpose_advanced_shuffle_layer ,
322
+ target ,
323
+ name + "_index_advanced_shuffle_transpose" ,
324
+ source_ir ,
325
+ )
326
+ transpose_tensor = transpose_advanced_shuffle_layer .get_output (0 )
286
327
287
328
# unfold advanced layer
288
329
concat_final_tensor = []
@@ -296,25 +337,29 @@ def index(
296
337
current_dim = dim_tensor_list [i ]
297
338
concat_final_tensor .append (current_dim )
298
339
299
- concat_final_shape_layer = ctx .net .add_concatenation (concat_final_tensor )
300
- set_layer_name (
301
- concat_final_shape_layer ,
302
- target ,
303
- name + "_index_continuous_concat_final_shape_layer" ,
304
- source_ir ,
305
- )
306
- concat_final_tensor = concat_final_shape_layer .get_output (0 )
307
-
308
- unfold_advanced_shuffle_layer = ctx .net .add_shuffle (transpose_tensor )
309
- # check this
310
- unfold_advanced_shuffle_layer .set_input (1 , concat_final_tensor )
311
- set_layer_name (
312
- unfold_advanced_shuffle_layer ,
313
- target ,
314
- name + "_unfold_advanced_index" ,
315
- source_ir ,
316
- )
317
- reshape_output = unfold_advanced_shuffle_layer .get_output (0 )
340
+ reshape_output = []
341
+ if (is_numpy ):
342
+ reshape_output = transpose_tensor .reshape (concat_final_tensor )
343
+ else :
344
+ concat_final_shape_layer = ctx .net .add_concatenation (concat_final_tensor )
345
+ set_layer_name (
346
+ concat_final_shape_layer ,
347
+ target ,
348
+ name + "_index_continuous_concat_final_shape_layer" ,
349
+ source_ir ,
350
+ )
351
+ concat_final_tensor = concat_final_shape_layer .get_output (0 )
352
+
353
+ unfold_advanced_shuffle_layer = ctx .net .add_shuffle (transpose_tensor )
354
+ # check this
355
+ unfold_advanced_shuffle_layer .set_input (1 , concat_final_tensor )
356
+ set_layer_name (
357
+ unfold_advanced_shuffle_layer ,
358
+ target ,
359
+ name + "_unfold_advanced_index" ,
360
+ source_ir ,
361
+ )
362
+ reshape_output = unfold_advanced_shuffle_layer .get_output (0 )
318
363
319
364
else :
320
365
_LOGGER .debug (f"The indices are not continuous in this case" )
@@ -325,23 +370,27 @@ def index(
325
370
curr_dim = dim_tensor_list [i ]
326
371
concat_final_tensor .append (curr_dim )
327
372
328
- concat_final_shape_layer = ctx .net .add_concatenation (concat_final_tensor )
329
- set_layer_name (
330
- concat_final_shape_layer ,
331
- target ,
332
- name + "_index_non_continuous_concat_final_shape_layer" ,
333
- source_ir ,
334
- )
335
- concat_final_tensor = concat_final_shape_layer .get_output (0 )
336
-
337
- reshape_layer = ctx .net .add_shuffle (gather_out )
338
- reshape_layer .set_input (1 , concat_final_tensor )
339
- set_layer_name (
340
- reshape_layer ,
341
- target ,
342
- name + "_index_non_continuous_shuffle_final_shape_layer" ,
343
- source_ir ,
344
- )
345
- reshape_output = reshape_layer .get_output (0 )
373
+ reshape_output = None
374
+ if (is_numpy ):
375
+ reshape_output = gather_out .reshape (concat_final_tensor )
376
+ else :
377
+ concat_final_shape_layer = ctx .net .add_concatenation (concat_final_tensor )
378
+ set_layer_name (
379
+ concat_final_shape_layer ,
380
+ target ,
381
+ name + "_index_non_continuous_concat_final_shape_layer" ,
382
+ source_ir ,
383
+ )
384
+ concat_final_tensor = concat_final_shape_layer .get_output (0 )
385
+
386
+ reshape_layer = ctx .net .add_shuffle (gather_out )
387
+ reshape_layer .set_input (1 , concat_final_tensor )
388
+ set_layer_name (
389
+ reshape_layer ,
390
+ target ,
391
+ name + "_index_non_continuous_shuffle_final_shape_layer" ,
392
+ source_ir ,
393
+ )
394
+ reshape_output = reshape_layer .get_output (0 )
346
395
347
396
return reshape_output
0 commit comments