Skip to content

Commit e3363df

Browse files
authored
chore: remove aten.full decomposition (#2954)
1 parent d245716 commit e3363df

File tree

3 files changed

+60
-19
lines changed

3 files changed

+60
-19
lines changed

py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
get_trt_tensor,
1515
)
1616
from torch_tensorrt.dynamo.conversion.impl.cat import cat
17+
from torch_tensorrt.dynamo.conversion.impl.shape import shape as get_shape
1718
from torch_tensorrt.dynamo.conversion.impl.slice.base import slice
19+
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
1820
from torch_tensorrt.fx.converters.converter_utils import (
1921
has_dynamic_shape,
2022
prepend_ones,
@@ -90,27 +92,39 @@ def expand(
9092
# After the above padding, the shape and tensor rank must be equal
9193
assert len(input_t.shape) == shape_rank
9294

93-
# -1 denotes taking the shape from the original input tensor
94-
shape = tuple(
95-
[input_t.shape[i] if shape[i] == -1 else shape[i] for i in range(shape_rank)]
96-
)
95+
shape_t = []
96+
for i in range(shape_rank):
97+
if shape[i] == -1:
98+
shape_t.append(
99+
get_shape(ctx, target, source_ir, name + f"_shape_dim{i}", input_t, i)
100+
)
101+
else:
102+
shape_t.append(shape[i])
97103

98104
# Establish the desired output shape, strides, and starting indices
99105
input_tensor_shape = tuple(input_t.shape)
100106
start = tuple([0] * shape_rank)
101-
stride = tuple(
102-
[int(i == o) for i, o in zip(input_tensor_shape, shape)]
103-
) # stride == 1 if dimensions match, 0 otherwise
104107

105-
shape_ = shape
108+
# TODO: Revisit stride calculation. stride[dim]=0 implies that dimension is being broadcasted.
109+
# stride should be 1 for all non-broadcasted dims
110+
stride = []
111+
for i, o in zip(input_tensor_shape, shape_t):
112+
# If the shape has ITensor, we treat it as a reshape dim instead of a broadcasted dim
113+
# shape_t cannot have -1. If the input at this dimension has a shape of -1, set the stride to 1. This indicates that the input is dynamic and does not imply broadcasting at that specific dimension.
114+
if isinstance(i, int) and isinstance(o, int) and i != DYNAMIC_DIM:
115+
stride.append(int(i == o))
116+
else:
117+
stride.append(1)
118+
119+
shape_ = shape_t
106120
# Handle dynamic shapes case where shape has dynamic dimension
107-
if any(isinstance(ele, TRTTensor) for ele in shape):
121+
if any(isinstance(ele, TRTTensor) for ele in shape_t):
108122
shape_ = cat(
109123
ctx,
110124
target,
111125
source_ir,
112126
name + "_shape_concat",
113-
shape,
127+
shape_t,
114128
0,
115129
cast_dtype=trt.int32,
116130
)

py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,6 @@
166166
aten.clamp_min,
167167
aten.clamp_max,
168168
aten.linalg_vector_norm,
169-
aten.full,
170169
aten.repeat,
171170
}
172171
torch_disabled_decompositions: Set[Union[OpOverload, OpOverloadPacket]] = {

tests/py/dynamo/conversion/test_expand_aten.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,31 +2,59 @@
22
import torch.nn as nn
33
from parameterized import parameterized
44
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt import Input
56

67
from .harness import DispatchTestCase
78

89

910
class TestExpandConverter(DispatchTestCase):
1011
@parameterized.expand(
1112
[
12-
("2d_dim", (2, 3), (2, 1)),
13-
("3d_dim", (2, 3, 4), (2, 1, 1)),
14-
("4d_dim", (2, 3, 4, 5), (2, 1, 1, 1)),
15-
("keep_dim", (2, 3, -1, -1), (2, 1, 5, 5)),
16-
("different_ranks", (2, 3, -1, -1), (1, 5, 7)),
13+
("2d_dim", (2, 1), (2, 3)),
14+
("3d_dim", (2, 1, 1), (2, 3, 4)),
15+
("4d_dim", (2, 1, 1, 1), (2, 3, 4, 5)),
16+
("keep_dim", (2, 1, 5, 5), (2, 3, -1, -1)),
17+
("different_ranks", (1, 5, 7), (2, 3, -1, -1)),
1718
]
1819
)
19-
def test_expand(self, _, sizes, init_size):
20+
def test_expand(self, _, input_shape, expanded_shape):
2021
class Expand(nn.Module):
2122
def forward(self, x):
22-
return torch.ops.aten.expand.default(x, sizes)
23+
return torch.ops.aten.expand.default(x, expanded_shape)
2324

24-
inputs = [torch.randn(*init_size)]
25+
inputs = [torch.randn(*input_shape)]
2526
self.run_test(
2627
Expand(),
2728
inputs,
2829
)
2930

31+
@parameterized.expand(
32+
[
33+
("2d_dim", (2, 1), (4, 1), (6, 1), (-1, 3)),
34+
("3d_dim", (2, 1, 1), (4, 1, 1), (6, 1, 1), (-1, 3, 4)),
35+
("4d_dim", (1, 1, 1, 1), (3, 1, 1, 1), (5, 1, 1, 1), (-1, 2, 3, 6)),
36+
("keep_dim", (2, 1, 5, 5), (4, 1, 5, 5), (6, 1, 5, 5), (-1, 3, -1, -1)),
37+
("different_ranks", (1, 2, 1), (1, 2, 1), (2, 2, 1), (2, -1, -1, -1)),
38+
]
39+
)
40+
def test_expand_dynamic(self, _, min_shape, opt_shape, max_shape, expanded_shape):
41+
class ExpandDynamic(nn.Module):
42+
def forward(self, x):
43+
return torch.ops.aten.expand.default(x, expanded_shape)
44+
45+
input_specs = [
46+
Input(
47+
dtype=torch.float32,
48+
min_shape=min_shape,
49+
opt_shape=opt_shape,
50+
max_shape=max_shape,
51+
),
52+
]
53+
self.run_test_with_dynamic_shape(
54+
ExpandDynamic(),
55+
input_specs,
56+
)
57+
3058

3159
if __name__ == "__main__":
3260
run_tests()

0 commit comments

Comments
 (0)