Skip to content

Commit 593ff44

Browse files
authored
feat: support aten.pixel_shuffle dynamo converter (#2596)
1 parent eb62a3d commit 593ff44

File tree

3 files changed

+95
-0
lines changed

3 files changed

+95
-0
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2278,6 +2278,29 @@ def aten_ops_reshape(
22782278
)
22792279

22802280

2281+
@dynamo_tensorrt_converter(torch.ops.aten.pixel_shuffle.default)
2282+
@enforce_tensor_types(
2283+
{
2284+
0: (TRTTensor,),
2285+
}
2286+
)
2287+
def aten_ops_pixel_shuffle(
2288+
ctx: ConversionContext,
2289+
target: Target,
2290+
args: Tuple[Argument, ...],
2291+
kwargs: Dict[str, Argument],
2292+
name: str,
2293+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2294+
return impl.shuffle.pixel_shuffle(
2295+
ctx,
2296+
target,
2297+
SourceIR.ATEN,
2298+
name,
2299+
args[0],
2300+
args[1],
2301+
)
2302+
2303+
22812304
@enforce_tensor_types({0: (TRTTensor,)})
22822305
@dynamo_tensorrt_converter(torch.ops.aten.argmax.default)
22832306
def aten_ops_argmax(

py/torch_tensorrt/dynamo/conversion/impl/shuffle.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Optional, Sequence, Union
22

3+
import torch_tensorrt.dynamo.conversion.impl as impl
34
from torch.fx.node import Target
45
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
56
from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR
@@ -19,3 +20,43 @@ def reshape(
1920
layer.reshape_dims = tuple(shape)
2021
set_layer_name(layer, target, name, source_ir)
2122
return layer.get_output(0)
23+
24+
25+
def pixel_shuffle(
26+
ctx: ConversionContext,
27+
target: Union[Target, str],
28+
source_ir: Optional[SourceIR],
29+
name: str,
30+
input: TRTTensor,
31+
upscale_factor: int,
32+
) -> TRTTensor:
33+
shape = input.shape
34+
in_channels, in_height, in_width = shape[-3:]
35+
out_channels = in_channels // (upscale_factor**2)
36+
out_height = in_height * upscale_factor
37+
out_width = in_width * upscale_factor
38+
new_shape = shape[:-3] + (
39+
out_channels,
40+
upscale_factor,
41+
upscale_factor,
42+
in_height,
43+
in_width,
44+
)
45+
reshaped_tensor = reshape(
46+
ctx, target, source_ir, f"{name}_reshape1", input, new_shape
47+
)
48+
rank = len(shape)
49+
permute_shape = list(range(rank))
50+
permute_shape.insert(-2, rank)
51+
permute_shape.insert(-1, rank + 1)
52+
permuted_tensor = impl.permutation.permute(
53+
ctx, target, source_ir, f"{name}_permute", reshaped_tensor, permute_shape
54+
)
55+
return reshape(
56+
ctx,
57+
target,
58+
source_ir,
59+
f"{name}_reshape2",
60+
permuted_tensor,
61+
shape[:-3] + (out_channels, out_height, out_width),
62+
)
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import torch
2+
from parameterized import parameterized
3+
from torch.testing._internal.common_utils import run_tests
4+
5+
from .harness import DispatchTestCase
6+
7+
8+
class TestPixelShuffleConverter(DispatchTestCase):
9+
@parameterized.expand(
10+
[
11+
((1, 1, 1), 1),
12+
((12, 3, 4), 2),
13+
((1, 9, 4, 4), 3),
14+
((2, 32, 2, 3), 4),
15+
((1, 10, 36, 2, 4), 6),
16+
]
17+
)
18+
def test_pixel_shuffle(self, shape, upscale_factor):
19+
class PixelShuffle(torch.nn.Module):
20+
def forward(self, x):
21+
return torch.ops.aten.pixel_shuffle.default(x, upscale_factor)
22+
23+
inputs = [torch.randn(shape)]
24+
self.run_test(
25+
PixelShuffle(),
26+
inputs,
27+
)
28+
29+
30+
if __name__ == "__main__":
31+
run_tests()

0 commit comments

Comments
 (0)