Skip to content

chore: remove aten.full decomposition #2954

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 24 additions & 10 deletions py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
get_trt_tensor,
)
from torch_tensorrt.dynamo.conversion.impl.cat import cat
from torch_tensorrt.dynamo.conversion.impl.shape import shape as get_shape
from torch_tensorrt.dynamo.conversion.impl.slice.base import slice
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
from torch_tensorrt.fx.converters.converter_utils import (
has_dynamic_shape,
prepend_ones,
Expand Down Expand Up @@ -90,27 +92,39 @@ def expand(
# After the above padding, the shape and tensor rank must be equal
assert len(input_t.shape) == shape_rank

# -1 denotes taking the shape from the original input tensor
shape = tuple(
[input_t.shape[i] if shape[i] == -1 else shape[i] for i in range(shape_rank)]
)
shape_t = []
for i in range(shape_rank):
if shape[i] == -1:
shape_t.append(
get_shape(ctx, target, source_ir, name + f"_shape_dim{i}", input_t, i)
)
else:
shape_t.append(shape[i])

# Establish the desired output shape, strides, and starting indices
input_tensor_shape = tuple(input_t.shape)
start = tuple([0] * shape_rank)
stride = tuple(
[int(i == o) for i, o in zip(input_tensor_shape, shape)]
) # stride == 1 if dimensions match, 0 otherwise

shape_ = shape
# TODO: Revisit stride calculation. stride[dim]=0 implies that dimension is being broadcasted.
# stride should be 1 for all non-broadcasted dims
stride = []
for i, o in zip(input_tensor_shape, shape_t):
# If the shape has ITensor, we treat it as a reshape dim instead of a broadcasted dim
# shape_t cannot have -1. If the input at this dimension has a shape of -1, set the stride to 1. This indicates that the input is dynamic and does not imply broadcasting at that specific dimension.
if isinstance(i, int) and isinstance(o, int) and i != DYNAMIC_DIM:
stride.append(int(i == o))
else:
stride.append(1)

shape_ = shape_t
# Handle dynamic shapes case where shape has dynamic dimension
if any(isinstance(ele, TRTTensor) for ele in shape):
if any(isinstance(ele, TRTTensor) for ele in shape_t):
shape_ = cat(
ctx,
target,
source_ir,
name + "_shape_concat",
shape,
shape_t,
0,
cast_dtype=trt.int32,
)
Expand Down
1 change: 0 additions & 1 deletion py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,6 @@
aten.clamp_min,
aten.clamp_max,
aten.linalg_vector_norm,
aten.full,
aten.repeat,
}
torch_disabled_decompositions: Set[Union[OpOverload, OpOverloadPacket]] = {
Expand Down
44 changes: 36 additions & 8 deletions tests/py/dynamo/conversion/test_expand_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,59 @@
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase


class TestExpandConverter(DispatchTestCase):
@parameterized.expand(
[
("2d_dim", (2, 3), (2, 1)),
("3d_dim", (2, 3, 4), (2, 1, 1)),
("4d_dim", (2, 3, 4, 5), (2, 1, 1, 1)),
("keep_dim", (2, 3, -1, -1), (2, 1, 5, 5)),
("different_ranks", (2, 3, -1, -1), (1, 5, 7)),
("2d_dim", (2, 1), (2, 3)),
("3d_dim", (2, 1, 1), (2, 3, 4)),
("4d_dim", (2, 1, 1, 1), (2, 3, 4, 5)),
("keep_dim", (2, 1, 5, 5), (2, 3, -1, -1)),
("different_ranks", (1, 5, 7), (2, 3, -1, -1)),
]
)
def test_expand(self, _, sizes, init_size):
def test_expand(self, _, input_shape, expanded_shape):
class Expand(nn.Module):
def forward(self, x):
return torch.ops.aten.expand.default(x, sizes)
return torch.ops.aten.expand.default(x, expanded_shape)

inputs = [torch.randn(*init_size)]
inputs = [torch.randn(*input_shape)]
self.run_test(
Expand(),
inputs,
)

@parameterized.expand(
[
("2d_dim", (2, 1), (4, 1), (6, 1), (-1, 3)),
("3d_dim", (2, 1, 1), (4, 1, 1), (6, 1, 1), (-1, 3, 4)),
("4d_dim", (1, 1, 1, 1), (3, 1, 1, 1), (5, 1, 1, 1), (-1, 2, 3, 6)),
("keep_dim", (2, 1, 5, 5), (4, 1, 5, 5), (6, 1, 5, 5), (-1, 3, -1, -1)),
("different_ranks", (1, 2, 1), (1, 2, 1), (2, 2, 1), (2, -1, -1, -1)),
]
)
def test_expand_dynamic(self, _, min_shape, opt_shape, max_shape, expanded_shape):
class ExpandDynamic(nn.Module):
def forward(self, x):
return torch.ops.aten.expand.default(x, expanded_shape)

input_specs = [
Input(
dtype=torch.float32,
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
),
]
self.run_test_with_dynamic_shape(
ExpandDynamic(),
input_specs,
)


if __name__ == "__main__":
run_tests()
Loading