Skip to content

Commit e31b7e8

Browse files
committed
Correction for the is_numpy cases for mix of numpy and non numpy inputs
1 parent b1534f8 commit e31b7e8

File tree

1 file changed

+102
-111
lines changed
  • py/torch_tensorrt/dynamo/conversion/impl

1 file changed

+102
-111
lines changed

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 102 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,16 @@ def index(
8888
# _LOGGER.debug(f"The index shape is {index.shape}")
8989
# check if the input is dynamic
9090
dynamic_shape = has_dynamic_shape(input.shape)
91-
#is_numpy is a flag to specify if input isa numpy
92-
is_numpy = False
93-
91+
# is_numpy is a flag to specify if all the indices are numpy or torchTensor.
92+
# If any is not this flag will be set to False
93+
is_numpy = True
94+
_LOGGER.debug(f"Checking for the is_numpy flag")
95+
for i, ind in enumerate(index):
96+
if ind is None:
97+
continue
98+
if not (isinstance(ind, torch.Tensor) or isinstance(ind, np.ndarray)):
99+
is_numpy = False
100+
break
94101
# here we need to check if all the index are broadcastable
95102
# if no, then we need to broadcast
96103
last_index = None
@@ -101,7 +108,7 @@ def index(
101108
# torch.nn.parameter.Parameter=> numpy array
102109
# numpy array is kept as numpy
103110
# other cases are kept as TRTTensor
104-
if (isinstance(ind, torch.Tensor) or (ind, np.ndarray)):
111+
if is_numpy:
105112
ind = to_numpy(ind)
106113
is_numpy = True
107114
else:
@@ -119,8 +126,9 @@ def index(
119126
set_layer_name(identity_layer, target, name + "_index_identity", source_ir)
120127
return identity_layer.get_output(0)
121128
elif len(tensor_indices) == 1:
122-
# This case works
123-
indices_tensor = tensor_indices[0]
129+
indices_tensor = get_trt_tensor(
130+
ctx, tensor_indices[0], name + f"_parameter_to_fp32_tensor"
131+
)
124132
index = adv_indx_indices[0]
125133
_LOGGER.debug(f"The advanced index indices is {adv_indx_indices}")
126134
gather_layer = ctx.net.add_gather(input, indices_tensor, index)
@@ -136,15 +144,15 @@ def index(
136144
rank = len(input_shape)
137145
adv_indx_count = len(adv_indx_indices)
138146
dim_tensor_list = []
147+
dim_list = []
139148

140149
for i in range(rank):
141150
dim = input_shape[i]
142151
dim_tensor = get_trt_tensor(ctx, dim, name + f"_individual_dim_{i}")
143152
# dim_tensor_list is a list of tensors or numpy
144-
if(is_numpy):
145-
dim_tensor_list.append(dim)
146-
else:
147-
dim_tensor_list.append(dim_tensor)
153+
if is_numpy:
154+
dim_list.append(dim)
155+
dim_tensor_list.append(dim_tensor)
148156

149157
# for cases like
150158
# t: [x_1, y_1, y_2, ..., x_m, ..., y_n] -> t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n],
@@ -164,12 +172,9 @@ def index(
164172
_LOGGER.debug(f"The new transpose order is {new_order}")
165173

166174
transpose_tensor = None
167-
if(is_numpy):
168-
transpose_tensor = input[new_order]
169-
else:
170-
transpose_layer.second_transpose = tuple(new_order)
171-
set_layer_name(transpose_layer, target, name + "_index_transpose", source_ir)
172-
transpose_tensor = transpose_layer.get_output(0)
175+
transpose_layer.second_transpose = tuple(new_order)
176+
set_layer_name(transpose_layer, target, name + "_index_transpose", source_ir)
177+
transpose_tensor = transpose_layer.get_output(0)
173178

174179
# Flatten [x_1, x_2,.......x_m, y_1, y_2,.....y_n]
175180
# transpose_tensor_shape = ctx.net.add_shape(transpose_tensor)
@@ -182,36 +187,34 @@ def index(
182187
for i in range(adv_indx_count, rank):
183188
mult_d1 = mult_d1 * transpose_tensor_shape[i]
184189

185-
flatten_tensor = None
186-
if(is_numpy):
187-
flatten_tensor = transpose_tensor.reshape(mult_d0, mult_d1)
188-
else:
189-
concat_tensor_layer = ctx.net.add_concatenation(
190-
[
191-
get_trt_tensor(ctx, mult_d0, name + "_d0_shape"),
192-
get_trt_tensor(ctx, mult_d1, name + "_d1_shape"),
193-
]
194-
)
195-
set_layer_name(concat_tensor_layer, target, name + "_index_Concat", source_ir)
196-
concat_tensor = concat_tensor_layer.get_output(0)
190+
concat_tensor_layer = ctx.net.add_concatenation(
191+
[
192+
get_trt_tensor(ctx, mult_d0, name + "_d0_shape"),
193+
get_trt_tensor(ctx, mult_d1, name + "_d1_shape"),
194+
]
195+
)
196+
set_layer_name(concat_tensor_layer, target, name + "_index_Concat", source_ir)
197+
concat_tensor = concat_tensor_layer.get_output(0)
197198

198-
reshape_layer = ctx.net.add_shuffle(transpose_tensor)
199-
reshape_layer.set_input(1, concat_tensor)
200-
flatten_tensor = reshape_layer.get_output(0)
199+
reshape_layer = ctx.net.add_shuffle(transpose_tensor)
200+
reshape_layer.set_input(1, concat_tensor)
201+
flatten_tensor = reshape_layer.get_output(0)
201202

202203
_LOGGER.debug(f"The flatten tensor shape is {flatten_tensor.shape}")
203204

204205
# tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the
205206
# // j dimension of input x.
206-
if(is_numpy):
207-
multiplier = dim_tensor_list[adv_indx_indices[adv_indx_count - 1]]
207+
if is_numpy:
208+
multiplier = dim_list[adv_indx_indices[adv_indx_count - 1]]
208209
cum_adv_index = tensor_indices[adv_indx_count - 1]
209210
for i in range(adv_indx_count - 2, -1, -1):
210211
adv_index = multiplier * tensor_indices[i]
211212
cum_adv_index = cum_adv_index + adv_index
212-
multiplier = multiplier * dim_tensor_list[adv_indx_indices[i]]
213+
multiplier = multiplier * dim_list[adv_indx_indices[i]]
214+
cum_adv_index = get_trt_tensor(
215+
ctx, cum_adv_index, name + f"_index_sum_intermediate"
216+
)
213217
else:
214-
215218
multiplier = get_trt_tensor(
216219
ctx,
217220
dim_tensor_list[adv_indx_indices[adv_indx_count - 1]],
@@ -269,36 +272,31 @@ def index(
269272
== adv_indx_indices[adv_indx_count - 1] - adv_indx_indices[0] + 1
270273
):
271274
_LOGGER.debug(f"The indices are continuous in this case")
272-
if(is_numpy):
273-
concat_tensor_reshape.append(-1)
274-
else:
275-
concat_tensor_reshape.append(
276-
get_trt_tensor(ctx, -1, name + "_dynamic_concat")
277-
)
275+
concat_tensor_reshape.append(
276+
get_trt_tensor(ctx, -1, name + "_dynamic_concat")
277+
)
278278
for i in range(0, rank):
279279
if i not in adv_indx_indices:
280280
curr_dim = dim_tensor_list[i]
281281
concat_tensor_reshape.append(curr_dim)
282282

283283
unfold_tensor = None
284-
if(is_numpy):
285-
unfold_tensor = gather_out.reshape(concat_tensor)
286-
else:
287-
concat_tensor_layer = ctx.net.add_concatenation(concat_tensor_reshape)
288-
set_layer_name(
289-
concat_tensor_layer, target, name + "_index_Concat_reshape", source_ir
290-
)
291-
concat_tensor = concat_tensor_layer.get_output(0)
292284

293-
regular_index_shuffle_layer = ctx.net.add_shuffle(gather_out)
294-
regular_index_shuffle_layer.set_input(1, concat_tensor)
295-
set_layer_name(
296-
regular_index_shuffle_layer,
297-
target,
298-
name + "_index_regular_index",
299-
source_ir,
300-
)
301-
unfold_tensor = regular_index_shuffle_layer.get_output(0)
285+
concat_tensor_layer = ctx.net.add_concatenation(concat_tensor_reshape)
286+
set_layer_name(
287+
concat_tensor_layer, target, name + "_index_Concat_reshape", source_ir
288+
)
289+
concat_tensor = concat_tensor_layer.get_output(0)
290+
291+
regular_index_shuffle_layer = ctx.net.add_shuffle(gather_out)
292+
regular_index_shuffle_layer.set_input(1, concat_tensor)
293+
set_layer_name(
294+
regular_index_shuffle_layer,
295+
target,
296+
name + "_index_regular_index",
297+
source_ir,
298+
)
299+
unfold_tensor = regular_index_shuffle_layer.get_output(0)
302300
_LOGGER.debug(f"The tensor is unfolded now")
303301
_LOGGER.debug(f"The unfolded tensor shape is {unfold_tensor.shape}")
304302

@@ -313,17 +311,15 @@ def index(
313311
_LOGGER.debug(f"Transposing the indices to correct position {new_order}")
314312

315313
transpose_tensor = None
316-
if(is_numpy):
317-
transpose_tensor = unfold_tensor[new_order]
318-
else:
319-
transpose_advanced_shuffle_layer.second_transpose = tuple(new_order)
320-
set_layer_name(
321-
transpose_advanced_shuffle_layer,
322-
target,
323-
name + "_index_advanced_shuffle_transpose",
324-
source_ir,
325-
)
326-
transpose_tensor = transpose_advanced_shuffle_layer.get_output(0)
314+
315+
transpose_advanced_shuffle_layer.second_transpose = tuple(new_order)
316+
set_layer_name(
317+
transpose_advanced_shuffle_layer,
318+
target,
319+
name + "_index_advanced_shuffle_transpose",
320+
source_ir,
321+
)
322+
transpose_tensor = transpose_advanced_shuffle_layer.get_output(0)
327323

328324
# unfold advanced layer
329325
concat_final_tensor = []
@@ -338,28 +334,25 @@ def index(
338334
concat_final_tensor.append(current_dim)
339335

340336
reshape_output = []
341-
if(is_numpy):
342-
reshape_output = transpose_tensor.reshape(concat_final_tensor)
343-
else:
344-
concat_final_shape_layer = ctx.net.add_concatenation(concat_final_tensor)
345-
set_layer_name(
346-
concat_final_shape_layer,
347-
target,
348-
name + "_index_continuous_concat_final_shape_layer",
349-
source_ir,
350-
)
351-
concat_final_tensor = concat_final_shape_layer.get_output(0)
352-
353-
unfold_advanced_shuffle_layer = ctx.net.add_shuffle(transpose_tensor)
354-
# check this
355-
unfold_advanced_shuffle_layer.set_input(1, concat_final_tensor)
356-
set_layer_name(
357-
unfold_advanced_shuffle_layer,
358-
target,
359-
name + "_unfold_advanced_index",
360-
source_ir,
361-
)
362-
reshape_output = unfold_advanced_shuffle_layer.get_output(0)
337+
concat_final_shape_layer = ctx.net.add_concatenation(concat_final_tensor)
338+
set_layer_name(
339+
concat_final_shape_layer,
340+
target,
341+
name + "_index_continuous_concat_final_shape_layer",
342+
source_ir,
343+
)
344+
concat_final_tensor = concat_final_shape_layer.get_output(0)
345+
346+
unfold_advanced_shuffle_layer = ctx.net.add_shuffle(transpose_tensor)
347+
# check this
348+
unfold_advanced_shuffle_layer.set_input(1, concat_final_tensor)
349+
set_layer_name(
350+
unfold_advanced_shuffle_layer,
351+
target,
352+
name + "_unfold_advanced_index",
353+
source_ir,
354+
)
355+
reshape_output = unfold_advanced_shuffle_layer.get_output(0)
363356

364357
else:
365358
_LOGGER.debug(f"The indices are not continuous in this case")
@@ -371,26 +364,24 @@ def index(
371364
concat_final_tensor.append(curr_dim)
372365

373366
reshape_output = None
374-
if(is_numpy):
375-
reshape_output = gather_out.reshape(concat_final_tensor)
376-
else:
377-
concat_final_shape_layer = ctx.net.add_concatenation(concat_final_tensor)
378-
set_layer_name(
379-
concat_final_shape_layer,
380-
target,
381-
name + "_index_non_continuous_concat_final_shape_layer",
382-
source_ir,
383-
)
384-
concat_final_tensor = concat_final_shape_layer.get_output(0)
385367

386-
reshape_layer = ctx.net.add_shuffle(gather_out)
387-
reshape_layer.set_input(1, concat_final_tensor)
388-
set_layer_name(
389-
reshape_layer,
390-
target,
391-
name + "_index_non_continuous_shuffle_final_shape_layer",
392-
source_ir,
393-
)
394-
reshape_output = reshape_layer.get_output(0)
368+
concat_final_shape_layer = ctx.net.add_concatenation(concat_final_tensor)
369+
set_layer_name(
370+
concat_final_shape_layer,
371+
target,
372+
name + "_index_non_continuous_concat_final_shape_layer",
373+
source_ir,
374+
)
375+
concat_final_tensor = concat_final_shape_layer.get_output(0)
376+
377+
reshape_layer = ctx.net.add_shuffle(gather_out)
378+
reshape_layer.set_input(1, concat_final_tensor)
379+
set_layer_name(
380+
reshape_layer,
381+
target,
382+
name + "_index_non_continuous_shuffle_final_shape_layer",
383+
source_ir,
384+
)
385+
reshape_output = reshape_layer.get_output(0)
395386

396387
return reshape_output

0 commit comments

Comments
 (0)