|
1 |
| -from typing import Optional, cast |
| 1 | +import logging |
| 2 | +from typing import Optional, Sequence, Union, cast |
2 | 3 |
|
3 | 4 | import numpy as np
|
| 5 | +import tensorrt as trt |
4 | 6 | from torch.fx.node import Target
|
5 | 7 | from torch_tensorrt.dynamo._SourceIR import SourceIR
|
6 | 8 | from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
|
7 |
| -from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim, to_numpy |
| 9 | +from torch_tensorrt.dynamo.conversion.converter_utils import ( |
| 10 | + broadcastable, |
| 11 | + get_positive_dim, |
| 12 | + get_trt_tensor, |
| 13 | + to_numpy, |
| 14 | +) |
| 15 | +from torch_tensorrt.dynamo.conversion.impl.elementwise import convert_binary_elementwise |
8 | 16 | from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
|
9 |
| -from torch_tensorrt.fx.converters.converter_utils import has_dynamic_shape |
| 17 | +from torch_tensorrt.fx.converters.converter_utils import ( |
| 18 | + has_dynamic_shape, |
| 19 | + set_layer_name, |
| 20 | +) |
10 | 21 | from torch_tensorrt.fx.types import Shape, TRTTensor
|
11 | 22 |
|
| 23 | +_LOGGER: logging.Logger = logging.getLogger(__name__) |
| 24 | + |
12 | 25 |
|
13 | 26 | def select(
|
14 | 27 | ctx: ConversionContext,
|
@@ -59,3 +72,276 @@ def select(
|
59 | 72 | if len(out.shape) != 1:
|
60 | 73 | layer = ctx.net.add_shuffle(out)
|
61 | 74 | return layer.get_output(0)
|
| 75 | + |
| 76 | + |
| 77 | +def index( |
| 78 | + ctx: ConversionContext, |
| 79 | + target: Target, |
| 80 | + source_ir: Optional[SourceIR], |
| 81 | + name: str, |
| 82 | + input: TRTTensor, |
| 83 | + index: Union[TRTTensor, Sequence[TRTTensor]], |
| 84 | +) -> TRTTensor: |
| 85 | + adv_indx_indices = [] |
| 86 | + tensor_indices = [] |
| 87 | + # _LOGGER.debug(f"The index shape is {index.shape}") |
| 88 | + # check if the input is dynamic |
| 89 | + dynamic_shape = has_dynamic_shape(input.shape) |
| 90 | + |
| 91 | + # here we need to check if all the index are broadcastable |
| 92 | + # if no, then we need to broadcast |
| 93 | + last_index = None |
| 94 | + for i, ind in enumerate(index): |
| 95 | + if ind is not None: |
| 96 | + _LOGGER.debug(f"Shape of {i} index is {ind.shape}") |
| 97 | + adv_indx_indices.append(i) |
| 98 | + # torch.nn.parameter.Parameter=> torch.Tensor |
| 99 | + ind = get_trt_tensor(ctx, ind, name + f"_parameter_to_fp32_tensor_{i}") |
| 100 | + if last_index is not None: |
| 101 | + assert broadcastable( |
| 102 | + ind, last_index |
| 103 | + ), "The indices should be broadcastable!" |
| 104 | + last_index = ind |
| 105 | + tensor_indices.append(ind) |
| 106 | + |
| 107 | + if not tensor_indices: |
| 108 | + identity_layer = ctx.net.add_identity(input) |
| 109 | + identity_layer.set_output_type(0, trt.int32) |
| 110 | + set_layer_name(identity_layer, target, name + "_index_identity", source_ir) |
| 111 | + return identity_layer.get_output(0) |
| 112 | + elif len(tensor_indices) == 1: |
| 113 | + # This case works |
| 114 | + indices_tensor = tensor_indices[0] |
| 115 | + index = adv_indx_indices[0] |
| 116 | + _LOGGER.debug(f"The advanced index indices is {adv_indx_indices}") |
| 117 | + gather_layer = ctx.net.add_gather(input, indices_tensor, index) |
| 118 | + set_layer_name(gather_layer, target, name + "_index_gather", source_ir) |
| 119 | + return gather_layer.get_output(0) |
| 120 | + else: |
| 121 | + input_shape = input.shape |
| 122 | + _LOGGER.debug(f"The input shape is {input.shape}") |
| 123 | + if dynamic_shape: |
| 124 | + input_shape = get_shape_with_dynamic_shape( |
| 125 | + ctx.net, target, source_ir, name, input_shape, input |
| 126 | + ) |
| 127 | + rank = len(input_shape) |
| 128 | + adv_indx_count = len(adv_indx_indices) |
| 129 | + dim_tensor_list = [] |
| 130 | + |
| 131 | + for i in range(rank): |
| 132 | + dim = input_shape[i] |
| 133 | + dim_tensor = get_trt_tensor(ctx, dim, name + f"_individual_dim_{i}") |
| 134 | + # dim_tensor_list is a list of tensors |
| 135 | + dim_tensor_list.append(dim_tensor) |
| 136 | + |
| 137 | + # for cases like |
| 138 | + # t: [x_1, y_1, y_2, ..., x_m, ..., y_n] -> t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n], |
| 139 | + # where t is a tensor of rank m+n, {x_i} are axes where tensor index is provided, and {y_i} are axes |
| 140 | + # for ":" |
| 141 | + # Examples: x.shape = (10,20,30,40,50) |
| 142 | + # ind_1, ind_2 broadcasted to (2,3,4) |
| 143 | + # x[:, ind_1, ind_2] = 10, 2, 3, 4, 40, 50 |
| 144 | + # x[:,ind_1, :, ind_2] = 2, 3, 4, 10, 30, 50 |
| 145 | + transpose_layer = ctx.net.add_shuffle(input) |
| 146 | + new_order = [] |
| 147 | + for i in range(adv_indx_count): |
| 148 | + new_order.append(adv_indx_indices[i]) |
| 149 | + for i in range(rank): |
| 150 | + if i not in adv_indx_indices: |
| 151 | + new_order.append(i) |
| 152 | + _LOGGER.debug(f"The new transpose order is {new_order}") |
| 153 | + transpose_layer.second_transpose = tuple(new_order) |
| 154 | + set_layer_name(transpose_layer, target, name + "_index_transpose", source_ir) |
| 155 | + transpose_tensor = transpose_layer.get_output(0) |
| 156 | + |
| 157 | + # Flatten [x_1, x_2,.......x_m, y_1, y_2,.....y_n] |
| 158 | + # transpose_tensor_shape = ctx.net.add_shape(transpose_tensor) |
| 159 | + transpose_tensor_shape = transpose_tensor.shape |
| 160 | + _LOGGER.debug(f"The shape of transpose tensor is {transpose_tensor_shape}") |
| 161 | + mult_d0 = 1 |
| 162 | + for i in range(adv_indx_count): |
| 163 | + mult_d0 = mult_d0 * transpose_tensor_shape[i] |
| 164 | + mult_d1 = 1 |
| 165 | + for i in range(adv_indx_count, rank): |
| 166 | + mult_d1 = mult_d1 * transpose_tensor_shape[i] |
| 167 | + |
| 168 | + concat_tensor_layer = ctx.net.add_concatenation( |
| 169 | + [ |
| 170 | + get_trt_tensor(ctx, mult_d0, name + "_d0_shape"), |
| 171 | + get_trt_tensor(ctx, mult_d1, name + "_d1_shape"), |
| 172 | + ] |
| 173 | + ) |
| 174 | + set_layer_name(concat_tensor_layer, target, name + "_index_Concat", source_ir) |
| 175 | + concat_tensor = concat_tensor_layer.get_output(0) |
| 176 | + |
| 177 | + reshape_layer = ctx.net.add_shuffle(transpose_tensor) |
| 178 | + # check this |
| 179 | + reshape_layer.set_input(1, concat_tensor) |
| 180 | + flatten_tensor = reshape_layer.get_output(0) |
| 181 | + _LOGGER.debug(f"The flatten tensor shape is {flatten_tensor.shape}") |
| 182 | + |
| 183 | + # tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the |
| 184 | + # // j dimension of input x. |
| 185 | + multiplier = get_trt_tensor( |
| 186 | + ctx, |
| 187 | + dim_tensor_list[adv_indx_indices[adv_indx_count - 1]], |
| 188 | + name + "_dim_last", |
| 189 | + ) |
| 190 | + cum_adv_index = tensor_indices[adv_indx_count - 1] |
| 191 | + for i in range(adv_indx_count - 2, -1, -1): |
| 192 | + adv_index = convert_binary_elementwise( |
| 193 | + ctx, |
| 194 | + target, |
| 195 | + source_ir, |
| 196 | + name + f"_index_intermediate_{i}", |
| 197 | + trt.ElementWiseOperation.PROD, |
| 198 | + multiplier, |
| 199 | + tensor_indices[i], |
| 200 | + ) |
| 201 | + cum_adv_index = convert_binary_elementwise( |
| 202 | + ctx, |
| 203 | + target, |
| 204 | + source_ir, |
| 205 | + name + f"_index_sum_intermediate_{i}", |
| 206 | + trt.ElementWiseOperation.SUM, |
| 207 | + cum_adv_index, |
| 208 | + adv_index, |
| 209 | + ) |
| 210 | + multiplier = convert_binary_elementwise( |
| 211 | + ctx, |
| 212 | + target, |
| 213 | + source_ir, |
| 214 | + name + f"_index_intermediate_xj_{i}", |
| 215 | + trt.ElementWiseOperation.PROD, |
| 216 | + multiplier, |
| 217 | + dim_tensor_list[adv_indx_indices[i]], |
| 218 | + ) |
| 219 | + |
| 220 | + gather_layer_element = ctx.net.add_gather(flatten_tensor, cum_adv_index, 0) |
| 221 | + set_layer_name( |
| 222 | + gather_layer_element, target, name + "_index_gather_element", source_ir |
| 223 | + ) |
| 224 | + gather_out = gather_layer_element.get_output(0) |
| 225 | + _LOGGER.debug(f"The shape after cumultative gather is {gather_out.shape}") |
| 226 | + _LOGGER.debug(f"The shape for cumulative adv index is {cum_adv_index}") |
| 227 | + |
| 228 | + cum_adv_index_shape_layer = ctx.net.add_shape(cum_adv_index) |
| 229 | + set_layer_name( |
| 230 | + cum_adv_index_shape_layer, target, name + "_cum_adv_index_shape", source_ir |
| 231 | + ) |
| 232 | + cum_adv_index_shape_tensor = cum_adv_index_shape_layer.get_output(0) |
| 233 | + cum_adv_index_shape = cum_adv_index.shape |
| 234 | + _LOGGER.debug(f"The shape for cumulative adv index is {cum_adv_index_shape}") |
| 235 | + # check if all advanced indices are consecutive |
| 236 | + concat_tensor_reshape = [] |
| 237 | + if ( |
| 238 | + adv_indx_count |
| 239 | + == adv_indx_indices[adv_indx_count - 1] - adv_indx_indices[0] + 1 |
| 240 | + ): |
| 241 | + _LOGGER.debug(f"The indices are continuous in this case") |
| 242 | + concat_tensor_reshape.append( |
| 243 | + get_trt_tensor(ctx, -1, name + "_dynamic_concat") |
| 244 | + ) |
| 245 | + for i in range(0, rank): |
| 246 | + if i not in adv_indx_indices: |
| 247 | + curr_dim = dim_tensor_list[i] |
| 248 | + concat_tensor_reshape.append(curr_dim) |
| 249 | + |
| 250 | + concat_tensor_layer = ctx.net.add_concatenation(concat_tensor_reshape) |
| 251 | + set_layer_name( |
| 252 | + concat_tensor_layer, target, name + "_index_Concat_reshape", source_ir |
| 253 | + ) |
| 254 | + concat_tensor = concat_tensor_layer.get_output(0) |
| 255 | + |
| 256 | + regular_index_shuffle_layer = ctx.net.add_shuffle(gather_out) |
| 257 | + regular_index_shuffle_layer.set_input(1, concat_tensor) |
| 258 | + set_layer_name( |
| 259 | + regular_index_shuffle_layer, |
| 260 | + target, |
| 261 | + name + "_index_regular_index", |
| 262 | + source_ir, |
| 263 | + ) |
| 264 | + unfold_tensor = regular_index_shuffle_layer.get_output(0) |
| 265 | + _LOGGER.debug(f"The tensor is unfolded now") |
| 266 | + _LOGGER.debug(f"The unfolded tensor shape is {unfold_tensor.shape}") |
| 267 | + |
| 268 | + # Transpose folded advanced indexed axis to its original location. |
| 269 | + transpose_advanced_shuffle_layer = ctx.net.add_shuffle(unfold_tensor) |
| 270 | + new_order = [] |
| 271 | + for i in range(1, adv_indx_indices[0] + 1): |
| 272 | + new_order.append(i) |
| 273 | + new_order.append(0) |
| 274 | + for i in range(adv_indx_indices[0] + 1, rank - adv_indx_count + 1): |
| 275 | + new_order.append(i) |
| 276 | + _LOGGER.debug(f"Transposing the indices to correct position {new_order}") |
| 277 | + |
| 278 | + transpose_advanced_shuffle_layer.second_transpose = tuple(new_order) |
| 279 | + set_layer_name( |
| 280 | + transpose_advanced_shuffle_layer, |
| 281 | + target, |
| 282 | + name + "_index_advanced_shuffle_transpose", |
| 283 | + source_ir, |
| 284 | + ) |
| 285 | + transpose_tensor = transpose_advanced_shuffle_layer.get_output(0) |
| 286 | + |
| 287 | + # unfold advanced layer |
| 288 | + concat_final_tensor = [] |
| 289 | + for i in range(0, adv_indx_indices[0]): |
| 290 | + current_dim = dim_tensor_list[i] |
| 291 | + concat_final_tensor.append(current_dim) |
| 292 | + |
| 293 | + concat_final_tensor.append(cum_adv_index_shape_tensor) |
| 294 | + for i in range(adv_indx_indices[0], rank): |
| 295 | + if i not in (adv_indx_indices): |
| 296 | + current_dim = dim_tensor_list[i] |
| 297 | + concat_final_tensor.append(current_dim) |
| 298 | + |
| 299 | + concat_final_shape_layer = ctx.net.add_concatenation(concat_final_tensor) |
| 300 | + set_layer_name( |
| 301 | + concat_final_shape_layer, |
| 302 | + target, |
| 303 | + name + "_index_continuous_concat_final_shape_layer", |
| 304 | + source_ir, |
| 305 | + ) |
| 306 | + concat_final_tensor = concat_final_shape_layer.get_output(0) |
| 307 | + |
| 308 | + unfold_advanced_shuffle_layer = ctx.net.add_shuffle(transpose_tensor) |
| 309 | + # check this |
| 310 | + unfold_advanced_shuffle_layer.set_input(1, concat_final_tensor) |
| 311 | + set_layer_name( |
| 312 | + unfold_advanced_shuffle_layer, |
| 313 | + target, |
| 314 | + name + "_unfold_advanced_index", |
| 315 | + source_ir, |
| 316 | + ) |
| 317 | + reshape_output = unfold_advanced_shuffle_layer.get_output(0) |
| 318 | + |
| 319 | + else: |
| 320 | + _LOGGER.debug(f"The indices are not continuous in this case") |
| 321 | + concat_final_tensor = [] |
| 322 | + concat_final_tensor.append(cum_adv_index_shape_tensor) |
| 323 | + for i in range(0, rank): |
| 324 | + if i not in adv_indx_indices: |
| 325 | + curr_dim = dim_tensor_list[i] |
| 326 | + concat_final_tensor.append(curr_dim) |
| 327 | + |
| 328 | + concat_final_shape_layer = ctx.net.add_concatenation(concat_final_tensor) |
| 329 | + set_layer_name( |
| 330 | + concat_final_shape_layer, |
| 331 | + target, |
| 332 | + name + "_index_non_continuous_concat_final_shape_layer", |
| 333 | + source_ir, |
| 334 | + ) |
| 335 | + concat_final_tensor = concat_final_shape_layer.get_output(0) |
| 336 | + |
| 337 | + reshape_layer = ctx.net.add_shuffle(gather_out) |
| 338 | + reshape_layer.set_input(1, concat_final_tensor) |
| 339 | + set_layer_name( |
| 340 | + reshape_layer, |
| 341 | + target, |
| 342 | + name + "_index_non_continuous_shuffle_final_shape_layer", |
| 343 | + source_ir, |
| 344 | + ) |
| 345 | + reshape_output = reshape_layer.get_output(0) |
| 346 | + |
| 347 | + return reshape_output |
0 commit comments