Skip to content

Commit 623d2f7

Browse files
committed
feat: expose IResizeLayer in dynamo
1 parent 46b39e0 commit 623d2f7

File tree

4 files changed

+207
-0
lines changed

4 files changed

+207
-0
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2463,3 +2463,45 @@ def aten_ops_pad(
24632463
mode=args_bounds_check(args, 2, "constant"),
24642464
value=args_bounds_check(args, 3, None),
24652465
)
2466+
2467+
2468+
@dynamo_tensorrt_converter(torch.ops.aten.upsample_nearest2d.vec)
2469+
def upsample_nearest2d(
2470+
ctx: ConversionContext,
2471+
target: Target,
2472+
args: Tuple[Argument, ...],
2473+
kwargs: Dict[str, Argument],
2474+
name: str,
2475+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2476+
return impl.upsample.upsample(
2477+
ctx,
2478+
target,
2479+
SourceIR.ATEN,
2480+
name,
2481+
input=args[0],
2482+
out_shape=args_bounds_check(args, 1),
2483+
scale_factors=args_bounds_check(args, 2),
2484+
resize_mode="nearest",
2485+
align_corners=False,
2486+
)
2487+
2488+
2489+
@dynamo_tensorrt_converter(torch.ops.aten.upsample_bilinear2d.vec)
2490+
def upsample_bilinear2d(
2491+
ctx: ConversionContext,
2492+
target: Target,
2493+
args: Tuple[Argument, ...],
2494+
kwargs: Dict[str, Argument],
2495+
name: str,
2496+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2497+
return impl.upsample.upsample(
2498+
ctx,
2499+
target,
2500+
SourceIR.ATEN,
2501+
name,
2502+
input=args[0],
2503+
out_shape=args_bounds_check(args, 1),
2504+
scale_factors=args_bounds_check(args, 3),
2505+
resize_mode="bilinear",
2506+
align_corners=args_bounds_check(args, 2),
2507+
)

py/torch_tensorrt/dynamo/conversion/impl/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,5 @@
2828
topk,
2929
unary,
3030
unsqueeze,
31+
upsample,
3132
)
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from typing import Optional, Sequence
2+
3+
import tensorrt as trt
4+
from torch.fx.node import Target
5+
from torch_tensorrt.dynamo._SourceIR import SourceIR
6+
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
7+
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
8+
from torch_tensorrt.fx.types import TRTTensor
9+
10+
11+
def upsample(
12+
ctx: ConversionContext,
13+
target: Target,
14+
source_ir: Optional[SourceIR],
15+
name: str,
16+
input: TRTTensor,
17+
out_shape: Optional[Sequence[int]],
18+
scale_factors: Optional[Sequence[float]],
19+
resize_mode: str,
20+
align_corners: bool,
21+
) -> TRTTensor:
22+
resize_layer = ctx.net.add_resize(input)
23+
# output size calculation
24+
# Pytorch assumes that one of out_shape/scale_factor is None
25+
# Pytorch assumes that dimensions match for out_shape/scale factor
26+
if out_shape is not None:
27+
resize_layer.shape = list(input.shape)[:2] + list(out_shape)
28+
elif scale_factors is not None:
29+
resize_layer.scales = [1.0, 1.0] + list(scale_factors)
30+
else:
31+
raise RuntimeError(
32+
f"At least one of out_shape and scale_factors should be specified."
33+
)
34+
35+
# interpolate mode
36+
if resize_mode == "nearest" or None:
37+
resize_layer.resize_mode = trt.ResizeMode.NEAREST
38+
elif resize_mode == "bilinear":
39+
resize_layer.resize_mode = trt.ResizeMode.LINEAR
40+
if align_corners is None or not align_corners:
41+
raise RuntimeError(
42+
f"Interpolation works differently is align_corners is False for {resize_mode} mode in PyTorch and TensorRT."
43+
)
44+
else:
45+
raise RuntimeError(
46+
f"Interpolation mode is {resize_mode} which is not supported by TensorRT."
47+
)
48+
49+
if resize_mode == "nearest":
50+
resize_layer.coordinate_transformation = (
51+
trt.ResizeCoordinateTransformation.ASYMMETRIC
52+
)
53+
elif resize_mode == "bilinear":
54+
# align corners
55+
if align_corners is not None and align_corners:
56+
resize_layer.coordinate_transformation = (
57+
trt.ResizeCoordinateTransformation.ALIGN_CORNERS
58+
)
59+
else:
60+
resize_layer.coordinate_transformation = (
61+
trt.ResizeCoordinateTransformation.ASYMMETRIC
62+
)
63+
64+
set_layer_name(resize_layer, target, name, source_ir)
65+
66+
out = resize_layer.get_output(0)
67+
return out
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import torch
2+
from parameterized import parameterized
3+
from torch.testing._internal.common_utils import run_tests
4+
5+
from .harness import DispatchTestCase
6+
7+
8+
class TestUpsampleConverter(DispatchTestCase):
9+
# test case for nearest upsample, using output_size, scale_factors is disabled here
10+
@parameterized.expand(
11+
[
12+
("upsample_nearest2d.vec_outshape_0", (2, 2), (4, 4)),
13+
("upsample_nearest2d.vec_outshape_1", (2, 2), (5, 5)),
14+
]
15+
)
16+
def test_upsample_nearest_output_shape(self, _, input_shape, output_shape):
17+
class Upsample(torch.nn.Module):
18+
def __init__(self):
19+
super().__init__()
20+
21+
def forward(self, input):
22+
return torch.ops.aten.upsample_nearest2d.vec(input, output_shape, None)
23+
24+
input = [torch.randn([1, 1] + list(input_shape))]
25+
self.run_test(Upsample(), input)
26+
27+
# test case for nearest upsample, using scale_factors, output_size is disabled here
28+
@parameterized.expand(
29+
[
30+
("upsample_nearest2d.vec_scale_0", (2, 2), (2, 2)),
31+
("upsample_nearest2d.vec_scale_1", (2, 2), (1.5, 1.5)),
32+
]
33+
)
34+
def test_upsample_nearest_scale_factor(self, _, input_shape, scale_factor):
35+
class Upsample(torch.nn.Module):
36+
def __init__(self):
37+
super().__init__()
38+
39+
def forward(self, input):
40+
return torch.ops.aten.upsample_nearest2d.vec(input, None, scale_factor)
41+
42+
input = [torch.randn([1, 1] + list(input_shape))]
43+
self.run_test(Upsample(), input)
44+
45+
# test case for bilinear upsample, using output_size, scale_factors is disabled here
46+
@parameterized.expand(
47+
[
48+
("upsample_bilinear2d.vec_outshape_0", (2, 2), (4, 4), True),
49+
("upsample_bilinear2d.vec_outshape_1", (2, 2), (5, 5), True),
50+
]
51+
)
52+
def test_upsample_bilinear_output_shape(
53+
self, _, input_shape, output_shape, align_corners
54+
):
55+
class Upsample(torch.nn.Module):
56+
def __init__(self):
57+
super().__init__()
58+
59+
def forward(self, input):
60+
return torch.ops.aten.upsample_bilinear2d.vec(
61+
input,
62+
output_shape,
63+
align_corners,
64+
None,
65+
)
66+
67+
input = [torch.randn([1, 1] + list(input_shape))]
68+
self.run_test(Upsample(), input)
69+
70+
# test case for bilinear upsample, using scale_factors, output_shape is disabled here
71+
@parameterized.expand(
72+
[
73+
("upsample_bilinear2d.vec_scale_0", (2, 2), (2, 2), True),
74+
("upsample_bilinear2d.vec_scale_1", (2, 2), (1.5, 1.5), True),
75+
]
76+
)
77+
def test_upsample_bilinear_scale_factors(
78+
self, _, input_shape, scale_factors, align_corners
79+
):
80+
class Upsample(torch.nn.Module):
81+
def __init__(self):
82+
super().__init__()
83+
84+
def forward(self, input):
85+
return torch.ops.aten.upsample_bilinear2d.vec(
86+
input,
87+
None,
88+
align_corners,
89+
scale_factors,
90+
)
91+
92+
input = [torch.randn([1, 1] + list(input_shape))]
93+
self.run_test(Upsample(), input)
94+
95+
96+
if __name__ == "__main__":
97+
run_tests()

0 commit comments

Comments
 (0)