Skip to content

Commit 0fc9c75

Browse files
committed
Correction for the is_numpy cases for mix of numpy and non numpy inputs
1 parent b1534f8 commit 0fc9c75

File tree

1 file changed

+108
-118
lines changed
  • py/torch_tensorrt/dynamo/conversion/impl

1 file changed

+108
-118
lines changed

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 108 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -81,16 +81,30 @@ def index(
8181
source_ir: Optional[SourceIR],
8282
name: str,
8383
input: TRTTensor,
84-
index: Union[TRTTensor, Sequence[TRTTensor]],
84+
index: Union[
85+
TRTTensor,
86+
Sequence[TRTTensor],
87+
np.ndarray,
88+
Sequence[np.ndarray],
89+
torch.Tensor,
90+
Sequence[torch.Tensor],
91+
],
8592
) -> TRTTensor:
8693
adv_indx_indices = []
8794
tensor_indices = []
8895
# _LOGGER.debug(f"The index shape is {index.shape}")
8996
# check if the input is dynamic
9097
dynamic_shape = has_dynamic_shape(input.shape)
91-
#is_numpy is a flag to specify if input isa numpy
92-
is_numpy = False
93-
98+
# is_numpy is a flag to specify if all the indices are numpy or torchTensor.
99+
# If any is not this flag will be set to False
100+
is_numpy = True
101+
_LOGGER.debug(f"Checking for the is_numpy flag")
102+
for i, ind in enumerate(index):
103+
if ind is None:
104+
continue
105+
if not (isinstance(ind, torch.Tensor) or isinstance(ind, np.ndarray)):
106+
is_numpy = False
107+
break
94108
# here we need to check if all the index are broadcastable
95109
# if no, then we need to broadcast
96110
last_index = None
@@ -101,7 +115,7 @@ def index(
101115
# torch.nn.parameter.Parameter=> numpy array
102116
# numpy array is kept as numpy
103117
# other cases are kept as TRTTensor
104-
if (isinstance(ind, torch.Tensor) or (ind, np.ndarray)):
118+
if is_numpy:
105119
ind = to_numpy(ind)
106120
is_numpy = True
107121
else:
@@ -119,8 +133,9 @@ def index(
119133
set_layer_name(identity_layer, target, name + "_index_identity", source_ir)
120134
return identity_layer.get_output(0)
121135
elif len(tensor_indices) == 1:
122-
# This case works
123-
indices_tensor = tensor_indices[0]
136+
indices_tensor = get_trt_tensor(
137+
ctx, tensor_indices[0], name + f"_parameter_to_fp32_tensor"
138+
)
124139
index = adv_indx_indices[0]
125140
_LOGGER.debug(f"The advanced index indices is {adv_indx_indices}")
126141
gather_layer = ctx.net.add_gather(input, indices_tensor, index)
@@ -136,15 +151,15 @@ def index(
136151
rank = len(input_shape)
137152
adv_indx_count = len(adv_indx_indices)
138153
dim_tensor_list = []
154+
dim_list = []
139155

140156
for i in range(rank):
141157
dim = input_shape[i]
142158
dim_tensor = get_trt_tensor(ctx, dim, name + f"_individual_dim_{i}")
143159
# dim_tensor_list is a list of tensors or numpy
144-
if(is_numpy):
145-
dim_tensor_list.append(dim)
146-
else:
147-
dim_tensor_list.append(dim_tensor)
160+
if is_numpy:
161+
dim_list.append(dim)
162+
dim_tensor_list.append(dim_tensor)
148163

149164
# for cases like
150165
# t: [x_1, y_1, y_2, ..., x_m, ..., y_n] -> t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n],
@@ -163,13 +178,9 @@ def index(
163178
new_order.append(i)
164179
_LOGGER.debug(f"The new transpose order is {new_order}")
165180

166-
transpose_tensor = None
167-
if(is_numpy):
168-
transpose_tensor = input[new_order]
169-
else:
170-
transpose_layer.second_transpose = tuple(new_order)
171-
set_layer_name(transpose_layer, target, name + "_index_transpose", source_ir)
172-
transpose_tensor = transpose_layer.get_output(0)
181+
transpose_layer.second_transpose = tuple(new_order)
182+
set_layer_name(transpose_layer, target, name + "_index_transpose", source_ir)
183+
transpose_tensor = transpose_layer.get_output(0)
173184

174185
# Flatten [x_1, x_2,.......x_m, y_1, y_2,.....y_n]
175186
# transpose_tensor_shape = ctx.net.add_shape(transpose_tensor)
@@ -182,36 +193,34 @@ def index(
182193
for i in range(adv_indx_count, rank):
183194
mult_d1 = mult_d1 * transpose_tensor_shape[i]
184195

185-
flatten_tensor = None
186-
if(is_numpy):
187-
flatten_tensor = transpose_tensor.reshape(mult_d0, mult_d1)
188-
else:
189-
concat_tensor_layer = ctx.net.add_concatenation(
190-
[
191-
get_trt_tensor(ctx, mult_d0, name + "_d0_shape"),
192-
get_trt_tensor(ctx, mult_d1, name + "_d1_shape"),
193-
]
194-
)
195-
set_layer_name(concat_tensor_layer, target, name + "_index_Concat", source_ir)
196-
concat_tensor = concat_tensor_layer.get_output(0)
196+
concat_tensor_layer = ctx.net.add_concatenation(
197+
[
198+
get_trt_tensor(ctx, mult_d0, name + "_d0_shape"),
199+
get_trt_tensor(ctx, mult_d1, name + "_d1_shape"),
200+
]
201+
)
202+
set_layer_name(concat_tensor_layer, target, name + "_index_Concat", source_ir)
203+
concat_tensor = concat_tensor_layer.get_output(0)
197204

198-
reshape_layer = ctx.net.add_shuffle(transpose_tensor)
199-
reshape_layer.set_input(1, concat_tensor)
200-
flatten_tensor = reshape_layer.get_output(0)
205+
reshape_layer = ctx.net.add_shuffle(transpose_tensor)
206+
reshape_layer.set_input(1, concat_tensor)
207+
flatten_tensor = reshape_layer.get_output(0)
201208

202209
_LOGGER.debug(f"The flatten tensor shape is {flatten_tensor.shape}")
203210

204211
# tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the
205212
# // j dimension of input x.
206-
if(is_numpy):
207-
multiplier = dim_tensor_list[adv_indx_indices[adv_indx_count - 1]]
213+
if is_numpy:
214+
multiplier = dim_list[adv_indx_indices[adv_indx_count - 1]]
208215
cum_adv_index = tensor_indices[adv_indx_count - 1]
209216
for i in range(adv_indx_count - 2, -1, -1):
210217
adv_index = multiplier * tensor_indices[i]
211218
cum_adv_index = cum_adv_index + adv_index
212-
multiplier = multiplier * dim_tensor_list[adv_indx_indices[i]]
219+
multiplier = multiplier * dim_list[adv_indx_indices[i]]
220+
cum_adv_index = get_trt_tensor(
221+
ctx, cum_adv_index, name + f"_index_sum_intermediate"
222+
)
213223
else:
214-
215224
multiplier = get_trt_tensor(
216225
ctx,
217226
dim_tensor_list[adv_indx_indices[adv_indx_count - 1]],
@@ -269,36 +278,29 @@ def index(
269278
== adv_indx_indices[adv_indx_count - 1] - adv_indx_indices[0] + 1
270279
):
271280
_LOGGER.debug(f"The indices are continuous in this case")
272-
if(is_numpy):
273-
concat_tensor_reshape.append(-1)
274-
else:
275-
concat_tensor_reshape.append(
276-
get_trt_tensor(ctx, -1, name + "_dynamic_concat")
277-
)
281+
concat_tensor_reshape.append(
282+
get_trt_tensor(ctx, -1, name + "_dynamic_concat")
283+
)
278284
for i in range(0, rank):
279285
if i not in adv_indx_indices:
280286
curr_dim = dim_tensor_list[i]
281287
concat_tensor_reshape.append(curr_dim)
282288

283-
unfold_tensor = None
284-
if(is_numpy):
285-
unfold_tensor = gather_out.reshape(concat_tensor)
286-
else:
287-
concat_tensor_layer = ctx.net.add_concatenation(concat_tensor_reshape)
288-
set_layer_name(
289-
concat_tensor_layer, target, name + "_index_Concat_reshape", source_ir
290-
)
291-
concat_tensor = concat_tensor_layer.get_output(0)
289+
concat_tensor_layer = ctx.net.add_concatenation(concat_tensor_reshape)
290+
set_layer_name(
291+
concat_tensor_layer, target, name + "_index_Concat_reshape", source_ir
292+
)
293+
concat_tensor = concat_tensor_layer.get_output(0)
292294

293-
regular_index_shuffle_layer = ctx.net.add_shuffle(gather_out)
294-
regular_index_shuffle_layer.set_input(1, concat_tensor)
295-
set_layer_name(
296-
regular_index_shuffle_layer,
297-
target,
298-
name + "_index_regular_index",
299-
source_ir,
300-
)
301-
unfold_tensor = regular_index_shuffle_layer.get_output(0)
295+
regular_index_shuffle_layer = ctx.net.add_shuffle(gather_out)
296+
regular_index_shuffle_layer.set_input(1, concat_tensor)
297+
set_layer_name(
298+
regular_index_shuffle_layer,
299+
target,
300+
name + "_index_regular_index",
301+
source_ir,
302+
)
303+
unfold_tensor = regular_index_shuffle_layer.get_output(0)
302304
_LOGGER.debug(f"The tensor is unfolded now")
303305
_LOGGER.debug(f"The unfolded tensor shape is {unfold_tensor.shape}")
304306

@@ -312,18 +314,14 @@ def index(
312314
new_order.append(i)
313315
_LOGGER.debug(f"Transposing the indices to correct position {new_order}")
314316

315-
transpose_tensor = None
316-
if(is_numpy):
317-
transpose_tensor = unfold_tensor[new_order]
318-
else:
319-
transpose_advanced_shuffle_layer.second_transpose = tuple(new_order)
320-
set_layer_name(
321-
transpose_advanced_shuffle_layer,
322-
target,
323-
name + "_index_advanced_shuffle_transpose",
324-
source_ir,
325-
)
326-
transpose_tensor = transpose_advanced_shuffle_layer.get_output(0)
317+
transpose_advanced_shuffle_layer.second_transpose = tuple(new_order)
318+
set_layer_name(
319+
transpose_advanced_shuffle_layer,
320+
target,
321+
name + "_index_advanced_shuffle_transpose",
322+
source_ir,
323+
)
324+
transpose_tensor = transpose_advanced_shuffle_layer.get_output(0)
327325

328326
# unfold advanced layer
329327
concat_final_tensor = []
@@ -337,29 +335,25 @@ def index(
337335
current_dim = dim_tensor_list[i]
338336
concat_final_tensor.append(current_dim)
339337

340-
reshape_output = []
341-
if(is_numpy):
342-
reshape_output = transpose_tensor.reshape(concat_final_tensor)
343-
else:
344-
concat_final_shape_layer = ctx.net.add_concatenation(concat_final_tensor)
345-
set_layer_name(
346-
concat_final_shape_layer,
347-
target,
348-
name + "_index_continuous_concat_final_shape_layer",
349-
source_ir,
350-
)
351-
concat_final_tensor = concat_final_shape_layer.get_output(0)
352-
353-
unfold_advanced_shuffle_layer = ctx.net.add_shuffle(transpose_tensor)
354-
# check this
355-
unfold_advanced_shuffle_layer.set_input(1, concat_final_tensor)
356-
set_layer_name(
357-
unfold_advanced_shuffle_layer,
358-
target,
359-
name + "_unfold_advanced_index",
360-
source_ir,
361-
)
362-
reshape_output = unfold_advanced_shuffle_layer.get_output(0)
338+
concat_final_shape_layer = ctx.net.add_concatenation(concat_final_tensor)
339+
set_layer_name(
340+
concat_final_shape_layer,
341+
target,
342+
name + "_index_continuous_concat_final_shape_layer",
343+
source_ir,
344+
)
345+
concat_final_tensor = concat_final_shape_layer.get_output(0)
346+
347+
unfold_advanced_shuffle_layer = ctx.net.add_shuffle(transpose_tensor)
348+
# check this
349+
unfold_advanced_shuffle_layer.set_input(1, concat_final_tensor)
350+
set_layer_name(
351+
unfold_advanced_shuffle_layer,
352+
target,
353+
name + "_unfold_advanced_index",
354+
source_ir,
355+
)
356+
reshape_output = unfold_advanced_shuffle_layer.get_output(0)
363357

364358
else:
365359
_LOGGER.debug(f"The indices are not continuous in this case")
@@ -370,27 +364,23 @@ def index(
370364
curr_dim = dim_tensor_list[i]
371365
concat_final_tensor.append(curr_dim)
372366

373-
reshape_output = None
374-
if(is_numpy):
375-
reshape_output = gather_out.reshape(concat_final_tensor)
376-
else:
377-
concat_final_shape_layer = ctx.net.add_concatenation(concat_final_tensor)
378-
set_layer_name(
379-
concat_final_shape_layer,
380-
target,
381-
name + "_index_non_continuous_concat_final_shape_layer",
382-
source_ir,
383-
)
384-
concat_final_tensor = concat_final_shape_layer.get_output(0)
385-
386-
reshape_layer = ctx.net.add_shuffle(gather_out)
387-
reshape_layer.set_input(1, concat_final_tensor)
388-
set_layer_name(
389-
reshape_layer,
390-
target,
391-
name + "_index_non_continuous_shuffle_final_shape_layer",
392-
source_ir,
393-
)
394-
reshape_output = reshape_layer.get_output(0)
367+
concat_final_shape_layer = ctx.net.add_concatenation(concat_final_tensor)
368+
set_layer_name(
369+
concat_final_shape_layer,
370+
target,
371+
name + "_index_non_continuous_concat_final_shape_layer",
372+
source_ir,
373+
)
374+
concat_final_tensor = concat_final_shape_layer.get_output(0)
375+
376+
reshape_layer = ctx.net.add_shuffle(gather_out)
377+
reshape_layer.set_input(1, concat_final_tensor)
378+
set_layer_name(
379+
reshape_layer,
380+
target,
381+
name + "_index_non_continuous_shuffle_final_shape_layer",
382+
source_ir,
383+
)
384+
reshape_output = reshape_layer.get_output(0)
395385

396386
return reshape_output

0 commit comments

Comments
 (0)