Skip to content

Commit e432bf2

Browse files
authored
Aten::Index converter (#2277)
1 parent 5bb8cb0 commit e432bf2

File tree

3 files changed

+487
-3
lines changed

3 files changed

+487
-3
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,29 @@ def aten_ops_sigmoid(
137137
)
138138

139139

140+
@dynamo_tensorrt_converter(torch.ops.aten.index.Tensor)
141+
@enforce_tensor_types(
142+
{
143+
0: (TRTTensor,),
144+
}
145+
)
146+
def aten_ops_index(
147+
ctx: ConversionContext,
148+
target: Target,
149+
args: Tuple[Argument, ...],
150+
kwargs: Dict[str, Argument],
151+
name: str,
152+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
153+
return impl.select.index(
154+
ctx,
155+
target,
156+
SourceIR.ATEN,
157+
name,
158+
args[0],
159+
args[1],
160+
)
161+
162+
140163
@dynamo_tensorrt_converter(torch.ops.aten.tanh.default) # type: ignore[misc]
141164
def aten_ops_tanh(
142165
ctx: ConversionContext,

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 289 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,27 @@
1-
from typing import Optional, cast
1+
import logging
2+
from typing import Optional, Sequence, Union, cast
23

34
import numpy as np
5+
import tensorrt as trt
46
from torch.fx.node import Target
57
from torch_tensorrt.dynamo._SourceIR import SourceIR
68
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
7-
from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim, to_numpy
9+
from torch_tensorrt.dynamo.conversion.converter_utils import (
10+
broadcastable,
11+
get_positive_dim,
12+
get_trt_tensor,
13+
to_numpy,
14+
)
15+
from torch_tensorrt.dynamo.conversion.impl.elementwise import convert_binary_elementwise
816
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
9-
from torch_tensorrt.fx.converters.converter_utils import has_dynamic_shape
17+
from torch_tensorrt.fx.converters.converter_utils import (
18+
has_dynamic_shape,
19+
set_layer_name,
20+
)
1021
from torch_tensorrt.fx.types import Shape, TRTTensor
1122

23+
_LOGGER: logging.Logger = logging.getLogger(__name__)
24+
1225

1326
def select(
1427
ctx: ConversionContext,
@@ -59,3 +72,276 @@ def select(
5972
if len(out.shape) != 1:
6073
layer = ctx.net.add_shuffle(out)
6174
return layer.get_output(0)
75+
76+
77+
def index(
78+
ctx: ConversionContext,
79+
target: Target,
80+
source_ir: Optional[SourceIR],
81+
name: str,
82+
input: TRTTensor,
83+
index: Union[TRTTensor, Sequence[TRTTensor]],
84+
) -> TRTTensor:
85+
adv_indx_indices = []
86+
tensor_indices = []
87+
# _LOGGER.debug(f"The index shape is {index.shape}")
88+
# check if the input is dynamic
89+
dynamic_shape = has_dynamic_shape(input.shape)
90+
91+
# here we need to check if all the index are broadcastable
92+
# if no, then we need to broadcast
93+
last_index = None
94+
for i, ind in enumerate(index):
95+
if ind is not None:
96+
_LOGGER.debug(f"Shape of {i} index is {ind.shape}")
97+
adv_indx_indices.append(i)
98+
# torch.nn.parameter.Parameter=> torch.Tensor
99+
ind = get_trt_tensor(ctx, ind, name + f"_parameter_to_fp32_tensor_{i}")
100+
if last_index is not None:
101+
assert broadcastable(
102+
ind, last_index
103+
), "The indices should be broadcastable!"
104+
last_index = ind
105+
tensor_indices.append(ind)
106+
107+
if not tensor_indices:
108+
identity_layer = ctx.net.add_identity(input)
109+
identity_layer.set_output_type(0, trt.int32)
110+
set_layer_name(identity_layer, target, name + "_index_identity", source_ir)
111+
return identity_layer.get_output(0)
112+
elif len(tensor_indices) == 1:
113+
# This case works
114+
indices_tensor = tensor_indices[0]
115+
index = adv_indx_indices[0]
116+
_LOGGER.debug(f"The advanced index indices is {adv_indx_indices}")
117+
gather_layer = ctx.net.add_gather(input, indices_tensor, index)
118+
set_layer_name(gather_layer, target, name + "_index_gather", source_ir)
119+
return gather_layer.get_output(0)
120+
else:
121+
input_shape = input.shape
122+
_LOGGER.debug(f"The input shape is {input.shape}")
123+
if dynamic_shape:
124+
input_shape = get_shape_with_dynamic_shape(
125+
ctx.net, target, source_ir, name, input_shape, input
126+
)
127+
rank = len(input_shape)
128+
adv_indx_count = len(adv_indx_indices)
129+
dim_tensor_list = []
130+
131+
for i in range(rank):
132+
dim = input_shape[i]
133+
dim_tensor = get_trt_tensor(ctx, dim, name + f"_individual_dim_{i}")
134+
# dim_tensor_list is a list of tensors
135+
dim_tensor_list.append(dim_tensor)
136+
137+
# for cases like
138+
# t: [x_1, y_1, y_2, ..., x_m, ..., y_n] -> t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n],
139+
# where t is a tensor of rank m+n, {x_i} are axes where tensor index is provided, and {y_i} are axes
140+
# for ":"
141+
# Examples: x.shape = (10,20,30,40,50)
142+
# ind_1, ind_2 broadcasted to (2,3,4)
143+
# x[:, ind_1, ind_2] = 10, 2, 3, 4, 40, 50
144+
# x[:,ind_1, :, ind_2] = 2, 3, 4, 10, 30, 50
145+
transpose_layer = ctx.net.add_shuffle(input)
146+
new_order = []
147+
for i in range(adv_indx_count):
148+
new_order.append(adv_indx_indices[i])
149+
for i in range(rank):
150+
if i not in adv_indx_indices:
151+
new_order.append(i)
152+
_LOGGER.debug(f"The new transpose order is {new_order}")
153+
transpose_layer.second_transpose = tuple(new_order)
154+
set_layer_name(transpose_layer, target, name + "_index_transpose", source_ir)
155+
transpose_tensor = transpose_layer.get_output(0)
156+
157+
# Flatten [x_1, x_2,.......x_m, y_1, y_2,.....y_n]
158+
# transpose_tensor_shape = ctx.net.add_shape(transpose_tensor)
159+
transpose_tensor_shape = transpose_tensor.shape
160+
_LOGGER.debug(f"The shape of transpose tensor is {transpose_tensor_shape}")
161+
mult_d0 = 1
162+
for i in range(adv_indx_count):
163+
mult_d0 = mult_d0 * transpose_tensor_shape[i]
164+
mult_d1 = 1
165+
for i in range(adv_indx_count, rank):
166+
mult_d1 = mult_d1 * transpose_tensor_shape[i]
167+
168+
concat_tensor_layer = ctx.net.add_concatenation(
169+
[
170+
get_trt_tensor(ctx, mult_d0, name + "_d0_shape"),
171+
get_trt_tensor(ctx, mult_d1, name + "_d1_shape"),
172+
]
173+
)
174+
set_layer_name(concat_tensor_layer, target, name + "_index_Concat", source_ir)
175+
concat_tensor = concat_tensor_layer.get_output(0)
176+
177+
reshape_layer = ctx.net.add_shuffle(transpose_tensor)
178+
# check this
179+
reshape_layer.set_input(1, concat_tensor)
180+
flatten_tensor = reshape_layer.get_output(0)
181+
_LOGGER.debug(f"The flatten tensor shape is {flatten_tensor.shape}")
182+
183+
# tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the
184+
# // j dimension of input x.
185+
multiplier = get_trt_tensor(
186+
ctx,
187+
dim_tensor_list[adv_indx_indices[adv_indx_count - 1]],
188+
name + "_dim_last",
189+
)
190+
cum_adv_index = tensor_indices[adv_indx_count - 1]
191+
for i in range(adv_indx_count - 2, -1, -1):
192+
adv_index = convert_binary_elementwise(
193+
ctx,
194+
target,
195+
source_ir,
196+
name + f"_index_intermediate_{i}",
197+
trt.ElementWiseOperation.PROD,
198+
multiplier,
199+
tensor_indices[i],
200+
)
201+
cum_adv_index = convert_binary_elementwise(
202+
ctx,
203+
target,
204+
source_ir,
205+
name + f"_index_sum_intermediate_{i}",
206+
trt.ElementWiseOperation.SUM,
207+
cum_adv_index,
208+
adv_index,
209+
)
210+
multiplier = convert_binary_elementwise(
211+
ctx,
212+
target,
213+
source_ir,
214+
name + f"_index_intermediate_xj_{i}",
215+
trt.ElementWiseOperation.PROD,
216+
multiplier,
217+
dim_tensor_list[adv_indx_indices[i]],
218+
)
219+
220+
gather_layer_element = ctx.net.add_gather(flatten_tensor, cum_adv_index, 0)
221+
set_layer_name(
222+
gather_layer_element, target, name + "_index_gather_element", source_ir
223+
)
224+
gather_out = gather_layer_element.get_output(0)
225+
_LOGGER.debug(f"The shape after cumultative gather is {gather_out.shape}")
226+
_LOGGER.debug(f"The shape for cumulative adv index is {cum_adv_index}")
227+
228+
cum_adv_index_shape_layer = ctx.net.add_shape(cum_adv_index)
229+
set_layer_name(
230+
cum_adv_index_shape_layer, target, name + "_cum_adv_index_shape", source_ir
231+
)
232+
cum_adv_index_shape_tensor = cum_adv_index_shape_layer.get_output(0)
233+
cum_adv_index_shape = cum_adv_index.shape
234+
_LOGGER.debug(f"The shape for cumulative adv index is {cum_adv_index_shape}")
235+
# check if all advanced indices are consecutive
236+
concat_tensor_reshape = []
237+
if (
238+
adv_indx_count
239+
== adv_indx_indices[adv_indx_count - 1] - adv_indx_indices[0] + 1
240+
):
241+
_LOGGER.debug(f"The indices are continuous in this case")
242+
concat_tensor_reshape.append(
243+
get_trt_tensor(ctx, -1, name + "_dynamic_concat")
244+
)
245+
for i in range(0, rank):
246+
if i not in adv_indx_indices:
247+
curr_dim = dim_tensor_list[i]
248+
concat_tensor_reshape.append(curr_dim)
249+
250+
concat_tensor_layer = ctx.net.add_concatenation(concat_tensor_reshape)
251+
set_layer_name(
252+
concat_tensor_layer, target, name + "_index_Concat_reshape", source_ir
253+
)
254+
concat_tensor = concat_tensor_layer.get_output(0)
255+
256+
regular_index_shuffle_layer = ctx.net.add_shuffle(gather_out)
257+
regular_index_shuffle_layer.set_input(1, concat_tensor)
258+
set_layer_name(
259+
regular_index_shuffle_layer,
260+
target,
261+
name + "_index_regular_index",
262+
source_ir,
263+
)
264+
unfold_tensor = regular_index_shuffle_layer.get_output(0)
265+
_LOGGER.debug(f"The tensor is unfolded now")
266+
_LOGGER.debug(f"The unfolded tensor shape is {unfold_tensor.shape}")
267+
268+
# Transpose folded advanced indexed axis to its original location.
269+
transpose_advanced_shuffle_layer = ctx.net.add_shuffle(unfold_tensor)
270+
new_order = []
271+
for i in range(1, adv_indx_indices[0] + 1):
272+
new_order.append(i)
273+
new_order.append(0)
274+
for i in range(adv_indx_indices[0] + 1, rank - adv_indx_count + 1):
275+
new_order.append(i)
276+
_LOGGER.debug(f"Transposing the indices to correct position {new_order}")
277+
278+
transpose_advanced_shuffle_layer.second_transpose = tuple(new_order)
279+
set_layer_name(
280+
transpose_advanced_shuffle_layer,
281+
target,
282+
name + "_index_advanced_shuffle_transpose",
283+
source_ir,
284+
)
285+
transpose_tensor = transpose_advanced_shuffle_layer.get_output(0)
286+
287+
# unfold advanced layer
288+
concat_final_tensor = []
289+
for i in range(0, adv_indx_indices[0]):
290+
current_dim = dim_tensor_list[i]
291+
concat_final_tensor.append(current_dim)
292+
293+
concat_final_tensor.append(cum_adv_index_shape_tensor)
294+
for i in range(adv_indx_indices[0], rank):
295+
if i not in (adv_indx_indices):
296+
current_dim = dim_tensor_list[i]
297+
concat_final_tensor.append(current_dim)
298+
299+
concat_final_shape_layer = ctx.net.add_concatenation(concat_final_tensor)
300+
set_layer_name(
301+
concat_final_shape_layer,
302+
target,
303+
name + "_index_continuous_concat_final_shape_layer",
304+
source_ir,
305+
)
306+
concat_final_tensor = concat_final_shape_layer.get_output(0)
307+
308+
unfold_advanced_shuffle_layer = ctx.net.add_shuffle(transpose_tensor)
309+
# check this
310+
unfold_advanced_shuffle_layer.set_input(1, concat_final_tensor)
311+
set_layer_name(
312+
unfold_advanced_shuffle_layer,
313+
target,
314+
name + "_unfold_advanced_index",
315+
source_ir,
316+
)
317+
reshape_output = unfold_advanced_shuffle_layer.get_output(0)
318+
319+
else:
320+
_LOGGER.debug(f"The indices are not continuous in this case")
321+
concat_final_tensor = []
322+
concat_final_tensor.append(cum_adv_index_shape_tensor)
323+
for i in range(0, rank):
324+
if i not in adv_indx_indices:
325+
curr_dim = dim_tensor_list[i]
326+
concat_final_tensor.append(curr_dim)
327+
328+
concat_final_shape_layer = ctx.net.add_concatenation(concat_final_tensor)
329+
set_layer_name(
330+
concat_final_shape_layer,
331+
target,
332+
name + "_index_non_continuous_concat_final_shape_layer",
333+
source_ir,
334+
)
335+
concat_final_tensor = concat_final_shape_layer.get_output(0)
336+
337+
reshape_layer = ctx.net.add_shuffle(gather_out)
338+
reshape_layer.set_input(1, concat_final_tensor)
339+
set_layer_name(
340+
reshape_layer,
341+
target,
342+
name + "_index_non_continuous_shuffle_final_shape_layer",
343+
source_ir,
344+
)
345+
reshape_output = reshape_layer.get_output(0)
346+
347+
return reshape_output

0 commit comments

Comments
 (0)