Skip to content

Commit b1534f8

Browse files
committed
Numpy changes for index
1 parent 4e5b0f6 commit b1534f8

File tree

2 files changed

+163
-115
lines changed

2 files changed

+163
-115
lines changed

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,7 @@ def cast_int_int_div_trt_tensor(
180180

181181

182182
def broadcastable(
183-
a: TRTTensor,
184-
b: TRTTensor,
183+
a: Union[TRTTensor, np.ndarray], b: Union[TRTTensor, np.ndarray]
185184
) -> bool:
186185
"Check if two tensors are broadcastable according to torch rules"
187186
a_shape = tuple(a.shape)

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 162 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import numpy as np
55
import tensorrt as trt
6+
import torch
67
from torch.fx.node import Target
78
from torch_tensorrt.dynamo._SourceIR import SourceIR
89
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
@@ -87,6 +88,8 @@ def index(
8788
# _LOGGER.debug(f"The index shape is {index.shape}")
8889
# check if the input is dynamic
8990
dynamic_shape = has_dynamic_shape(input.shape)
91+
#is_numpy is a flag to specify if input isa numpy
92+
is_numpy = False
9093

9194
# here we need to check if all the index are broadcastable
9295
# if no, then we need to broadcast
@@ -95,8 +98,14 @@ def index(
9598
if ind is not None:
9699
_LOGGER.debug(f"Shape of {i} index is {ind.shape}")
97100
adv_indx_indices.append(i)
98-
# torch.nn.parameter.Parameter=> torch.Tensor
99-
ind = get_trt_tensor(ctx, ind, name + f"_parameter_to_fp32_tensor_{i}")
101+
# torch.nn.parameter.Parameter=> numpy array
102+
# numpy array is kept as numpy
103+
# other cases are kept as TRTTensor
104+
if (isinstance(ind, torch.Tensor) or (ind, np.ndarray)):
105+
ind = to_numpy(ind)
106+
is_numpy = True
107+
else:
108+
ind = get_trt_tensor(ctx, ind, name + f"_parameter_to_fp32_tensor_{i}")
100109
if last_index is not None:
101110
assert broadcastable(
102111
ind, last_index
@@ -131,8 +140,11 @@ def index(
131140
for i in range(rank):
132141
dim = input_shape[i]
133142
dim_tensor = get_trt_tensor(ctx, dim, name + f"_individual_dim_{i}")
134-
# dim_tensor_list is a list of tensors
135-
dim_tensor_list.append(dim_tensor)
143+
# dim_tensor_list is a list of tensors or numpy
144+
if(is_numpy):
145+
dim_tensor_list.append(dim)
146+
else:
147+
dim_tensor_list.append(dim_tensor)
136148

137149
# for cases like
138150
# t: [x_1, y_1, y_2, ..., x_m, ..., y_n] -> t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n],
@@ -150,9 +162,14 @@ def index(
150162
if i not in adv_indx_indices:
151163
new_order.append(i)
152164
_LOGGER.debug(f"The new transpose order is {new_order}")
153-
transpose_layer.second_transpose = tuple(new_order)
154-
set_layer_name(transpose_layer, target, name + "_index_transpose", source_ir)
155-
transpose_tensor = transpose_layer.get_output(0)
165+
166+
transpose_tensor = None
167+
if(is_numpy):
168+
transpose_tensor = input[new_order]
169+
else:
170+
transpose_layer.second_transpose = tuple(new_order)
171+
set_layer_name(transpose_layer, target, name + "_index_transpose", source_ir)
172+
transpose_tensor = transpose_layer.get_output(0)
156173

157174
# Flatten [x_1, x_2,.......x_m, y_1, y_2,.....y_n]
158175
# transpose_tensor_shape = ctx.net.add_shape(transpose_tensor)
@@ -165,57 +182,70 @@ def index(
165182
for i in range(adv_indx_count, rank):
166183
mult_d1 = mult_d1 * transpose_tensor_shape[i]
167184

168-
concat_tensor_layer = ctx.net.add_concatenation(
169-
[
170-
get_trt_tensor(ctx, mult_d0, name + "_d0_shape"),
171-
get_trt_tensor(ctx, mult_d1, name + "_d1_shape"),
172-
]
173-
)
174-
set_layer_name(concat_tensor_layer, target, name + "_index_Concat", source_ir)
175-
concat_tensor = concat_tensor_layer.get_output(0)
185+
flatten_tensor = None
186+
if(is_numpy):
187+
flatten_tensor = transpose_tensor.reshape(mult_d0, mult_d1)
188+
else:
189+
concat_tensor_layer = ctx.net.add_concatenation(
190+
[
191+
get_trt_tensor(ctx, mult_d0, name + "_d0_shape"),
192+
get_trt_tensor(ctx, mult_d1, name + "_d1_shape"),
193+
]
194+
)
195+
set_layer_name(concat_tensor_layer, target, name + "_index_Concat", source_ir)
196+
concat_tensor = concat_tensor_layer.get_output(0)
197+
198+
reshape_layer = ctx.net.add_shuffle(transpose_tensor)
199+
reshape_layer.set_input(1, concat_tensor)
200+
flatten_tensor = reshape_layer.get_output(0)
176201

177-
reshape_layer = ctx.net.add_shuffle(transpose_tensor)
178-
# check this
179-
reshape_layer.set_input(1, concat_tensor)
180-
flatten_tensor = reshape_layer.get_output(0)
181202
_LOGGER.debug(f"The flatten tensor shape is {flatten_tensor.shape}")
182203

183204
# tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the
184205
# // j dimension of input x.
185-
multiplier = get_trt_tensor(
186-
ctx,
187-
dim_tensor_list[adv_indx_indices[adv_indx_count - 1]],
188-
name + "_dim_last",
189-
)
190-
cum_adv_index = tensor_indices[adv_indx_count - 1]
191-
for i in range(adv_indx_count - 2, -1, -1):
192-
adv_index = convert_binary_elementwise(
193-
ctx,
194-
target,
195-
source_ir,
196-
name + f"_index_intermediate_{i}",
197-
trt.ElementWiseOperation.PROD,
198-
multiplier,
199-
tensor_indices[i],
200-
)
201-
cum_adv_index = convert_binary_elementwise(
202-
ctx,
203-
target,
204-
source_ir,
205-
name + f"_index_sum_intermediate_{i}",
206-
trt.ElementWiseOperation.SUM,
207-
cum_adv_index,
208-
adv_index,
209-
)
210-
multiplier = convert_binary_elementwise(
206+
if(is_numpy):
207+
multiplier = dim_tensor_list[adv_indx_indices[adv_indx_count - 1]]
208+
cum_adv_index = tensor_indices[adv_indx_count - 1]
209+
for i in range(adv_indx_count - 2, -1, -1):
210+
adv_index = multiplier * tensor_indices[i]
211+
cum_adv_index = cum_adv_index + adv_index
212+
multiplier = multiplier * dim_tensor_list[adv_indx_indices[i]]
213+
else:
214+
215+
multiplier = get_trt_tensor(
211216
ctx,
212-
target,
213-
source_ir,
214-
name + f"_index_intermediate_xj_{i}",
215-
trt.ElementWiseOperation.PROD,
216-
multiplier,
217-
dim_tensor_list[adv_indx_indices[i]],
217+
dim_tensor_list[adv_indx_indices[adv_indx_count - 1]],
218+
name + "_dim_last",
218219
)
220+
cum_adv_index = tensor_indices[adv_indx_count - 1]
221+
for i in range(adv_indx_count - 2, -1, -1):
222+
adv_index = convert_binary_elementwise(
223+
ctx,
224+
target,
225+
source_ir,
226+
name + f"_index_intermediate_{i}",
227+
trt.ElementWiseOperation.PROD,
228+
multiplier,
229+
tensor_indices[i],
230+
)
231+
cum_adv_index = convert_binary_elementwise(
232+
ctx,
233+
target,
234+
source_ir,
235+
name + f"_index_sum_intermediate_{i}",
236+
trt.ElementWiseOperation.SUM,
237+
cum_adv_index,
238+
adv_index,
239+
)
240+
multiplier = convert_binary_elementwise(
241+
ctx,
242+
target,
243+
source_ir,
244+
name + f"_index_intermediate_xj_{i}",
245+
trt.ElementWiseOperation.PROD,
246+
multiplier,
247+
dim_tensor_list[adv_indx_indices[i]],
248+
)
219249

220250
gather_layer_element = ctx.net.add_gather(flatten_tensor, cum_adv_index, 0)
221251
set_layer_name(
@@ -239,29 +269,36 @@ def index(
239269
== adv_indx_indices[adv_indx_count - 1] - adv_indx_indices[0] + 1
240270
):
241271
_LOGGER.debug(f"The indices are continuous in this case")
242-
concat_tensor_reshape.append(
243-
get_trt_tensor(ctx, -1, name + "_dynamic_concat")
244-
)
272+
if(is_numpy):
273+
concat_tensor_reshape.append(-1)
274+
else:
275+
concat_tensor_reshape.append(
276+
get_trt_tensor(ctx, -1, name + "_dynamic_concat")
277+
)
245278
for i in range(0, rank):
246279
if i not in adv_indx_indices:
247280
curr_dim = dim_tensor_list[i]
248281
concat_tensor_reshape.append(curr_dim)
249282

250-
concat_tensor_layer = ctx.net.add_concatenation(concat_tensor_reshape)
251-
set_layer_name(
252-
concat_tensor_layer, target, name + "_index_Concat_reshape", source_ir
253-
)
254-
concat_tensor = concat_tensor_layer.get_output(0)
283+
unfold_tensor = None
284+
if(is_numpy):
285+
unfold_tensor = gather_out.reshape(concat_tensor)
286+
else:
287+
concat_tensor_layer = ctx.net.add_concatenation(concat_tensor_reshape)
288+
set_layer_name(
289+
concat_tensor_layer, target, name + "_index_Concat_reshape", source_ir
290+
)
291+
concat_tensor = concat_tensor_layer.get_output(0)
255292

256-
regular_index_shuffle_layer = ctx.net.add_shuffle(gather_out)
257-
regular_index_shuffle_layer.set_input(1, concat_tensor)
258-
set_layer_name(
259-
regular_index_shuffle_layer,
260-
target,
261-
name + "_index_regular_index",
262-
source_ir,
263-
)
264-
unfold_tensor = regular_index_shuffle_layer.get_output(0)
293+
regular_index_shuffle_layer = ctx.net.add_shuffle(gather_out)
294+
regular_index_shuffle_layer.set_input(1, concat_tensor)
295+
set_layer_name(
296+
regular_index_shuffle_layer,
297+
target,
298+
name + "_index_regular_index",
299+
source_ir,
300+
)
301+
unfold_tensor = regular_index_shuffle_layer.get_output(0)
265302
_LOGGER.debug(f"The tensor is unfolded now")
266303
_LOGGER.debug(f"The unfolded tensor shape is {unfold_tensor.shape}")
267304

@@ -275,14 +312,18 @@ def index(
275312
new_order.append(i)
276313
_LOGGER.debug(f"Transposing the indices to correct position {new_order}")
277314

278-
transpose_advanced_shuffle_layer.second_transpose = tuple(new_order)
279-
set_layer_name(
280-
transpose_advanced_shuffle_layer,
281-
target,
282-
name + "_index_advanced_shuffle_transpose",
283-
source_ir,
284-
)
285-
transpose_tensor = transpose_advanced_shuffle_layer.get_output(0)
315+
transpose_tensor = None
316+
if(is_numpy):
317+
transpose_tensor = unfold_tensor[new_order]
318+
else:
319+
transpose_advanced_shuffle_layer.second_transpose = tuple(new_order)
320+
set_layer_name(
321+
transpose_advanced_shuffle_layer,
322+
target,
323+
name + "_index_advanced_shuffle_transpose",
324+
source_ir,
325+
)
326+
transpose_tensor = transpose_advanced_shuffle_layer.get_output(0)
286327

287328
# unfold advanced layer
288329
concat_final_tensor = []
@@ -296,25 +337,29 @@ def index(
296337
current_dim = dim_tensor_list[i]
297338
concat_final_tensor.append(current_dim)
298339

299-
concat_final_shape_layer = ctx.net.add_concatenation(concat_final_tensor)
300-
set_layer_name(
301-
concat_final_shape_layer,
302-
target,
303-
name + "_index_continuous_concat_final_shape_layer",
304-
source_ir,
305-
)
306-
concat_final_tensor = concat_final_shape_layer.get_output(0)
307-
308-
unfold_advanced_shuffle_layer = ctx.net.add_shuffle(transpose_tensor)
309-
# check this
310-
unfold_advanced_shuffle_layer.set_input(1, concat_final_tensor)
311-
set_layer_name(
312-
unfold_advanced_shuffle_layer,
313-
target,
314-
name + "_unfold_advanced_index",
315-
source_ir,
316-
)
317-
reshape_output = unfold_advanced_shuffle_layer.get_output(0)
340+
reshape_output = []
341+
if(is_numpy):
342+
reshape_output = transpose_tensor.reshape(concat_final_tensor)
343+
else:
344+
concat_final_shape_layer = ctx.net.add_concatenation(concat_final_tensor)
345+
set_layer_name(
346+
concat_final_shape_layer,
347+
target,
348+
name + "_index_continuous_concat_final_shape_layer",
349+
source_ir,
350+
)
351+
concat_final_tensor = concat_final_shape_layer.get_output(0)
352+
353+
unfold_advanced_shuffle_layer = ctx.net.add_shuffle(transpose_tensor)
354+
# check this
355+
unfold_advanced_shuffle_layer.set_input(1, concat_final_tensor)
356+
set_layer_name(
357+
unfold_advanced_shuffle_layer,
358+
target,
359+
name + "_unfold_advanced_index",
360+
source_ir,
361+
)
362+
reshape_output = unfold_advanced_shuffle_layer.get_output(0)
318363

319364
else:
320365
_LOGGER.debug(f"The indices are not continuous in this case")
@@ -325,23 +370,27 @@ def index(
325370
curr_dim = dim_tensor_list[i]
326371
concat_final_tensor.append(curr_dim)
327372

328-
concat_final_shape_layer = ctx.net.add_concatenation(concat_final_tensor)
329-
set_layer_name(
330-
concat_final_shape_layer,
331-
target,
332-
name + "_index_non_continuous_concat_final_shape_layer",
333-
source_ir,
334-
)
335-
concat_final_tensor = concat_final_shape_layer.get_output(0)
336-
337-
reshape_layer = ctx.net.add_shuffle(gather_out)
338-
reshape_layer.set_input(1, concat_final_tensor)
339-
set_layer_name(
340-
reshape_layer,
341-
target,
342-
name + "_index_non_continuous_shuffle_final_shape_layer",
343-
source_ir,
344-
)
345-
reshape_output = reshape_layer.get_output(0)
373+
reshape_output = None
374+
if(is_numpy):
375+
reshape_output = gather_out.reshape(concat_final_tensor)
376+
else:
377+
concat_final_shape_layer = ctx.net.add_concatenation(concat_final_tensor)
378+
set_layer_name(
379+
concat_final_shape_layer,
380+
target,
381+
name + "_index_non_continuous_concat_final_shape_layer",
382+
source_ir,
383+
)
384+
concat_final_tensor = concat_final_shape_layer.get_output(0)
385+
386+
reshape_layer = ctx.net.add_shuffle(gather_out)
387+
reshape_layer.set_input(1, concat_final_tensor)
388+
set_layer_name(
389+
reshape_layer,
390+
target,
391+
name + "_index_non_continuous_shuffle_final_shape_layer",
392+
source_ir,
393+
)
394+
reshape_output = reshape_layer.get_output(0)
346395

347396
return reshape_output

0 commit comments

Comments
 (0)