|
5 | 5 | from torch_tensorrt.dynamo._SourceIR import SourceIR
|
6 | 6 | from torch_tensorrt.dynamo.conversion.impl.slice.base import slice
|
7 | 7 | from torch_tensorrt.fx.converters.converter_utils import (
|
8 |
| - broadcast, |
9 | 8 | get_positive_dim,
|
10 |
| - get_trt_tensor, |
11 | 9 | has_dynamic_shape,
|
| 10 | + prepend_ones, |
| 11 | + set_layer_name, |
12 | 12 | )
|
13 | 13 | from torch_tensorrt.fx.types import Shape, TRTNetwork, TRTTensor
|
14 | 14 |
|
@@ -65,33 +65,46 @@ def expand(
|
65 | 65 | target: Target,
|
66 | 66 | source_ir: Optional[SourceIR],
|
67 | 67 | name: str,
|
68 |
| - input: TRTTensor, |
69 |
| - sizes: Shape, |
| 68 | + input_t: TRTTensor, |
| 69 | + shape: Shape, |
70 | 70 | ) -> TRTTensor:
|
71 |
| - shape = list(sizes) |
72 |
| - |
73 |
| - input_val = get_trt_tensor(network, input, f"{name}_input") |
| 71 | + if not isinstance(input_t, TRTTensor): |
| 72 | + raise RuntimeError( |
| 73 | + f"expand received input {input_t} that is not a TensorRT ITensor" |
| 74 | + ) |
74 | 75 |
|
75 |
| - if network.has_implicit_batch_dimension: |
76 |
| - shape = shape[1:] |
| 76 | + shape_rank = len(shape) |
| 77 | + initial_tensor_rank = len(input_t.shape) |
77 | 78 |
|
78 |
| - ranks = len(input_val.shape) |
79 |
| - # TRT does not support different dimension size |
80 |
| - # though this condition is not seen in the case of bmm |
81 |
| - # where input_t and shape dimensions are not equal |
82 |
| - assert len(shape) >= ranks |
83 |
| - if len(shape) != ranks: |
84 |
| - shape_tuple = tuple([0] * len(shape)) |
85 |
| - shape_tensor = get_trt_tensor(network, input, f"{name}_shape") |
86 |
| - input_val, shape_tensor = broadcast( |
87 |
| - network, input_val, shape_tensor, f"{name}_input_val", f"{name}_shape_val" |
| 79 | + # If the rank of the input tensor is less than the shape's rank, pad with ones |
| 80 | + if initial_tensor_rank < shape_rank: |
| 81 | + input_t = prepend_ones( |
| 82 | + network, |
| 83 | + input_t, |
| 84 | + name + "_expand_broadcast", |
| 85 | + shape_rank - initial_tensor_rank, |
88 | 86 | )
|
89 |
| - ranks = len(shape) |
| 87 | + # If the rank of the input tensor is more than the shape's rank, raise error |
| 88 | + elif initial_tensor_rank > shape_rank: |
| 89 | + raise RuntimeError( |
| 90 | + f"expand called with {shape_rank}-dimensional shape on Tensor with {len(shape)} dimensions. " |
| 91 | + "Cannot expand to shape with rank smaller than original tensor." |
| 92 | + ) |
| 93 | + |
| 94 | + # After the above padding, the shape and tensor rank must be equal |
| 95 | + assert len(input_t.shape) == shape_rank |
| 96 | + |
| 97 | + # -1 denotes taking the shape from the original input tensor |
| 98 | + shape = tuple( |
| 99 | + [input_t.shape[i] if shape[i] == -1 else shape[i] for i in range(shape_rank)] |
| 100 | + ) |
90 | 101 |
|
91 |
| - inshape = tuple(input_val.shape) |
92 |
| - shape_t = tuple(shape) |
93 |
| - start = tuple([0] * ranks) |
| 102 | + # Establish the desired output shape, strides, and starting indices |
| 103 | + input_tensor_shape = tuple(input_t.shape) |
| 104 | + start = tuple([0] * shape_rank) |
94 | 105 | stride = tuple(
|
95 |
| - [int(i == o) for i, o in zip(inshape, shape)] |
| 106 | + [int(i == o) for i, o in zip(input_tensor_shape, shape)] |
96 | 107 | ) # stride == 1 if dimensions match, 0 otherwise
|
97 |
| - return slice(network, target, source_ir, name, input_val, start, shape_t, stride) |
| 108 | + layer = network.add_slice(input_t, start=start, shape=shape, stride=stride) |
| 109 | + set_layer_name(layer, target, name, source_ir) |
| 110 | + return layer.get_output(0) |
0 commit comments