Skip to content

Commit 56b8950

Browse files
authored
fix: Allow rank differences in aten.expand (#2234)
1 parent 1133432 commit 56b8950

File tree

4 files changed

+60
-28
lines changed

4 files changed

+60
-28
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,3 +420,21 @@ def aten_ops_clone(
420420
name,
421421
args[0],
422422
)
423+
424+
425+
@dynamo_tensorrt_converter(torch.ops.aten.expand.default)
426+
def aten_ops_expand(
427+
network: TRTNetwork,
428+
target: Target,
429+
args: Tuple[Argument, ...],
430+
kwargs: Dict[str, Argument],
431+
name: str,
432+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
433+
return impl.slice.expand(
434+
network,
435+
target,
436+
SourceIR.ATEN,
437+
name,
438+
args[0],
439+
args[1],
440+
)

py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py

Lines changed: 38 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
from torch_tensorrt.dynamo._SourceIR import SourceIR
66
from torch_tensorrt.dynamo.conversion.impl.slice.base import slice
77
from torch_tensorrt.fx.converters.converter_utils import (
8-
broadcast,
98
get_positive_dim,
10-
get_trt_tensor,
119
has_dynamic_shape,
10+
prepend_ones,
11+
set_layer_name,
1212
)
1313
from torch_tensorrt.fx.types import Shape, TRTNetwork, TRTTensor
1414

@@ -65,33 +65,46 @@ def expand(
6565
target: Target,
6666
source_ir: Optional[SourceIR],
6767
name: str,
68-
input: TRTTensor,
69-
sizes: Shape,
68+
input_t: TRTTensor,
69+
shape: Shape,
7070
) -> TRTTensor:
71-
shape = list(sizes)
72-
73-
input_val = get_trt_tensor(network, input, f"{name}_input")
71+
if not isinstance(input_t, TRTTensor):
72+
raise RuntimeError(
73+
f"expand received input {input_t} that is not a TensorRT ITensor"
74+
)
7475

75-
if network.has_implicit_batch_dimension:
76-
shape = shape[1:]
76+
shape_rank = len(shape)
77+
initial_tensor_rank = len(input_t.shape)
7778

78-
ranks = len(input_val.shape)
79-
# TRT does not support different dimension size
80-
# though this condition is not seen in the case of bmm
81-
# where input_t and shape dimensions are not equal
82-
assert len(shape) >= ranks
83-
if len(shape) != ranks:
84-
shape_tuple = tuple([0] * len(shape))
85-
shape_tensor = get_trt_tensor(network, input, f"{name}_shape")
86-
input_val, shape_tensor = broadcast(
87-
network, input_val, shape_tensor, f"{name}_input_val", f"{name}_shape_val"
79+
# If the rank of the input tensor is less than the shape's rank, pad with ones
80+
if initial_tensor_rank < shape_rank:
81+
input_t = prepend_ones(
82+
network,
83+
input_t,
84+
name + "_expand_broadcast",
85+
shape_rank - initial_tensor_rank,
8886
)
89-
ranks = len(shape)
87+
# If the rank of the input tensor is more than the shape's rank, raise error
88+
elif initial_tensor_rank > shape_rank:
89+
raise RuntimeError(
90+
f"expand called with {shape_rank}-dimensional shape on Tensor with {len(shape)} dimensions. "
91+
"Cannot expand to shape with rank smaller than original tensor."
92+
)
93+
94+
# After the above padding, the shape and tensor rank must be equal
95+
assert len(input_t.shape) == shape_rank
96+
97+
# -1 denotes taking the shape from the original input tensor
98+
shape = tuple(
99+
[input_t.shape[i] if shape[i] == -1 else shape[i] for i in range(shape_rank)]
100+
)
90101

91-
inshape = tuple(input_val.shape)
92-
shape_t = tuple(shape)
93-
start = tuple([0] * ranks)
102+
# Establish the desired output shape, strides, and starting indices
103+
input_tensor_shape = tuple(input_t.shape)
104+
start = tuple([0] * shape_rank)
94105
stride = tuple(
95-
[int(i == o) for i, o in zip(inshape, shape)]
106+
[int(i == o) for i, o in zip(input_tensor_shape, shape)]
96107
) # stride == 1 if dimensions match, 0 otherwise
97-
return slice(network, target, source_ir, name, input_val, start, shape_t, stride)
108+
layer = network.add_slice(input_t, start=start, shape=shape, stride=stride)
109+
set_layer_name(layer, target, name, source_ir)
110+
return layer.get_output(0)

py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def unsqueeze(
2727
)
2828

2929
dim = cast(int, dim)
30-
input_shape = input_val.shape
30+
3131
input_shape_size = (
3232
len(input_val.shape) + 1
3333
if network.has_implicit_batch_dimension
@@ -46,5 +46,5 @@ def unsqueeze(
4646
layer.reshape_dims = (
4747
tuple(input_val.shape)[:dim] + (1,) + tuple(input_val.shape)[dim:]
4848
)
49-
set_layer_name(layer, target, name)
49+
set_layer_name(layer, target, name, source_ir)
5050
return layer.get_output(0)

tests/py/dynamo/converters/test_expand_aten.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import torch
22
import torch.nn as nn
3+
from harness import DispatchTestCase
34
from parameterized import parameterized
45
from torch.testing._internal.common_utils import run_tests
5-
from harness import DispatchTestCase
66

77

88
class TestExpandConverter(DispatchTestCase):
@@ -12,6 +12,7 @@ class TestExpandConverter(DispatchTestCase):
1212
("3d_dim", (2, 3, 4), (2, 1, 1)),
1313
("4d_dim", (2, 3, 4, 5), (2, 1, 1, 1)),
1414
("keep_dim", (2, 3, -1, -1), (2, 1, 5, 5)),
15+
("different_ranks", (2, 3, -1, -1), (1, 5, 7)),
1516
]
1617
)
1718
def test_expand(self, _, sizes, init_size):

0 commit comments

Comments
 (0)