3
3
4
4
import numpy as np
5
5
import tensorrt as trt
6
+ import torch
6
7
from torch .fx .node import Target
7
8
from torch_tensorrt .dynamo ._SourceIR import SourceIR
8
9
from torch_tensorrt .dynamo .conversion ._ConversionContext import ConversionContext
@@ -80,23 +81,34 @@ def index(
80
81
source_ir : Optional [SourceIR ],
81
82
name : str ,
82
83
input : TRTTensor ,
83
- index : Union [TRTTensor , Sequence [ TRTTensor ]],
84
+ index : Sequence [ Union [TRTTensor , np . ndarray , torch . Tensor ]],
84
85
) -> TRTTensor :
85
86
adv_indx_indices = []
86
87
tensor_indices = []
87
- # _LOGGER.debug(f"The index shape is {index.shape}")
88
88
# check if the input is dynamic
89
89
dynamic_shape = has_dynamic_shape (input .shape )
90
-
90
+ # is_numpy is a flag to specify if all the indices are numpy or torchTensor.
91
+ # If any is not this flag will be set to False
92
+ _LOGGER .debug (
93
+ f"Determining whether aten.index constant-index optimization can be invoked"
94
+ )
95
+ is_numpy = all (
96
+ isinstance (ind , (torch .Tensor , np .ndarray )) for ind in index if ind is not None
97
+ )
91
98
# here we need to check if all the index are broadcastable
92
99
# if no, then we need to broadcast
93
100
last_index = None
94
101
for i , ind in enumerate (index ):
95
102
if ind is not None :
96
103
_LOGGER .debug (f"Shape of { i } index is { ind .shape } " )
97
104
adv_indx_indices .append (i )
98
- # torch.nn.parameter.Parameter=> torch.Tensor
99
- ind = get_trt_tensor (ctx , ind , name + f"_parameter_to_fp32_tensor_{ i } " )
105
+ # torch.nn.parameter.Parameter=> numpy array
106
+ # numpy array is kept as numpy
107
+ # other cases are kept as TRTTensor
108
+ if is_numpy :
109
+ ind = to_numpy (ind )
110
+ else :
111
+ ind = get_trt_tensor (ctx , ind , name + f"_parameter_to_fp32_tensor_{ i } " )
100
112
if last_index is not None :
101
113
assert broadcastable (
102
114
ind , last_index
@@ -110,8 +122,9 @@ def index(
110
122
set_layer_name (identity_layer , target , name + "_index_identity" , source_ir )
111
123
return identity_layer .get_output (0 )
112
124
elif len (tensor_indices ) == 1 :
113
- # This case works
114
- indices_tensor = tensor_indices [0 ]
125
+ indices_tensor = get_trt_tensor (
126
+ ctx , tensor_indices [0 ], name + f"_parameter_to_fp32_tensor"
127
+ )
115
128
index = adv_indx_indices [0 ]
116
129
_LOGGER .debug (f"The advanced index indices is { adv_indx_indices } " )
117
130
gather_layer = ctx .net .add_gather (input , indices_tensor , index )
@@ -150,6 +163,7 @@ def index(
150
163
if i not in adv_indx_indices :
151
164
new_order .append (i )
152
165
_LOGGER .debug (f"The new transpose order is { new_order } " )
166
+
153
167
transpose_layer .second_transpose = tuple (new_order )
154
168
set_layer_name (transpose_layer , target , name + "_index_transpose" , source_ir )
155
169
transpose_tensor = transpose_layer .get_output (0 )
@@ -175,47 +189,58 @@ def index(
175
189
concat_tensor = concat_tensor_layer .get_output (0 )
176
190
177
191
reshape_layer = ctx .net .add_shuffle (transpose_tensor )
178
- # check this
179
192
reshape_layer .set_input (1 , concat_tensor )
180
193
flatten_tensor = reshape_layer .get_output (0 )
194
+
181
195
_LOGGER .debug (f"The flatten tensor shape is { flatten_tensor .shape } " )
182
196
183
197
# tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the
184
198
# // j dimension of input x.
185
- multiplier = get_trt_tensor (
186
- ctx ,
187
- dim_tensor_list [adv_indx_indices [adv_indx_count - 1 ]],
188
- name + "_dim_last" ,
189
- )
190
- cum_adv_index = tensor_indices [adv_indx_count - 1 ]
191
- for i in range (adv_indx_count - 2 , - 1 , - 1 ):
192
- adv_index = convert_binary_elementwise (
193
- ctx ,
194
- target ,
195
- source_ir ,
196
- name + f"_index_intermediate_{ i } " ,
197
- trt .ElementWiseOperation .PROD ,
198
- multiplier ,
199
- tensor_indices [i ],
199
+ if is_numpy :
200
+ multiplier = input_shape [adv_indx_indices [adv_indx_count - 1 ]]
201
+ cum_adv_index = tensor_indices [adv_indx_count - 1 ]
202
+ for i in range (adv_indx_count - 2 , - 1 , - 1 ):
203
+ adv_index = multiplier * tensor_indices [i ]
204
+ cum_adv_index = cum_adv_index + adv_index
205
+ multiplier = multiplier * input_shape [adv_indx_indices [i ]]
206
+ cum_adv_index = get_trt_tensor (
207
+ ctx , cum_adv_index , name + f"_index_sum_intermediate"
200
208
)
201
- cum_adv_index = convert_binary_elementwise (
202
- ctx ,
203
- target ,
204
- source_ir ,
205
- name + f"_index_sum_intermediate_{ i } " ,
206
- trt .ElementWiseOperation .SUM ,
207
- cum_adv_index ,
208
- adv_index ,
209
- )
210
- multiplier = convert_binary_elementwise (
209
+ else :
210
+ multiplier = get_trt_tensor (
211
211
ctx ,
212
- target ,
213
- source_ir ,
214
- name + f"_index_intermediate_xj_{ i } " ,
215
- trt .ElementWiseOperation .PROD ,
216
- multiplier ,
217
- dim_tensor_list [adv_indx_indices [i ]],
212
+ dim_tensor_list [adv_indx_indices [adv_indx_count - 1 ]],
213
+ name + "_dim_last" ,
218
214
)
215
+ cum_adv_index = tensor_indices [adv_indx_count - 1 ]
216
+ for i in range (adv_indx_count - 2 , - 1 , - 1 ):
217
+ adv_index = convert_binary_elementwise (
218
+ ctx ,
219
+ target ,
220
+ source_ir ,
221
+ name + f"_index_intermediate_{ i } " ,
222
+ trt .ElementWiseOperation .PROD ,
223
+ multiplier ,
224
+ tensor_indices [i ],
225
+ )
226
+ cum_adv_index = convert_binary_elementwise (
227
+ ctx ,
228
+ target ,
229
+ source_ir ,
230
+ name + f"_index_sum_intermediate_{ i } " ,
231
+ trt .ElementWiseOperation .SUM ,
232
+ cum_adv_index ,
233
+ adv_index ,
234
+ )
235
+ multiplier = convert_binary_elementwise (
236
+ ctx ,
237
+ target ,
238
+ source_ir ,
239
+ name + f"_index_intermediate_xj_{ i } " ,
240
+ trt .ElementWiseOperation .PROD ,
241
+ multiplier ,
242
+ dim_tensor_list [adv_indx_indices [i ]],
243
+ )
219
244
220
245
gather_layer_element = ctx .net .add_gather (flatten_tensor , cum_adv_index , 0 )
221
246
set_layer_name (
0 commit comments