Skip to content

Commit 5910cf2

Browse files
committed
fix: Allow rank differences in aten.expand
- Add support for `aten.expand.default` in Dynamo converter registry - Build converter to support rank-padding for input Tensors, in line with the existing Torch behavior - Add test case to validate new behavior, in addition to existing cases validating old behavior
1 parent b57d83e commit 5910cf2

File tree

3 files changed

+74
-5
lines changed

3 files changed

+74
-5
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22
from typing import Any, Dict, Optional, Sequence, Tuple, Union
33

4+
import tensorrt as trt
45
import torch
56
from torch.fx.node import Argument, Node, Target
67
from torch_tensorrt.dynamo._SourceIR import SourceIR
@@ -12,8 +13,6 @@
1213
from torch_tensorrt.fx.converters import acc_ops_converters
1314
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
1415

15-
import tensorrt as trt
16-
1716
from .converter_registry import dynamo_tensorrt_converter
1817

1918
_LOGGER: logging.Logger = logging.getLogger(__name__)
@@ -380,3 +379,21 @@ def aten_ops_permute(
380379
args[0],
381380
args[1],
382381
)
382+
383+
384+
@dynamo_tensorrt_converter(torch.ops.aten.expand.default)
385+
def aten_ops_expand(
386+
network: TRTNetwork,
387+
target: Target,
388+
args: Tuple[Argument, ...],
389+
kwargs: Dict[str, Argument],
390+
name: str,
391+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
392+
return impl.unsqueeze.expand(
393+
network,
394+
target,
395+
SourceIR.ATEN,
396+
name,
397+
args[0],
398+
args[1],
399+
)

py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from torch_tensorrt.fx.converters.converter_utils import (
66
get_positive_dim,
77
get_trt_tensor,
8+
prepend_ones,
89
set_layer_name,
910
)
1011
from torch_tensorrt.fx.types import Shape, TRTNetwork, TRTTensor
@@ -27,7 +28,7 @@ def unsqueeze(
2728
)
2829

2930
dim = cast(int, dim)
30-
input_shape = input_val.shape
31+
3132
input_shape_size = (
3233
len(input_val.shape) + 1
3334
if network.has_implicit_batch_dimension
@@ -46,5 +47,55 @@ def unsqueeze(
4647
layer.reshape_dims = (
4748
tuple(input_val.shape)[:dim] + (1,) + tuple(input_val.shape)[dim:]
4849
)
49-
set_layer_name(layer, target, name)
50+
set_layer_name(layer, target, name, source_ir)
51+
return layer.get_output(0)
52+
53+
54+
def expand(
55+
network: TRTNetwork,
56+
target: Target,
57+
source_ir: Optional[SourceIR],
58+
name: str,
59+
input_t: TRTTensor,
60+
shape: Shape,
61+
) -> TRTTensor:
62+
if not isinstance(input_t, TRTTensor):
63+
raise RuntimeError(
64+
f"expand received input {input_t} that is not a TensorRT ITensor"
65+
)
66+
67+
shape_rank = len(shape)
68+
initial_tensor_rank = len(input_t.shape)
69+
70+
# If the rank of the input tensor is less than the shape's rank, pad with ones
71+
if initial_tensor_rank < shape_rank:
72+
input_t = prepend_ones(
73+
network,
74+
input_t,
75+
name + "_expand_ones_padding",
76+
shape_rank - initial_tensor_rank,
77+
)
78+
# If the rank of the input tensor is more than the shape's rank, raise error
79+
elif initial_tensor_rank > shape_rank:
80+
raise RuntimeError(
81+
f"expand called with {shape_rank}-dimensional shape on Tensor with {len(shape)} dimensions. "
82+
"Cannot expand to shape with rank smaller than original tensor."
83+
)
84+
85+
# After the above padding, the shape and tensor rank must be equal
86+
assert len(input_t.shape) == shape_rank
87+
88+
# -1 denotes taking the shape from the original input tensor
89+
shape = tuple(
90+
[input_t.shape[i] if shape[i] == -1 else shape[i] for i in range(shape_rank)]
91+
)
92+
93+
# Establish the desired output shape, strides, and starting indices
94+
input_tensor_shape = tuple(input_t.shape)
95+
start = tuple([0] * shape_rank)
96+
stride = tuple(
97+
[int(i == o) for i, o in zip(input_tensor_shape, shape)]
98+
) # stride == 1 if dimensions match, 0 otherwise
99+
layer = network.add_slice(input_t, start=start, shape=shape, stride=stride)
100+
set_layer_name(layer, target, name, source_ir)
50101
return layer.get_output(0)

tests/py/dynamo/converters/test_expand_aten.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import torch
22
import torch.nn as nn
3+
from harness import DispatchTestCase
34
from parameterized import parameterized
45
from torch.testing._internal.common_utils import run_tests
5-
from harness import DispatchTestCase
66

77

88
class TestExpandConverter(DispatchTestCase):
@@ -12,6 +12,7 @@ class TestExpandConverter(DispatchTestCase):
1212
("3d_dim", (2, 3, 4), (2, 1, 1)),
1313
("4d_dim", (2, 3, 4, 5), (2, 1, 1, 1)),
1414
("keep_dim", (2, 3, -1, -1), (2, 1, 5, 5)),
15+
("different_ranks", (2, 3, -1, -1), (1, 5, 7)),
1516
]
1617
)
1718
def test_expand(self, _, sizes, init_size):

0 commit comments

Comments
 (0)