Skip to content

Commit 922fd11

Browse files
committed
WIP: need to fix the bug that output values are incorrect sometimes. Try using 'interpolate' instead of 'pad'
1 parent e642f86 commit 922fd11

File tree

3 files changed

+169
-82
lines changed

3 files changed

+169
-82
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2184,6 +2184,27 @@ def aten_ops_avg_pool(
21842184
)
21852185

21862186

2187+
@dynamo_tensorrt_converter(torch.ops.aten.adaptive_avg_pool2d.default)
2188+
@dynamo_tensorrt_converter(torch.ops.aten._adaptive_avg_pool2d.default)
2189+
@dynamo_tensorrt_converter(torch.ops.aten.adaptive_avg_pool3d.default)
2190+
@dynamo_tensorrt_converter(torch.ops.aten._adaptive_avg_pool3d.default)
2191+
def aten_ops_adaptive_avg_poolNd(
2192+
ctx: ConversionContext,
2193+
target: Target,
2194+
args: Tuple[Argument, ...],
2195+
kwargs: Dict[str, Argument],
2196+
name: str,
2197+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2198+
return impl.pool.adaptive_avg_poolNd(
2199+
ctx,
2200+
target,
2201+
source_ir=SourceIR.ATEN,
2202+
name=name,
2203+
input=args[0],
2204+
output_size=args[1],
2205+
)
2206+
2207+
21872208
def max_pool_param_validator(pool_node: Node) -> bool:
21882209
dilation = args_bounds_check(pool_node.args, 4, 1)
21892210
ceil_mode = args_bounds_check(pool_node.args, 5, False)

py/torch_tensorrt/dynamo/conversion/impl/pool.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Optional, Sequence, Union
22

33
import tensorrt as trt
4+
import torch_tensorrt.dynamo.conversion.impl as impl
45
from torch.fx.node import Target
56
from torch_tensorrt.dynamo._SourceIR import SourceIR
67
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
@@ -104,3 +105,65 @@ def max_poolNd(
104105

105106
set_layer_name(pool_layer, target, name, source_ir)
106107
return pool_layer.get_output(0)
108+
109+
110+
def adaptive_avg_poolNd(
111+
ctx: ConversionContext,
112+
target: Union[Target, str],
113+
source_ir: Optional[SourceIR],
114+
name: str,
115+
input: TRTTensor,
116+
output_size: Sequence[int],
117+
) -> TRTTensor:
118+
input_rank = len(input.shape)
119+
if input_rank == 3:
120+
input = impl.shuffle.reshape(ctx, target, source_ir, f"{name}_reshape", input, (1, *input.shape))
121+
122+
extend_len = len(output_size)
123+
124+
# pad the input based on output_size if the dim of output is larger than input
125+
pad = []
126+
input_shape = input.shape
127+
for i in range(1, extend_len + 1):
128+
input_dim = input_shape[-i]
129+
output_dim = output_size[-i]
130+
diff = output_dim - input_dim
131+
if diff > 0:
132+
if diff % 2 == 0:
133+
pad.append(diff // 2)
134+
pad.append(diff // 2)
135+
else:
136+
pad.append(diff // 2 + 1)
137+
pad.append(diff // 2 + 1)
138+
else:
139+
pad.append(0)
140+
pad.append(0)
141+
142+
input = impl.pad.replication_padNd(
143+
ctx,
144+
target,
145+
source_ir,
146+
f"{name}_replication_padNd",
147+
input,
148+
pad,
149+
)
150+
151+
stride = tuple(
152+
input.shape[-extend_len + i] // output_size[i] for i in range(extend_len)
153+
)
154+
kernel_size = tuple(
155+
input.shape[-extend_len + i] - (output_size[i] - 1) * stride[i]
156+
for i in range(extend_len)
157+
)
158+
layer = ctx.net.add_pooling_nd(
159+
input=input, type=trt.PoolingType.AVERAGE, window_size=kernel_size
160+
)
161+
layer.stride_nd = stride
162+
set_layer_name(layer, target, f"{name}_pooling_{extend_len}d", source_ir)
163+
164+
output = layer.get_output(0)
165+
166+
if input_rank == 3:
167+
output = impl.shuffle.reshape(ctx, target, source_ir, f"{name}_reshape_back", output, (*output.shape[1:],))
168+
169+
return output

tests/py/dynamo/conversion/test_adaptive_avgpool_aten.py

Lines changed: 85 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -3,107 +3,110 @@
33
from torch.testing._internal.common_utils import run_tests
44
from torch_tensorrt import Input
55

6-
from .harness import DispatchTestCase
6+
from harness import DispatchTestCase
77

88

99
class TestAdaptiveAvgPoolConverter(DispatchTestCase):
1010
@parameterized.expand(
1111
[
12-
((64, 64),),
13-
((128, 64),),
14-
# (64,), This case has been there in previous code but it isn't a valid pytorch code.
15-
]
16-
)
17-
def test_adaptive_avgpool(
18-
self,
19-
output_size,
20-
):
21-
class TestModule(torch.nn.Module):
22-
def __init__(self):
23-
super().__init__()
24-
self.pool = torch.nn.AdaptiveAvgPool2d(output_size)
25-
26-
def forward(self, x):
27-
return self.pool(x)
28-
29-
inputs = [torch.randn(1, 3, 256, 256)]
30-
self.run_test(
31-
TestModule(),
32-
inputs,
33-
use_dynamo_tracer=True,
34-
)
35-
36-
def test_adaptive_avgpool_with_dynamic_shape(self):
37-
class TestModule(torch.nn.Module):
38-
def __init__(self):
39-
super().__init__()
40-
self.pool = torch.nn.AdaptiveAvgPool2d((64, 64))
41-
42-
def forward(self, x):
43-
return self.pool(x)
44-
45-
input_specs = [
46-
Input(
47-
shape=(-1, -1, 256, 256),
48-
dtype=torch.float32,
49-
shape_ranges=[((1, 1, 256, 256), (3, 3, 256, 256), (5, 5, 256, 256))],
12+
# 3d input
13+
# (
14+
# (1, 2, 3),
15+
# (1, 2),
16+
# ),
17+
# (
18+
# (1, 2, 3),
19+
# (2, 3),
20+
# ),
21+
# (
22+
# (1, 2, 8),
23+
# (4, 4),
24+
# ),
25+
(
26+
(2, 8, 16),
27+
(4, 10),
5028
),
51-
]
52-
self.run_test_with_dynamic_shape(
53-
TestModule(), input_specs, use_dynamo_tracer=True
54-
)
55-
56-
@parameterized.expand(
57-
[
58-
((16, 16, 16),),
59-
((32, 16, 4),),
60-
(32,),
29+
# (
30+
# (2, 8, 16),
31+
# (8, 8),
32+
# ),
33+
# # 4d input
34+
(
35+
(1, 1, 2, 3),
36+
(2, 8),
37+
),
38+
# (
39+
# (3, 2, 3, 2),
40+
# (1, 5),
41+
# ),
42+
# (
43+
# (4, 2, 2, 8),
44+
# (5, 2),
45+
# ),
46+
(
47+
(3, 2, 3, 2),
48+
(6, 4),
49+
),
50+
# (
51+
# (1, 2, 3, 2),
52+
# (2, 2),
53+
# ),
54+
# (
55+
# (2, 2, 32, 16),
56+
# (8, 8),
57+
# ),
58+
# (
59+
# (2, 2, 32, 32),
60+
# (31, 16),
61+
# ),
62+
# (
63+
# (1, 1, 64, 64),
64+
# (64, 16),
65+
# ),
6166
]
6267
)
63-
def test_adaptive_avgpool3d(
68+
def test_adaptive_avg_pool2d(
6469
self,
70+
input_shape,
6571
output_size,
6672
):
6773
class TestModule(torch.nn.Module):
68-
def __init__(self):
69-
super().__init__()
70-
self.pool = torch.nn.AdaptiveAvgPool3d(output_size)
71-
7274
def forward(self, x):
73-
return self.pool(x)
75+
return torch.ops.aten.adaptive_avg_pool2d.default(x, output_size)
7476

75-
inputs = [torch.randn(1, 3, 32, 64, 64)]
77+
inputs = [torch.randn(input_shape)]
7678
self.run_test(
7779
TestModule(),
7880
inputs,
79-
use_dynamo_tracer=True,
80-
)
81-
82-
def test_adaptive_avgpool3d_with_dynamic_shape(self):
83-
class TestModule(torch.nn.Module):
84-
def __init__(self):
85-
super().__init__()
86-
self.pool = torch.nn.AdaptiveAvgPool3d((16, 16, 16))
87-
88-
def forward(self, x):
89-
return self.pool(x)
90-
91-
input_specs = [
92-
Input(
93-
shape=(-1, -1, 32, 64, 64),
94-
dtype=torch.float32,
95-
shape_ranges=[
96-
((1, 1, 32, 64, 64), (3, 3, 32, 64, 64), (5, 5, 32, 64, 64))
97-
],
98-
),
99-
]
100-
self.run_test_with_dynamic_shape(
101-
TestModule(),
102-
input_specs,
103-
use_dynamo_tracer=True,
81+
# use_dynamo_tracer=True,
82+
enable_passes=True,
10483
)
10584

106-
# Testing with shape(-1, -1, -1, -1) results into error: "AdaptiveAvgPool2d and AdaptiveAvgPool3d currently doesn't support dynamic shapes for last two dims."
85+
# @parameterized.expand(
86+
# [
87+
# ((16, 16, 16),),
88+
# ((32, 16, 4),),
89+
# (32,),
90+
# ]
91+
# )
92+
# def test_adaptive_avgpool3d(
93+
# self,
94+
# output_size,
95+
# ):
96+
# class TestModule(torch.nn.Module):
97+
# def __init__(self):
98+
# super().__init__()
99+
# self.pool = torch.nn.AdaptiveAvgPool3d(output_size)
100+
101+
# def forward(self, x):
102+
# return self.pool(x)
103+
104+
# inputs = [torch.randn(1, 3, 32, 64, 64)]
105+
# self.run_test(
106+
# TestModule(),
107+
# inputs,
108+
# use_dynamo_tracer=True,
109+
# )
107110

108111

109112
if __name__ == "__main__":

0 commit comments

Comments
 (0)