@@ -88,9 +88,16 @@ def index(
88
88
# _LOGGER.debug(f"The index shape is {index.shape}")
89
89
# check if the input is dynamic
90
90
dynamic_shape = has_dynamic_shape (input .shape )
91
- #is_numpy is a flag to specify if input isa numpy
92
- is_numpy = False
93
-
91
+ # is_numpy is a flag to specify if all the indices are numpy or torchTensor.
92
+ # If any is not this flag will be set to False
93
+ is_numpy = True
94
+ _LOGGER .debug (f"Checking for the is_numpy flag" )
95
+ for i , ind in enumerate (index ):
96
+ if ind is None :
97
+ continue
98
+ if not (isinstance (ind , torch .Tensor ) or isinstance (ind , np .ndarray )):
99
+ is_numpy = False
100
+ break
94
101
# here we need to check if all the index are broadcastable
95
102
# if no, then we need to broadcast
96
103
last_index = None
@@ -101,7 +108,7 @@ def index(
101
108
# torch.nn.parameter.Parameter=> numpy array
102
109
# numpy array is kept as numpy
103
110
# other cases are kept as TRTTensor
104
- if ( isinstance ( ind , torch . Tensor ) or ( ind , np . ndarray )) :
111
+ if is_numpy :
105
112
ind = to_numpy (ind )
106
113
is_numpy = True
107
114
else :
@@ -119,8 +126,9 @@ def index(
119
126
set_layer_name (identity_layer , target , name + "_index_identity" , source_ir )
120
127
return identity_layer .get_output (0 )
121
128
elif len (tensor_indices ) == 1 :
122
- # This case works
123
- indices_tensor = tensor_indices [0 ]
129
+ indices_tensor = get_trt_tensor (
130
+ ctx , tensor_indices [0 ], name + f"_parameter_to_fp32_tensor"
131
+ )
124
132
index = adv_indx_indices [0 ]
125
133
_LOGGER .debug (f"The advanced index indices is { adv_indx_indices } " )
126
134
gather_layer = ctx .net .add_gather (input , indices_tensor , index )
@@ -136,15 +144,15 @@ def index(
136
144
rank = len (input_shape )
137
145
adv_indx_count = len (adv_indx_indices )
138
146
dim_tensor_list = []
147
+ dim_list = []
139
148
140
149
for i in range (rank ):
141
150
dim = input_shape [i ]
142
151
dim_tensor = get_trt_tensor (ctx , dim , name + f"_individual_dim_{ i } " )
143
152
# dim_tensor_list is a list of tensors or numpy
144
- if (is_numpy ):
145
- dim_tensor_list .append (dim )
146
- else :
147
- dim_tensor_list .append (dim_tensor )
153
+ if is_numpy :
154
+ dim_list .append (dim )
155
+ dim_tensor_list .append (dim_tensor )
148
156
149
157
# for cases like
150
158
# t: [x_1, y_1, y_2, ..., x_m, ..., y_n] -> t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n],
@@ -164,12 +172,9 @@ def index(
164
172
_LOGGER .debug (f"The new transpose order is { new_order } " )
165
173
166
174
transpose_tensor = None
167
- if (is_numpy ):
168
- transpose_tensor = input [new_order ]
169
- else :
170
- transpose_layer .second_transpose = tuple (new_order )
171
- set_layer_name (transpose_layer , target , name + "_index_transpose" , source_ir )
172
- transpose_tensor = transpose_layer .get_output (0 )
175
+ transpose_layer .second_transpose = tuple (new_order )
176
+ set_layer_name (transpose_layer , target , name + "_index_transpose" , source_ir )
177
+ transpose_tensor = transpose_layer .get_output (0 )
173
178
174
179
# Flatten [x_1, x_2,.......x_m, y_1, y_2,.....y_n]
175
180
# transpose_tensor_shape = ctx.net.add_shape(transpose_tensor)
@@ -182,36 +187,34 @@ def index(
182
187
for i in range (adv_indx_count , rank ):
183
188
mult_d1 = mult_d1 * transpose_tensor_shape [i ]
184
189
185
- flatten_tensor = None
186
- if (is_numpy ):
187
- flatten_tensor = transpose_tensor .reshape (mult_d0 , mult_d1 )
188
- else :
189
- concat_tensor_layer = ctx .net .add_concatenation (
190
- [
191
- get_trt_tensor (ctx , mult_d0 , name + "_d0_shape" ),
192
- get_trt_tensor (ctx , mult_d1 , name + "_d1_shape" ),
193
- ]
194
- )
195
- set_layer_name (concat_tensor_layer , target , name + "_index_Concat" , source_ir )
196
- concat_tensor = concat_tensor_layer .get_output (0 )
190
+ concat_tensor_layer = ctx .net .add_concatenation (
191
+ [
192
+ get_trt_tensor (ctx , mult_d0 , name + "_d0_shape" ),
193
+ get_trt_tensor (ctx , mult_d1 , name + "_d1_shape" ),
194
+ ]
195
+ )
196
+ set_layer_name (concat_tensor_layer , target , name + "_index_Concat" , source_ir )
197
+ concat_tensor = concat_tensor_layer .get_output (0 )
197
198
198
- reshape_layer = ctx .net .add_shuffle (transpose_tensor )
199
- reshape_layer .set_input (1 , concat_tensor )
200
- flatten_tensor = reshape_layer .get_output (0 )
199
+ reshape_layer = ctx .net .add_shuffle (transpose_tensor )
200
+ reshape_layer .set_input (1 , concat_tensor )
201
+ flatten_tensor = reshape_layer .get_output (0 )
201
202
202
203
_LOGGER .debug (f"The flatten tensor shape is { flatten_tensor .shape } " )
203
204
204
205
# tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the
205
206
# // j dimension of input x.
206
- if ( is_numpy ) :
207
- multiplier = dim_tensor_list [adv_indx_indices [adv_indx_count - 1 ]]
207
+ if is_numpy :
208
+ multiplier = dim_list [adv_indx_indices [adv_indx_count - 1 ]]
208
209
cum_adv_index = tensor_indices [adv_indx_count - 1 ]
209
210
for i in range (adv_indx_count - 2 , - 1 , - 1 ):
210
211
adv_index = multiplier * tensor_indices [i ]
211
212
cum_adv_index = cum_adv_index + adv_index
212
- multiplier = multiplier * dim_tensor_list [adv_indx_indices [i ]]
213
+ multiplier = multiplier * dim_list [adv_indx_indices [i ]]
214
+ cum_adv_index = get_trt_tensor (
215
+ ctx , cum_adv_index , name + f"_index_sum_intermediate"
216
+ )
213
217
else :
214
-
215
218
multiplier = get_trt_tensor (
216
219
ctx ,
217
220
dim_tensor_list [adv_indx_indices [adv_indx_count - 1 ]],
@@ -269,36 +272,31 @@ def index(
269
272
== adv_indx_indices [adv_indx_count - 1 ] - adv_indx_indices [0 ] + 1
270
273
):
271
274
_LOGGER .debug (f"The indices are continuous in this case" )
272
- if (is_numpy ):
273
- concat_tensor_reshape .append (- 1 )
274
- else :
275
- concat_tensor_reshape .append (
276
- get_trt_tensor (ctx , - 1 , name + "_dynamic_concat" )
277
- )
275
+ concat_tensor_reshape .append (
276
+ get_trt_tensor (ctx , - 1 , name + "_dynamic_concat" )
277
+ )
278
278
for i in range (0 , rank ):
279
279
if i not in adv_indx_indices :
280
280
curr_dim = dim_tensor_list [i ]
281
281
concat_tensor_reshape .append (curr_dim )
282
282
283
283
unfold_tensor = None
284
- if (is_numpy ):
285
- unfold_tensor = gather_out .reshape (concat_tensor )
286
- else :
287
- concat_tensor_layer = ctx .net .add_concatenation (concat_tensor_reshape )
288
- set_layer_name (
289
- concat_tensor_layer , target , name + "_index_Concat_reshape" , source_ir
290
- )
291
- concat_tensor = concat_tensor_layer .get_output (0 )
292
284
293
- regular_index_shuffle_layer = ctx .net .add_shuffle (gather_out )
294
- regular_index_shuffle_layer .set_input (1 , concat_tensor )
295
- set_layer_name (
296
- regular_index_shuffle_layer ,
297
- target ,
298
- name + "_index_regular_index" ,
299
- source_ir ,
300
- )
301
- unfold_tensor = regular_index_shuffle_layer .get_output (0 )
285
+ concat_tensor_layer = ctx .net .add_concatenation (concat_tensor_reshape )
286
+ set_layer_name (
287
+ concat_tensor_layer , target , name + "_index_Concat_reshape" , source_ir
288
+ )
289
+ concat_tensor = concat_tensor_layer .get_output (0 )
290
+
291
+ regular_index_shuffle_layer = ctx .net .add_shuffle (gather_out )
292
+ regular_index_shuffle_layer .set_input (1 , concat_tensor )
293
+ set_layer_name (
294
+ regular_index_shuffle_layer ,
295
+ target ,
296
+ name + "_index_regular_index" ,
297
+ source_ir ,
298
+ )
299
+ unfold_tensor = regular_index_shuffle_layer .get_output (0 )
302
300
_LOGGER .debug (f"The tensor is unfolded now" )
303
301
_LOGGER .debug (f"The unfolded tensor shape is { unfold_tensor .shape } " )
304
302
@@ -313,17 +311,15 @@ def index(
313
311
_LOGGER .debug (f"Transposing the indices to correct position { new_order } " )
314
312
315
313
transpose_tensor = None
316
- if (is_numpy ):
317
- transpose_tensor = unfold_tensor [new_order ]
318
- else :
319
- transpose_advanced_shuffle_layer .second_transpose = tuple (new_order )
320
- set_layer_name (
321
- transpose_advanced_shuffle_layer ,
322
- target ,
323
- name + "_index_advanced_shuffle_transpose" ,
324
- source_ir ,
325
- )
326
- transpose_tensor = transpose_advanced_shuffle_layer .get_output (0 )
314
+
315
+ transpose_advanced_shuffle_layer .second_transpose = tuple (new_order )
316
+ set_layer_name (
317
+ transpose_advanced_shuffle_layer ,
318
+ target ,
319
+ name + "_index_advanced_shuffle_transpose" ,
320
+ source_ir ,
321
+ )
322
+ transpose_tensor = transpose_advanced_shuffle_layer .get_output (0 )
327
323
328
324
# unfold advanced layer
329
325
concat_final_tensor = []
@@ -338,28 +334,25 @@ def index(
338
334
concat_final_tensor .append (current_dim )
339
335
340
336
reshape_output = []
341
- if (is_numpy ):
342
- reshape_output = transpose_tensor .reshape (concat_final_tensor )
343
- else :
344
- concat_final_shape_layer = ctx .net .add_concatenation (concat_final_tensor )
345
- set_layer_name (
346
- concat_final_shape_layer ,
347
- target ,
348
- name + "_index_continuous_concat_final_shape_layer" ,
349
- source_ir ,
350
- )
351
- concat_final_tensor = concat_final_shape_layer .get_output (0 )
352
-
353
- unfold_advanced_shuffle_layer = ctx .net .add_shuffle (transpose_tensor )
354
- # check this
355
- unfold_advanced_shuffle_layer .set_input (1 , concat_final_tensor )
356
- set_layer_name (
357
- unfold_advanced_shuffle_layer ,
358
- target ,
359
- name + "_unfold_advanced_index" ,
360
- source_ir ,
361
- )
362
- reshape_output = unfold_advanced_shuffle_layer .get_output (0 )
337
+ concat_final_shape_layer = ctx .net .add_concatenation (concat_final_tensor )
338
+ set_layer_name (
339
+ concat_final_shape_layer ,
340
+ target ,
341
+ name + "_index_continuous_concat_final_shape_layer" ,
342
+ source_ir ,
343
+ )
344
+ concat_final_tensor = concat_final_shape_layer .get_output (0 )
345
+
346
+ unfold_advanced_shuffle_layer = ctx .net .add_shuffle (transpose_tensor )
347
+ # check this
348
+ unfold_advanced_shuffle_layer .set_input (1 , concat_final_tensor )
349
+ set_layer_name (
350
+ unfold_advanced_shuffle_layer ,
351
+ target ,
352
+ name + "_unfold_advanced_index" ,
353
+ source_ir ,
354
+ )
355
+ reshape_output = unfold_advanced_shuffle_layer .get_output (0 )
363
356
364
357
else :
365
358
_LOGGER .debug (f"The indices are not continuous in this case" )
@@ -371,26 +364,24 @@ def index(
371
364
concat_final_tensor .append (curr_dim )
372
365
373
366
reshape_output = None
374
- if (is_numpy ):
375
- reshape_output = gather_out .reshape (concat_final_tensor )
376
- else :
377
- concat_final_shape_layer = ctx .net .add_concatenation (concat_final_tensor )
378
- set_layer_name (
379
- concat_final_shape_layer ,
380
- target ,
381
- name + "_index_non_continuous_concat_final_shape_layer" ,
382
- source_ir ,
383
- )
384
- concat_final_tensor = concat_final_shape_layer .get_output (0 )
385
367
386
- reshape_layer = ctx .net .add_shuffle (gather_out )
387
- reshape_layer .set_input (1 , concat_final_tensor )
388
- set_layer_name (
389
- reshape_layer ,
390
- target ,
391
- name + "_index_non_continuous_shuffle_final_shape_layer" ,
392
- source_ir ,
393
- )
394
- reshape_output = reshape_layer .get_output (0 )
368
+ concat_final_shape_layer = ctx .net .add_concatenation (concat_final_tensor )
369
+ set_layer_name (
370
+ concat_final_shape_layer ,
371
+ target ,
372
+ name + "_index_non_continuous_concat_final_shape_layer" ,
373
+ source_ir ,
374
+ )
375
+ concat_final_tensor = concat_final_shape_layer .get_output (0 )
376
+
377
+ reshape_layer = ctx .net .add_shuffle (gather_out )
378
+ reshape_layer .set_input (1 , concat_final_tensor )
379
+ set_layer_name (
380
+ reshape_layer ,
381
+ target ,
382
+ name + "_index_non_continuous_shuffle_final_shape_layer" ,
383
+ source_ir ,
384
+ )
385
+ reshape_output = reshape_layer .get_output (0 )
395
386
396
387
return reshape_output
0 commit comments