Skip to content

Commit 0398f48

Browse files
committed
Add support for aten.pixel_unshuffle dynamo converter
1 parent 7f14221 commit 0398f48

File tree

3 files changed

+96
-0
lines changed

3 files changed

+96
-0
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2319,6 +2319,29 @@ def aten_ops_pixel_shuffle(
23192319
)
23202320

23212321

2322+
@dynamo_tensorrt_converter(torch.ops.aten.pixel_unshuffle.default)
2323+
@enforce_tensor_types(
2324+
{
2325+
0: (TRTTensor,),
2326+
}
2327+
)
2328+
def aten_ops_pixel_unshuffle(
2329+
ctx: ConversionContext,
2330+
target: Target,
2331+
args: Tuple[Argument, ...],
2332+
kwargs: Dict[str, Argument],
2333+
name: str,
2334+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2335+
return impl.shuffle.pixel_unshuffle(
2336+
ctx,
2337+
target,
2338+
SourceIR.ATEN,
2339+
name,
2340+
args[0],
2341+
args[1],
2342+
)
2343+
2344+
23222345
@enforce_tensor_types({0: (TRTTensor,)})
23232346
@dynamo_tensorrt_converter(torch.ops.aten.argmax.default)
23242347
def aten_ops_argmax(

py/torch_tensorrt/dynamo/conversion/impl/shuffle.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,47 @@ def pixel_shuffle(
6060
permuted_tensor,
6161
shape[:-3] + (out_channels, out_height, out_width),
6262
)
63+
64+
65+
def pixel_unshuffle(
66+
ctx: ConversionContext,
67+
target: Union[Target, str],
68+
source_ir: Optional[SourceIR],
69+
name: str,
70+
input: TRTTensor,
71+
downscale_factor: int,
72+
) -> TRTTensor:
73+
shape = input.shape
74+
in_channels, in_height, in_width = shape[-3:]
75+
out_channels = in_channels * (downscale_factor**2)
76+
out_height = in_height // downscale_factor
77+
out_width = in_width // downscale_factor
78+
new_shape = shape[:-3] + (
79+
in_channels,
80+
out_height,
81+
downscale_factor,
82+
out_width,
83+
downscale_factor,
84+
)
85+
reshaped_tensor = reshape(
86+
ctx, target, source_ir, f"{name}_reshape1", input, new_shape
87+
)
88+
rank = len(new_shape)
89+
permute_shape = tuple(range(rank - 5)) + (
90+
rank - 5, # in_channels
91+
rank - 3, # downscale_factor
92+
rank - 1, # downscale_factor
93+
rank - 4, # out_height
94+
rank - 2, # out_width
95+
)
96+
permuted_tensor = impl.permutation.permute(
97+
ctx, target, source_ir, f"{name}_permute", reshaped_tensor, permute_shape
98+
)
99+
return reshape(
100+
ctx,
101+
target,
102+
source_ir,
103+
f"{name}_reshape2",
104+
permuted_tensor,
105+
shape[:-3] + (out_channels, out_height, out_width),
106+
)
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import torch
2+
from parameterized import parameterized
3+
from torch.testing._internal.common_utils import run_tests
4+
5+
from .harness import DispatchTestCase
6+
7+
8+
class TestPixelUnshuffleConverter(DispatchTestCase):
9+
@parameterized.expand(
10+
[
11+
((1, 1, 1), 1),
12+
((1, 1, 12, 12), 3),
13+
((2, 3, 4, 25, 30), 5),
14+
]
15+
)
16+
def test_pixel_unshuffle(self, shape, downscale_factor):
17+
class PixelUnshuffle(torch.nn.Module):
18+
def forward(self, x):
19+
return torch.ops.aten.pixel_unshuffle.default(x, downscale_factor)
20+
21+
inputs = [torch.randn(shape)]
22+
self.run_test(
23+
PixelUnshuffle(),
24+
inputs,
25+
)
26+
27+
28+
if __name__ == "__main__":
29+
run_tests()

0 commit comments

Comments
 (0)