Skip to content

Commit 7a5c7f6

Browse files
committed
Correction for the is_numpy cases for mix of numpy and non numpy inputs
1 parent b1534f8 commit 7a5c7f6

File tree

1 file changed

+100
-117
lines changed
  • py/torch_tensorrt/dynamo/conversion/impl

1 file changed

+100
-117
lines changed

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 100 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,16 @@ def index(
8888
# _LOGGER.debug(f"The index shape is {index.shape}")
8989
# check if the input is dynamic
9090
dynamic_shape = has_dynamic_shape(input.shape)
91-
#is_numpy is a flag to specify if input isa numpy
92-
is_numpy = False
93-
91+
# is_numpy is a flag to specify if all the indices are numpy or torchTensor.
92+
# If any is not this flag will be set to False
93+
is_numpy = True
94+
_LOGGER.debug(f"Checking for the is_numpy flag")
95+
for i, ind in enumerate(index):
96+
if ind is None:
97+
continue
98+
if not (isinstance(ind, torch.Tensor) or isinstance(ind, np.ndarray)):
99+
is_numpy = False
100+
break
94101
# here we need to check if all the index are broadcastable
95102
# if no, then we need to broadcast
96103
last_index = None
@@ -101,7 +108,7 @@ def index(
101108
# torch.nn.parameter.Parameter=> numpy array
102109
# numpy array is kept as numpy
103110
# other cases are kept as TRTTensor
104-
if (isinstance(ind, torch.Tensor) or (ind, np.ndarray)):
111+
if is_numpy:
105112
ind = to_numpy(ind)
106113
is_numpy = True
107114
else:
@@ -119,8 +126,9 @@ def index(
119126
set_layer_name(identity_layer, target, name + "_index_identity", source_ir)
120127
return identity_layer.get_output(0)
121128
elif len(tensor_indices) == 1:
122-
# This case works
123-
indices_tensor = tensor_indices[0]
129+
indices_tensor = get_trt_tensor(
130+
ctx, tensor_indices[0], name + f"_parameter_to_fp32_tensor"
131+
)
124132
index = adv_indx_indices[0]
125133
_LOGGER.debug(f"The advanced index indices is {adv_indx_indices}")
126134
gather_layer = ctx.net.add_gather(input, indices_tensor, index)
@@ -136,15 +144,15 @@ def index(
136144
rank = len(input_shape)
137145
adv_indx_count = len(adv_indx_indices)
138146
dim_tensor_list = []
147+
dim_list = []
139148

140149
for i in range(rank):
141150
dim = input_shape[i]
142151
dim_tensor = get_trt_tensor(ctx, dim, name + f"_individual_dim_{i}")
143152
# dim_tensor_list is a list of tensors or numpy
144-
if(is_numpy):
145-
dim_tensor_list.append(dim)
146-
else:
147-
dim_tensor_list.append(dim_tensor)
153+
if is_numpy:
154+
dim_list.append(dim)
155+
dim_tensor_list.append(dim_tensor)
148156

149157
# for cases like
150158
# t: [x_1, y_1, y_2, ..., x_m, ..., y_n] -> t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n],
@@ -163,13 +171,9 @@ def index(
163171
new_order.append(i)
164172
_LOGGER.debug(f"The new transpose order is {new_order}")
165173

166-
transpose_tensor = None
167-
if(is_numpy):
168-
transpose_tensor = input[new_order]
169-
else:
170-
transpose_layer.second_transpose = tuple(new_order)
171-
set_layer_name(transpose_layer, target, name + "_index_transpose", source_ir)
172-
transpose_tensor = transpose_layer.get_output(0)
174+
transpose_layer.second_transpose = tuple(new_order)
175+
set_layer_name(transpose_layer, target, name + "_index_transpose", source_ir)
176+
transpose_tensor = transpose_layer.get_output(0)
173177

174178
# Flatten [x_1, x_2,.......x_m, y_1, y_2,.....y_n]
175179
# transpose_tensor_shape = ctx.net.add_shape(transpose_tensor)
@@ -182,36 +186,34 @@ def index(
182186
for i in range(adv_indx_count, rank):
183187
mult_d1 = mult_d1 * transpose_tensor_shape[i]
184188

185-
flatten_tensor = None
186-
if(is_numpy):
187-
flatten_tensor = transpose_tensor.reshape(mult_d0, mult_d1)
188-
else:
189-
concat_tensor_layer = ctx.net.add_concatenation(
190-
[
191-
get_trt_tensor(ctx, mult_d0, name + "_d0_shape"),
192-
get_trt_tensor(ctx, mult_d1, name + "_d1_shape"),
193-
]
194-
)
195-
set_layer_name(concat_tensor_layer, target, name + "_index_Concat", source_ir)
196-
concat_tensor = concat_tensor_layer.get_output(0)
189+
concat_tensor_layer = ctx.net.add_concatenation(
190+
[
191+
get_trt_tensor(ctx, mult_d0, name + "_d0_shape"),
192+
get_trt_tensor(ctx, mult_d1, name + "_d1_shape"),
193+
]
194+
)
195+
set_layer_name(concat_tensor_layer, target, name + "_index_Concat", source_ir)
196+
concat_tensor = concat_tensor_layer.get_output(0)
197197

198-
reshape_layer = ctx.net.add_shuffle(transpose_tensor)
199-
reshape_layer.set_input(1, concat_tensor)
200-
flatten_tensor = reshape_layer.get_output(0)
198+
reshape_layer = ctx.net.add_shuffle(transpose_tensor)
199+
reshape_layer.set_input(1, concat_tensor)
200+
flatten_tensor = reshape_layer.get_output(0)
201201

202202
_LOGGER.debug(f"The flatten tensor shape is {flatten_tensor.shape}")
203203

204204
# tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the
205205
# // j dimension of input x.
206-
if(is_numpy):
207-
multiplier = dim_tensor_list[adv_indx_indices[adv_indx_count - 1]]
206+
if is_numpy:
207+
multiplier = dim_list[adv_indx_indices[adv_indx_count - 1]]
208208
cum_adv_index = tensor_indices[adv_indx_count - 1]
209209
for i in range(adv_indx_count - 2, -1, -1):
210210
adv_index = multiplier * tensor_indices[i]
211211
cum_adv_index = cum_adv_index + adv_index
212-
multiplier = multiplier * dim_tensor_list[adv_indx_indices[i]]
212+
multiplier = multiplier * dim_list[adv_indx_indices[i]]
213+
cum_adv_index = get_trt_tensor(
214+
ctx, cum_adv_index, name + f"_index_sum_intermediate"
215+
)
213216
else:
214-
215217
multiplier = get_trt_tensor(
216218
ctx,
217219
dim_tensor_list[adv_indx_indices[adv_indx_count - 1]],
@@ -269,36 +271,29 @@ def index(
269271
== adv_indx_indices[adv_indx_count - 1] - adv_indx_indices[0] + 1
270272
):
271273
_LOGGER.debug(f"The indices are continuous in this case")
272-
if(is_numpy):
273-
concat_tensor_reshape.append(-1)
274-
else:
275-
concat_tensor_reshape.append(
276-
get_trt_tensor(ctx, -1, name + "_dynamic_concat")
277-
)
274+
concat_tensor_reshape.append(
275+
get_trt_tensor(ctx, -1, name + "_dynamic_concat")
276+
)
278277
for i in range(0, rank):
279278
if i not in adv_indx_indices:
280279
curr_dim = dim_tensor_list[i]
281280
concat_tensor_reshape.append(curr_dim)
282281

283-
unfold_tensor = None
284-
if(is_numpy):
285-
unfold_tensor = gather_out.reshape(concat_tensor)
286-
else:
287-
concat_tensor_layer = ctx.net.add_concatenation(concat_tensor_reshape)
288-
set_layer_name(
289-
concat_tensor_layer, target, name + "_index_Concat_reshape", source_ir
290-
)
291-
concat_tensor = concat_tensor_layer.get_output(0)
282+
concat_tensor_layer = ctx.net.add_concatenation(concat_tensor_reshape)
283+
set_layer_name(
284+
concat_tensor_layer, target, name + "_index_Concat_reshape", source_ir
285+
)
286+
concat_tensor = concat_tensor_layer.get_output(0)
292287

293-
regular_index_shuffle_layer = ctx.net.add_shuffle(gather_out)
294-
regular_index_shuffle_layer.set_input(1, concat_tensor)
295-
set_layer_name(
296-
regular_index_shuffle_layer,
297-
target,
298-
name + "_index_regular_index",
299-
source_ir,
300-
)
301-
unfold_tensor = regular_index_shuffle_layer.get_output(0)
288+
regular_index_shuffle_layer = ctx.net.add_shuffle(gather_out)
289+
regular_index_shuffle_layer.set_input(1, concat_tensor)
290+
set_layer_name(
291+
regular_index_shuffle_layer,
292+
target,
293+
name + "_index_regular_index",
294+
source_ir,
295+
)
296+
unfold_tensor = regular_index_shuffle_layer.get_output(0)
302297
_LOGGER.debug(f"The tensor is unfolded now")
303298
_LOGGER.debug(f"The unfolded tensor shape is {unfold_tensor.shape}")
304299

@@ -312,18 +307,14 @@ def index(
312307
new_order.append(i)
313308
_LOGGER.debug(f"Transposing the indices to correct position {new_order}")
314309

315-
transpose_tensor = None
316-
if(is_numpy):
317-
transpose_tensor = unfold_tensor[new_order]
318-
else:
319-
transpose_advanced_shuffle_layer.second_transpose = tuple(new_order)
320-
set_layer_name(
321-
transpose_advanced_shuffle_layer,
322-
target,
323-
name + "_index_advanced_shuffle_transpose",
324-
source_ir,
325-
)
326-
transpose_tensor = transpose_advanced_shuffle_layer.get_output(0)
310+
transpose_advanced_shuffle_layer.second_transpose = tuple(new_order)
311+
set_layer_name(
312+
transpose_advanced_shuffle_layer,
313+
target,
314+
name + "_index_advanced_shuffle_transpose",
315+
source_ir,
316+
)
317+
transpose_tensor = transpose_advanced_shuffle_layer.get_output(0)
327318

328319
# unfold advanced layer
329320
concat_final_tensor = []
@@ -337,29 +328,25 @@ def index(
337328
current_dim = dim_tensor_list[i]
338329
concat_final_tensor.append(current_dim)
339330

340-
reshape_output = []
341-
if(is_numpy):
342-
reshape_output = transpose_tensor.reshape(concat_final_tensor)
343-
else:
344-
concat_final_shape_layer = ctx.net.add_concatenation(concat_final_tensor)
345-
set_layer_name(
346-
concat_final_shape_layer,
347-
target,
348-
name + "_index_continuous_concat_final_shape_layer",
349-
source_ir,
350-
)
351-
concat_final_tensor = concat_final_shape_layer.get_output(0)
352-
353-
unfold_advanced_shuffle_layer = ctx.net.add_shuffle(transpose_tensor)
354-
# check this
355-
unfold_advanced_shuffle_layer.set_input(1, concat_final_tensor)
356-
set_layer_name(
357-
unfold_advanced_shuffle_layer,
358-
target,
359-
name + "_unfold_advanced_index",
360-
source_ir,
361-
)
362-
reshape_output = unfold_advanced_shuffle_layer.get_output(0)
331+
concat_final_shape_layer = ctx.net.add_concatenation(concat_final_tensor)
332+
set_layer_name(
333+
concat_final_shape_layer,
334+
target,
335+
name + "_index_continuous_concat_final_shape_layer",
336+
source_ir,
337+
)
338+
concat_final_tensor = concat_final_shape_layer.get_output(0)
339+
340+
unfold_advanced_shuffle_layer = ctx.net.add_shuffle(transpose_tensor)
341+
# check this
342+
unfold_advanced_shuffle_layer.set_input(1, concat_final_tensor)
343+
set_layer_name(
344+
unfold_advanced_shuffle_layer,
345+
target,
346+
name + "_unfold_advanced_index",
347+
source_ir,
348+
)
349+
reshape_output = unfold_advanced_shuffle_layer.get_output(0)
363350

364351
else:
365352
_LOGGER.debug(f"The indices are not continuous in this case")
@@ -370,27 +357,23 @@ def index(
370357
curr_dim = dim_tensor_list[i]
371358
concat_final_tensor.append(curr_dim)
372359

373-
reshape_output = None
374-
if(is_numpy):
375-
reshape_output = gather_out.reshape(concat_final_tensor)
376-
else:
377-
concat_final_shape_layer = ctx.net.add_concatenation(concat_final_tensor)
378-
set_layer_name(
379-
concat_final_shape_layer,
380-
target,
381-
name + "_index_non_continuous_concat_final_shape_layer",
382-
source_ir,
383-
)
384-
concat_final_tensor = concat_final_shape_layer.get_output(0)
385-
386-
reshape_layer = ctx.net.add_shuffle(gather_out)
387-
reshape_layer.set_input(1, concat_final_tensor)
388-
set_layer_name(
389-
reshape_layer,
390-
target,
391-
name + "_index_non_continuous_shuffle_final_shape_layer",
392-
source_ir,
393-
)
394-
reshape_output = reshape_layer.get_output(0)
360+
concat_final_shape_layer = ctx.net.add_concatenation(concat_final_tensor)
361+
set_layer_name(
362+
concat_final_shape_layer,
363+
target,
364+
name + "_index_non_continuous_concat_final_shape_layer",
365+
source_ir,
366+
)
367+
concat_final_tensor = concat_final_shape_layer.get_output(0)
368+
369+
reshape_layer = ctx.net.add_shuffle(gather_out)
370+
reshape_layer.set_input(1, concat_final_tensor)
371+
set_layer_name(
372+
reshape_layer,
373+
target,
374+
name + "_index_non_continuous_shuffle_final_shape_layer",
375+
source_ir,
376+
)
377+
reshape_output = reshape_layer.get_output(0)
395378

396379
return reshape_output

0 commit comments

Comments
 (0)