Skip to content

Commit 7029e91

Browse files
authored
Numpy changes for aten::index converter (#2396)
1 parent 88f6812 commit 7029e91

File tree

3 files changed

+93
-43
lines changed

3 files changed

+93
-43
lines changed

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,7 @@ def cast_int_int_div_trt_tensor(
181181

182182

183183
def broadcastable(
184-
a: TRTTensor,
185-
b: TRTTensor,
184+
a: Union[TRTTensor, np.ndarray], b: Union[TRTTensor, np.ndarray]
186185
) -> bool:
187186
"Check if two tensors are broadcastable according to torch rules"
188187
a_shape = tuple(a.shape)

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 64 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import numpy as np
55
import tensorrt as trt
6+
import torch
67
from torch.fx.node import Target
78
from torch_tensorrt.dynamo._SourceIR import SourceIR
89
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
@@ -80,23 +81,34 @@ def index(
8081
source_ir: Optional[SourceIR],
8182
name: str,
8283
input: TRTTensor,
83-
index: Union[TRTTensor, Sequence[TRTTensor]],
84+
index: Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]],
8485
) -> TRTTensor:
8586
adv_indx_indices = []
8687
tensor_indices = []
87-
# _LOGGER.debug(f"The index shape is {index.shape}")
8888
# check if the input is dynamic
8989
dynamic_shape = has_dynamic_shape(input.shape)
90-
90+
# is_numpy is a flag to specify if all the indices are numpy or torchTensor.
91+
# If any is not this flag will be set to False
92+
_LOGGER.debug(
93+
f"Determining whether aten.index constant-index optimization can be invoked"
94+
)
95+
is_numpy = all(
96+
isinstance(ind, (torch.Tensor, np.ndarray)) for ind in index if ind is not None
97+
)
9198
# here we need to check if all the index are broadcastable
9299
# if no, then we need to broadcast
93100
last_index = None
94101
for i, ind in enumerate(index):
95102
if ind is not None:
96103
_LOGGER.debug(f"Shape of {i} index is {ind.shape}")
97104
adv_indx_indices.append(i)
98-
# torch.nn.parameter.Parameter=> torch.Tensor
99-
ind = get_trt_tensor(ctx, ind, name + f"_parameter_to_fp32_tensor_{i}")
105+
# torch.nn.parameter.Parameter=> numpy array
106+
# numpy array is kept as numpy
107+
# other cases are kept as TRTTensor
108+
if is_numpy:
109+
ind = to_numpy(ind)
110+
else:
111+
ind = get_trt_tensor(ctx, ind, name + f"_parameter_to_fp32_tensor_{i}")
100112
if last_index is not None:
101113
assert broadcastable(
102114
ind, last_index
@@ -110,8 +122,9 @@ def index(
110122
set_layer_name(identity_layer, target, name + "_index_identity", source_ir)
111123
return identity_layer.get_output(0)
112124
elif len(tensor_indices) == 1:
113-
# This case works
114-
indices_tensor = tensor_indices[0]
125+
indices_tensor = get_trt_tensor(
126+
ctx, tensor_indices[0], name + f"_parameter_to_fp32_tensor"
127+
)
115128
index = adv_indx_indices[0]
116129
_LOGGER.debug(f"The advanced index indices is {adv_indx_indices}")
117130
gather_layer = ctx.net.add_gather(input, indices_tensor, index)
@@ -150,6 +163,7 @@ def index(
150163
if i not in adv_indx_indices:
151164
new_order.append(i)
152165
_LOGGER.debug(f"The new transpose order is {new_order}")
166+
153167
transpose_layer.second_transpose = tuple(new_order)
154168
set_layer_name(transpose_layer, target, name + "_index_transpose", source_ir)
155169
transpose_tensor = transpose_layer.get_output(0)
@@ -175,47 +189,58 @@ def index(
175189
concat_tensor = concat_tensor_layer.get_output(0)
176190

177191
reshape_layer = ctx.net.add_shuffle(transpose_tensor)
178-
# check this
179192
reshape_layer.set_input(1, concat_tensor)
180193
flatten_tensor = reshape_layer.get_output(0)
194+
181195
_LOGGER.debug(f"The flatten tensor shape is {flatten_tensor.shape}")
182196

183197
# tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the
184198
# // j dimension of input x.
185-
multiplier = get_trt_tensor(
186-
ctx,
187-
dim_tensor_list[adv_indx_indices[adv_indx_count - 1]],
188-
name + "_dim_last",
189-
)
190-
cum_adv_index = tensor_indices[adv_indx_count - 1]
191-
for i in range(adv_indx_count - 2, -1, -1):
192-
adv_index = convert_binary_elementwise(
193-
ctx,
194-
target,
195-
source_ir,
196-
name + f"_index_intermediate_{i}",
197-
trt.ElementWiseOperation.PROD,
198-
multiplier,
199-
tensor_indices[i],
199+
if is_numpy:
200+
multiplier = input_shape[adv_indx_indices[adv_indx_count - 1]]
201+
cum_adv_index = tensor_indices[adv_indx_count - 1]
202+
for i in range(adv_indx_count - 2, -1, -1):
203+
adv_index = multiplier * tensor_indices[i]
204+
cum_adv_index = cum_adv_index + adv_index
205+
multiplier = multiplier * input_shape[adv_indx_indices[i]]
206+
cum_adv_index = get_trt_tensor(
207+
ctx, cum_adv_index, name + f"_index_sum_intermediate"
200208
)
201-
cum_adv_index = convert_binary_elementwise(
202-
ctx,
203-
target,
204-
source_ir,
205-
name + f"_index_sum_intermediate_{i}",
206-
trt.ElementWiseOperation.SUM,
207-
cum_adv_index,
208-
adv_index,
209-
)
210-
multiplier = convert_binary_elementwise(
209+
else:
210+
multiplier = get_trt_tensor(
211211
ctx,
212-
target,
213-
source_ir,
214-
name + f"_index_intermediate_xj_{i}",
215-
trt.ElementWiseOperation.PROD,
216-
multiplier,
217-
dim_tensor_list[adv_indx_indices[i]],
212+
dim_tensor_list[adv_indx_indices[adv_indx_count - 1]],
213+
name + "_dim_last",
218214
)
215+
cum_adv_index = tensor_indices[adv_indx_count - 1]
216+
for i in range(adv_indx_count - 2, -1, -1):
217+
adv_index = convert_binary_elementwise(
218+
ctx,
219+
target,
220+
source_ir,
221+
name + f"_index_intermediate_{i}",
222+
trt.ElementWiseOperation.PROD,
223+
multiplier,
224+
tensor_indices[i],
225+
)
226+
cum_adv_index = convert_binary_elementwise(
227+
ctx,
228+
target,
229+
source_ir,
230+
name + f"_index_sum_intermediate_{i}",
231+
trt.ElementWiseOperation.SUM,
232+
cum_adv_index,
233+
adv_index,
234+
)
235+
multiplier = convert_binary_elementwise(
236+
ctx,
237+
target,
238+
source_ir,
239+
name + f"_index_intermediate_xj_{i}",
240+
trt.ElementWiseOperation.PROD,
241+
multiplier,
242+
dim_tensor_list[adv_indx_indices[i]],
243+
)
219244

220245
gather_layer_element = ctx.net.add_gather(flatten_tensor, cum_adv_index, 0)
221246
set_layer_name(

tests/py/dynamo/conversion/test_index_aten.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,10 @@
22

33
import torch
44
import torch.nn as nn
5+
from .harness import DispatchTestCase
56
from torch.testing._internal.common_utils import run_tests
67
from torch_tensorrt import Input
78

8-
from .harness import DispatchTestCase
9-
109

1110
class TestIndexConverter(DispatchTestCase):
1211
def test_index_zero_two_dim(self):
@@ -27,6 +26,21 @@ def forward(self, x):
2726
input,
2827
)
2928

29+
def test_index_zero_two_dim_ITensor(self):
30+
class TestModule(nn.Module):
31+
def forward(self, x, index0):
32+
indices = [None, index0]
33+
out = torch.ops.aten.index.Tensor(x, indices)
34+
return out
35+
36+
input = torch.randn(2, 2)
37+
index0 = torch.randint(0, 1, (1, 1))
38+
index0 = index0.to(torch.int32)
39+
self.run_test(
40+
TestModule(),
41+
[input, index0],
42+
)
43+
3044
def test_index_zero_index_three_dim(self):
3145
class TestModule(nn.Module):
3246
def __init__(self):
@@ -44,6 +58,18 @@ def forward(self, x):
4458
input,
4559
)
4660

61+
def test_index_zero_index_three_dim_ITensor(self):
62+
class TestModule(nn.Module):
63+
def forward(self, x, index0):
64+
indices = [None, index0, None]
65+
out = torch.ops.aten.index.Tensor(x, indices)
66+
return out
67+
68+
input = torch.randn(2, 2, 2)
69+
index0 = torch.randint(0, 1, (1, 1))
70+
index0 = index0.to(torch.int32)
71+
self.run_test(TestModule(), [input, index0])
72+
4773
def test_index_zero_index_one_index_two_three_dim(self):
4874
class TestModule(nn.Module):
4975
def __init__(self):

0 commit comments

Comments
 (0)