Skip to content

Commit 648772c

Browse files
authored
feat: dynamic support for pixel_suffle and pixel_unshuffle (#3044)
1 parent 655ed6b commit 648772c

File tree

4 files changed

+245
-34
lines changed

4 files changed

+245
-34
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2801,7 +2801,9 @@ def aten_ops_reshape(
28012801
)
28022802

28032803

2804-
@dynamo_tensorrt_converter(torch.ops.aten.pixel_shuffle.default)
2804+
@dynamo_tensorrt_converter(
2805+
torch.ops.aten.pixel_shuffle.default, supports_dynamic_shapes=True
2806+
)
28052807
@enforce_tensor_types(
28062808
{
28072809
0: (TRTTensor,),
@@ -2824,7 +2826,9 @@ def aten_ops_pixel_shuffle(
28242826
)
28252827

28262828

2827-
@dynamo_tensorrt_converter(torch.ops.aten.pixel_unshuffle.default)
2829+
@dynamo_tensorrt_converter(
2830+
torch.ops.aten.pixel_unshuffle.default, supports_dynamic_shapes=True
2831+
)
28282832
@enforce_tensor_types(
28292833
{
28302834
0: (TRTTensor,),

py/torch_tensorrt/dynamo/conversion/impl/shuffle.py

Lines changed: 173 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Optional, Sequence, Union
22

33
import numpy as np
4+
import tensorrt as trt
45
import torch_tensorrt.dynamo.conversion.impl as impl
56
from torch.fx.node import Target
67
from torch_tensorrt import _enums
@@ -12,10 +13,9 @@
1213
get_trt_tensor,
1314
set_layer_name,
1415
)
16+
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
1517
from torch_tensorrt.fx.types import TRTTensor
1618

17-
import tensorrt as trt
18-
1919

2020
def reshape(
2121
ctx: ConversionContext,
@@ -61,35 +61,106 @@ def pixel_shuffle(
6161
input: TRTTensor,
6262
upscale_factor: int,
6363
) -> TRTTensor:
64-
shape = input.shape
65-
in_channels, in_height, in_width = shape[-3:]
66-
out_channels = in_channels // (upscale_factor**2)
67-
out_height = in_height * upscale_factor
68-
out_width = in_width * upscale_factor
69-
new_shape = shape[:-3] + (
70-
out_channels,
64+
# Get input shape tensor
65+
input_shape_tensor = get_shape_with_dynamic_shape(
66+
ctx,
67+
target,
68+
source_ir,
69+
name + "_shape",
70+
input.shape,
71+
input,
72+
)
73+
74+
# Extract in_channels, in_height, and in_width from the input shape tensor
75+
in_channels_tensor = ctx.net.add_slice(
76+
input_shape_tensor, start=(len(input.shape) - 3,), shape=(1,), stride=(1,)
77+
).get_output(0)
78+
in_height_tensor = ctx.net.add_slice(
79+
input_shape_tensor, start=(len(input.shape) - 2,), shape=(1,), stride=(1,)
80+
).get_output(0)
81+
in_width_tensor = ctx.net.add_slice(
82+
input_shape_tensor, start=(len(input.shape) - 1,), shape=(1,), stride=(1,)
83+
).get_output(0)
84+
85+
# Calculate out_channels, out_height, and out_width as tensors
86+
upscale_factor_sq = upscale_factor * upscale_factor
87+
upscale_factor_tensor = get_trt_tensor(
88+
ctx, upscale_factor, f"{name}_upscale_factor"
89+
)
90+
upscale_factor_sq_tensor = get_trt_tensor(
91+
ctx, upscale_factor_sq, f"{name}_upscale_factor_sq"
92+
)
93+
94+
out_channels_tensor = impl.elementwise.floor_divide(
95+
ctx,
96+
target,
97+
source_ir,
98+
f"{name}_out_channels_tensor",
99+
in_channels_tensor,
100+
upscale_factor_sq_tensor,
101+
)
102+
out_height_tensor = impl.elementwise.mul(
103+
ctx,
104+
target,
105+
source_ir,
106+
f"{name}_out_height_tensor",
107+
in_height_tensor,
71108
upscale_factor,
109+
)
110+
out_width_tensor = impl.elementwise.mul(
111+
ctx,
112+
target,
113+
source_ir,
114+
f"{name}_out_width_tensor",
115+
in_width_tensor,
72116
upscale_factor,
73-
in_height,
74-
in_width,
75117
)
118+
119+
# Construct new shape tensor
120+
new_shape_tensors = [
121+
ctx.net.add_slice(
122+
input_shape_tensor, start=(i,), shape=(1,), stride=(1,)
123+
).get_output(0)
124+
for i in range(len(input.shape) - 3)
125+
]
126+
new_shape_tensors += [
127+
out_channels_tensor,
128+
upscale_factor_tensor,
129+
upscale_factor_tensor,
130+
in_height_tensor,
131+
in_width_tensor,
132+
]
133+
134+
# Reshape tensor
76135
reshaped_tensor = reshape(
77-
ctx, target, source_ir, f"{name}_reshape1", input, new_shape
136+
ctx, target, source_ir, f"{name}_reshape", input, new_shape_tensors
78137
)
79-
rank = len(shape)
138+
139+
# Permute shape
140+
rank = len(input.shape)
80141
permute_shape = list(range(rank))
81142
permute_shape.insert(-2, rank)
82143
permute_shape.insert(-1, rank + 1)
83144
permuted_tensor = impl.permutation.permute(
84145
ctx, target, source_ir, f"{name}_permute", reshaped_tensor, permute_shape
85146
)
147+
148+
# Construct output shape tensor
149+
out_shape_tensors = [
150+
ctx.net.add_slice(
151+
input_shape_tensor, start=(i,), shape=(1,), stride=(1,)
152+
).get_output(0)
153+
for i in range(len(input.shape) - 3)
154+
]
155+
out_shape_tensors += [out_channels_tensor, out_height_tensor, out_width_tensor]
156+
86157
return reshape(
87158
ctx,
88159
target,
89160
source_ir,
90-
f"{name}_reshape2",
161+
f"{name}_reshape_out",
91162
permuted_tensor,
92-
shape[:-3] + (out_channels, out_height, out_width),
163+
out_shape_tensors,
93164
)
94165

95166

@@ -101,39 +172,109 @@ def pixel_unshuffle(
101172
input: TRTTensor,
102173
downscale_factor: int,
103174
) -> TRTTensor:
104-
shape = input.shape
105-
in_channels, in_height, in_width = shape[-3:]
106-
out_channels = in_channels * (downscale_factor**2)
107-
out_height = in_height // downscale_factor
108-
out_width = in_width // downscale_factor
109-
new_shape = shape[:-3] + (
110-
in_channels,
111-
out_height,
112-
downscale_factor,
113-
out_width,
114-
downscale_factor,
175+
# Get input shape tensor
176+
input_shape_tensor = get_shape_with_dynamic_shape(
177+
ctx,
178+
target,
179+
source_ir,
180+
name + "_shape",
181+
input.shape,
182+
input,
183+
)
184+
185+
# Extract in_channels, in_height, and in_width from the input shape tensor
186+
in_channels_tensor = ctx.net.add_slice(
187+
input_shape_tensor, start=(len(input.shape) - 3,), shape=(1,), stride=(1,)
188+
).get_output(0)
189+
in_height_tensor = ctx.net.add_slice(
190+
input_shape_tensor, start=(len(input.shape) - 2,), shape=(1,), stride=(1,)
191+
).get_output(0)
192+
in_width_tensor = ctx.net.add_slice(
193+
input_shape_tensor, start=(len(input.shape) - 1,), shape=(1,), stride=(1,)
194+
).get_output(0)
195+
196+
# Calculate out_channels, out_height, and out_width as tensors
197+
downscale_factor_sq = downscale_factor * downscale_factor
198+
downscale_factor_tensor = get_trt_tensor(
199+
ctx, downscale_factor, f"{name}_downscale_factor"
200+
)
201+
downscale_factor_sq_tensor = get_trt_tensor(
202+
ctx, downscale_factor_sq, f"{name}_downscale_factor_sq"
115203
)
204+
205+
out_channels_tensor = impl.elementwise.mul(
206+
ctx,
207+
target,
208+
source_ir,
209+
f"{name}_out_channels_tensor",
210+
in_channels_tensor,
211+
downscale_factor_sq_tensor,
212+
)
213+
out_height_tensor = impl.elementwise.floor_divide(
214+
ctx,
215+
target,
216+
source_ir,
217+
f"{name}_out_height_tensor",
218+
in_height_tensor,
219+
downscale_factor_tensor,
220+
)
221+
out_width_tensor = impl.elementwise.floor_divide(
222+
ctx,
223+
target,
224+
source_ir,
225+
f"{name}_out_width_tensor",
226+
in_width_tensor,
227+
downscale_factor_tensor,
228+
)
229+
230+
# Construct new shape tensor
231+
new_shape_tensors = [
232+
ctx.net.add_slice(
233+
input_shape_tensor, start=(i,), shape=(1,), stride=(1,)
234+
).get_output(0)
235+
for i in range(len(input.shape) - 3)
236+
]
237+
new_shape_tensors += [
238+
in_channels_tensor,
239+
out_height_tensor,
240+
downscale_factor_tensor,
241+
out_width_tensor,
242+
downscale_factor_tensor,
243+
]
244+
116245
reshaped_tensor = reshape(
117-
ctx, target, source_ir, f"{name}_reshape1", input, new_shape
246+
ctx, target, source_ir, f"{name}_reshape", input, new_shape_tensors
118247
)
119-
rank = len(new_shape)
120-
permute_shape = tuple(range(rank - 5)) + (
248+
249+
# Permute shape
250+
rank = len(new_shape_tensors)
251+
permute_shape = list(range(rank - 5)) + [
121252
rank - 5, # in_channels
122253
rank - 3, # downscale_factor
123254
rank - 1, # downscale_factor
124255
rank - 4, # out_height
125256
rank - 2, # out_width
126-
)
257+
]
127258
permuted_tensor = impl.permutation.permute(
128259
ctx, target, source_ir, f"{name}_permute", reshaped_tensor, permute_shape
129260
)
261+
262+
# Construct output shape tensor
263+
out_shape_tensors = [
264+
ctx.net.add_slice(
265+
input_shape_tensor, start=(i,), shape=(1,), stride=(1,)
266+
).get_output(0)
267+
for i in range(len(input.shape) - 3)
268+
]
269+
out_shape_tensors += [out_channels_tensor, out_height_tensor, out_width_tensor]
270+
130271
return reshape(
131272
ctx,
132273
target,
133274
source_ir,
134-
f"{name}_reshape2",
275+
f"{name}_reshape_out",
135276
permuted_tensor,
136-
shape[:-3] + (out_channels, out_height, out_width),
277+
out_shape_tensors,
137278
)
138279

139280

tests/py/dynamo/conversion/test_pixel_shuffle_aten.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
from parameterized import parameterized
33
from torch.testing._internal.common_utils import run_tests
4+
from torch_tensorrt import Input
45

56
from .harness import DispatchTestCase
67

@@ -26,6 +27,38 @@ def forward(self, x):
2627
inputs,
2728
)
2829

30+
@parameterized.expand(
31+
[
32+
(
33+
(1, 1, 1),
34+
(2, 2, 2),
35+
(3, 3, 3),
36+
torch.float,
37+
1,
38+
),
39+
]
40+
)
41+
def test_dynamic_shape_pixel_shuffle(
42+
self, min_shape, opt_shape, max_shape, type, upscale_factor
43+
):
44+
class PixelShuffle(torch.nn.Module):
45+
def __init__(self):
46+
super().__init__()
47+
48+
def forward(self, x):
49+
return torch.ops.aten.pixel_shuffle.default(x, upscale_factor)
50+
51+
input_specs = [
52+
Input(
53+
min_shape=min_shape,
54+
opt_shape=opt_shape,
55+
max_shape=max_shape,
56+
dtype=type,
57+
),
58+
]
59+
60+
self.run_test_with_dynamic_shape(PixelShuffle(), input_specs)
61+
2962

3063
if __name__ == "__main__":
3164
run_tests()

tests/py/dynamo/conversion/test_pixel_unshuffle_aten.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
from parameterized import parameterized
33
from torch.testing._internal.common_utils import run_tests
4+
from torch_tensorrt import Input
45

56
from .harness import DispatchTestCase
67

@@ -24,6 +25,38 @@ def forward(self, x):
2425
inputs,
2526
)
2627

28+
@parameterized.expand(
29+
[
30+
(
31+
(1, 1, 1),
32+
(2, 2, 2),
33+
(3, 3, 3),
34+
torch.float,
35+
1,
36+
),
37+
]
38+
)
39+
def test_dynamic_shape_pixel_unshuffle(
40+
self, min_shape, opt_shape, max_shape, type, upscale_factor
41+
):
42+
class PixelUnshuffle(torch.nn.Module):
43+
def __init__(self):
44+
super().__init__()
45+
46+
def forward(self, x):
47+
return torch.ops.aten.pixel_unshuffle.default(x, upscale_factor)
48+
49+
input_specs = [
50+
Input(
51+
min_shape=min_shape,
52+
opt_shape=opt_shape,
53+
max_shape=max_shape,
54+
dtype=type,
55+
),
56+
]
57+
58+
self.run_test_with_dynamic_shape(PixelUnshuffle(), input_specs)
59+
2760

2861
if __name__ == "__main__":
2962
run_tests()

0 commit comments

Comments
 (0)