Skip to content

Commit d375d10

Browse files
authored
feat: support flatten and reshape via shuffle_layer (#2354)
1 parent 6d59a14 commit d375d10

File tree

6 files changed

+128
-11
lines changed

6 files changed

+128
-11
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1588,3 +1588,27 @@ def tensorrt_scaled_dot_product_attention(
15881588
return impl.attention.scaled_dot_product_attention(
15891589
ctx, target, SourceIR.TORCHTRT_LOWERED, name, args[0], args[1], args[2]
15901590
)
1591+
1592+
1593+
@dynamo_tensorrt_converter(torch.ops.aten.reshape.default) # type: ignore[misc]
1594+
@dynamo_tensorrt_converter(torch.ops.aten.view.default) # type: ignore[misc]
1595+
@enforce_tensor_types(
1596+
{
1597+
0: (TRTTensor,),
1598+
}
1599+
) # type: ignore[misc]
1600+
def aten_ops_reshape(
1601+
ctx: ConversionContext,
1602+
target: Target,
1603+
args: Tuple[Argument, ...],
1604+
kwargs: Dict[str, Argument],
1605+
name: str,
1606+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1607+
return impl.shuffle.reshape(
1608+
ctx,
1609+
target,
1610+
SourceIR.ATEN,
1611+
name,
1612+
input=args[0],
1613+
shape=args[1],
1614+
)

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,3 +511,35 @@ def to_numpy(
511511
raise AssertionError(
512512
f"to_numpy can only be called on None, bool, int, float, np.ndarray, or torch.Tensor, got: {value}"
513513
)
514+
515+
516+
def flatten_dims(
517+
input: Sequence[Union[TRTTensor, torch.Tensor, np.ndarray]],
518+
start_dim: int,
519+
end_dim: int,
520+
) -> Tuple[int, ...]:
521+
"""
522+
Given an input, start and end indices of dimension,
523+
this function will return a flattened new shape.
524+
525+
Args:
526+
input (Sequence[Union[TRTTensor, torch.Tensor, np.ndarray]]):
527+
an input value waiting to be flattened
528+
start_dim (int): the first dim to flatten
529+
end_dim (int): the last dim to flatten (this dim is included)
530+
531+
Returns:
532+
Tuple[int]: new_shape
533+
"""
534+
shape = input.shape
535+
dim_size = len(shape)
536+
start_dim = get_positive_dim(start_dim, dim_size)
537+
end_dim = get_positive_dim(end_dim, dim_size)
538+
539+
num_elements = 1
540+
for i in range(start_dim, end_dim + 1):
541+
num_elements *= shape[i]
542+
543+
new_shape = tuple(shape[:start_dim]) + (num_elements,) + tuple(shape[end_dim + 1 :])
544+
545+
return new_shape

py/torch_tensorrt/dynamo/conversion/impl/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
reduce,
1818
select,
1919
shape,
20+
shuffle,
2021
slice,
2122
split,
2223
squeeze,
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from typing import Optional, Sequence, Union
2+
3+
from torch.fx.node import Target
4+
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
5+
from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR
6+
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
7+
from torch_tensorrt.fx.types import TRTTensor
8+
9+
10+
def reshape(
11+
ctx: ConversionContext,
12+
target: Union[Target, str],
13+
source_ir: Optional[SourceIR],
14+
name: str,
15+
input: TRTTensor,
16+
shape: Sequence[int],
17+
) -> TRTTensor:
18+
layer = ctx.net.add_shuffle(input)
19+
layer.reshape_dims = tuple(shape)
20+
set_layer_name(layer, target, name, source_ir)
21+
return layer.get_output(0)

tests/py/dynamo/conversion/test_converter_utils.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
import numpy as np
22
import torch
3+
from parameterized import parameterized
34
from torch.testing._internal.common_utils import TestCase, run_tests
4-
from torch_tensorrt.dynamo.conversion.converter_utils import enforce_tensor_types
5+
from torch_tensorrt.dynamo.conversion.converter_utils import (
6+
enforce_tensor_types,
7+
flatten_dims,
8+
)
59
from torch_tensorrt.fx.types import TRTTensor
610

711
from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing
@@ -37,5 +41,39 @@ def test_invalid_invocation_type(self):
3741
enforce_tensor_types({0: (int, bool)})
3842

3943

44+
class TestFlattenDimsEnforcement(TestCase):
45+
@parameterized.expand(
46+
[
47+
((1, 2), 0, 0, (1, 2)),
48+
((1, 2), 0, 1, (2,)),
49+
((2, 3, 4), 1, 2, (2, 12)),
50+
((2, 3, 4), 0, 1, (6, 4)),
51+
((2, 3, 4), -3, 2, (24,)),
52+
((2, 3, 4, 5), 0, -2, (24, 5)),
53+
((2, 3, 4, 5), -4, -1, (120,)),
54+
]
55+
)
56+
def test_numpy_array(self, input_shape, start_dim, end_dim, true_shape):
57+
inputs = np.random.randn(*input_shape)
58+
new_shape = flatten_dims(inputs, start_dim, end_dim)
59+
self.assertEqual(new_shape, true_shape)
60+
61+
@parameterized.expand(
62+
[
63+
((1, 2), 0, 0, (1, 2)),
64+
((1, 2), 0, 1, (2,)),
65+
((2, 3, 4), 1, 2, (2, 12)),
66+
((2, 3, 4), 0, 1, (6, 4)),
67+
((2, 3, 4), -3, 2, (24,)),
68+
((2, 3, 4, 5), 0, -2, (24, 5)),
69+
((2, 3, 4, 5), -4, -1, (120,)),
70+
]
71+
)
72+
def test_torch_tensor(self, input_shape, start_dim, end_dim, true_shape):
73+
inputs = torch.randn(input_shape)
74+
new_shape = flatten_dims(inputs, start_dim, end_dim)
75+
self.assertEqual(new_shape, true_shape)
76+
77+
4078
if __name__ == "__main__":
4179
run_tests()

tests/py/dynamo/conversion/test_reshape_aten.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
class TestReshapeConverter(DispatchTestCase):
1313
@parameterized.expand(
1414
[
15+
((-1,),),
16+
((20,),),
1517
((1, 20),),
1618
((1, 10, -1),),
1719
]
@@ -21,22 +23,22 @@ class TestReshapeConverter(DispatchTestCase):
2123
"Shape tensor supported well in TensorRT 8.5 and later",
2224
)
2325
def test_reshape(self, target_shape):
24-
class TestModule(torch.nn.Module):
25-
def __init__(self, target_shape):
26+
class Reshape(torch.nn.Module):
27+
def __init__(self):
2628
super().__init__()
27-
self.target_shape = target_shape
2829

2930
def forward(self, x):
30-
return torch.ops.aten.view.default(x, self.target_shape)
31+
return torch.ops.aten.view.default(x, target_shape)
3132

3233
inputs = [torch.randn(1, 2, 10)]
3334
self.run_test(
34-
TestModule(target_shape),
35+
Reshape(),
3536
inputs,
3637
)
3738

3839
@parameterized.expand(
3940
[
41+
((-1,),),
4042
((-1, 10),),
4143
((-1, 5),),
4244
((2, 2, -1),),
@@ -47,13 +49,12 @@ def forward(self, x):
4749
"Shape tensor supported well in TensorRT 8.5 and later",
4850
)
4951
def test_reshape_with_dynamic_shape(self, target_shape):
50-
class TestModule(torch.nn.Module):
51-
def __init__(self, target_shape):
52+
class Reshape(torch.nn.Module):
53+
def __init__(self):
5254
super().__init__()
53-
self.target_shape = target_shape
5455

5556
def forward(self, x):
56-
return torch.ops.aten.view.default(x, self.target_shape)
57+
return torch.ops.aten.view.default(x, target_shape)
5758

5859
input_specs = [
5960
Input(
@@ -63,7 +64,7 @@ def forward(self, x):
6364
),
6465
]
6566
self.run_test_with_dynamic_shape(
66-
TestModule(target_shape),
67+
Reshape(),
6768
input_specs,
6869
)
6970

0 commit comments

Comments
 (0)