Skip to content

Commit 5930e96

Browse files
zewenli98peri044
authored andcommitted
feat: support adaptive_avg_pool1d dynamo converter (#2614)
1 parent 815751b commit 5930e96

File tree

3 files changed

+141
-83
lines changed

3 files changed

+141
-83
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2200,6 +2200,24 @@ def aten_ops_avg_pool(
22002200
)
22012201

22022202

2203+
@dynamo_tensorrt_converter(torch.ops.aten.adaptive_avg_pool1d.default)
2204+
def aten_ops_adaptive_avg_pool(
2205+
ctx: ConversionContext,
2206+
target: Target,
2207+
args: Tuple[Argument, ...],
2208+
kwargs: Dict[str, Argument],
2209+
name: str,
2210+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2211+
return impl.pool.adaptive_avg_pool1d(
2212+
ctx,
2213+
target,
2214+
source_ir=SourceIR.ATEN,
2215+
name=name,
2216+
input=args[0],
2217+
output_size=args[1],
2218+
)
2219+
2220+
22032221
def max_pool_param_validator(pool_node: Node) -> bool:
22042222
dilation = args_bounds_check(pool_node.args, 4, 1)
22052223
ceil_mode = args_bounds_check(pool_node.args, 5, False)

py/torch_tensorrt/dynamo/conversion/impl/pool.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
from typing import Optional, Sequence, Union
1+
import math
2+
from typing import Dict, Optional, Sequence, Union
23

34
import tensorrt as trt
5+
import torch_tensorrt.dynamo.conversion.impl as impl
46
from torch.fx.node import Target
57
from torch_tensorrt.dynamo._SourceIR import SourceIR
68
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
@@ -104,3 +106,66 @@ def max_poolNd(
104106

105107
set_layer_name(pool_layer, target, name, source_ir)
106108
return pool_layer.get_output(0)
109+
110+
111+
def adaptive_avg_pool1d(
112+
ctx: ConversionContext,
113+
target: Union[Target, str],
114+
source_ir: Optional[SourceIR],
115+
name: str,
116+
input: TRTTensor,
117+
output_size: Union[int, Sequence[int]],
118+
) -> TRTTensor:
119+
def start_index(idx: int, out_dim: int, in_dim: int) -> int:
120+
"""Calculate the start index of each pooling window"""
121+
return math.floor((float(idx) * float(in_dim)) / out_dim)
122+
123+
def end_index(idx: int, out_dim: int, in_dim: int) -> int:
124+
"""Calculate the end index of each pooling window"""
125+
return math.ceil((float(idx + 1) * float(in_dim)) / out_dim)
126+
127+
in_dim = input.shape[-1]
128+
out_dim = output_size if isinstance(output_size, int) else output_size[0]
129+
output_list = []
130+
131+
# store {index: slice} for reducing repeated slice ops
132+
idx_slice_map: Dict[int, TRTTensor] = {}
133+
# iterate over each output dimension
134+
for i in range(out_dim):
135+
# calculate the start and end index of each pooling window
136+
start = start_index(i, out_dim, in_dim)
137+
end = end_index(i, out_dim, in_dim)
138+
139+
# slice the input tensor from start to end index, the result of which is the window waiting for pooling
140+
slices = []
141+
for j in range(start, end):
142+
if j in idx_slice_map:
143+
slice = idx_slice_map[j]
144+
else:
145+
slice = impl.select.select(
146+
ctx, target, source_ir, f"{name}_select_{j}", input, -1, j
147+
)
148+
slice = impl.shuffle.reshape(
149+
ctx,
150+
target,
151+
source_ir,
152+
f"{name}_reshape_{i}_{j}",
153+
slice,
154+
(*slice.shape, 1),
155+
)
156+
idx_slice_map[j] = slice
157+
158+
slices.append(slice)
159+
160+
slices = impl.cat.cat(
161+
ctx, target, source_ir, f"{name}_slices_cat_{i}", slices, dim=-1
162+
)
163+
# calculate the mean of the slices (average pooling output) and append to the output list
164+
output_list.append(
165+
impl.reduce.mean(
166+
ctx, target, source_ir, f"{name}_sum_{i}", slices, dim=-1, keepdim=True
167+
)
168+
)
169+
170+
output = impl.cat.cat(ctx, target, source_ir, f"{name}_cat", output_list, dim=-1)
171+
return output

tests/py/dynamo/conversion/test_adaptive_avgpool_aten.py

Lines changed: 57 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -9,102 +9,77 @@
99
class TestAdaptiveAvgPoolConverter(DispatchTestCase):
1010
@parameterized.expand(
1111
[
12-
((64, 64),),
13-
((128, 64),),
14-
# (64,), This case has been there in previous code but it isn't a valid pytorch code.
15-
]
16-
)
17-
def test_adaptive_avgpool(
18-
self,
19-
output_size,
20-
):
21-
class TestModule(torch.nn.Module):
22-
def __init__(self):
23-
super().__init__()
24-
self.pool = torch.nn.AdaptiveAvgPool2d(output_size)
25-
26-
def forward(self, x):
27-
return self.pool(x)
28-
29-
inputs = [torch.randn(1, 3, 256, 256)]
30-
self.run_test(
31-
TestModule(),
32-
inputs,
33-
use_dynamo_tracer=True,
34-
)
35-
36-
def test_adaptive_avgpool_with_dynamic_shape(self):
37-
class TestModule(torch.nn.Module):
38-
def __init__(self):
39-
super().__init__()
40-
self.pool = torch.nn.AdaptiveAvgPool2d((64, 64))
41-
42-
def forward(self, x):
43-
return self.pool(x)
44-
45-
input_specs = [
46-
Input(
47-
shape=(-1, -1, 256, 256),
48-
dtype=torch.float32,
49-
shape_ranges=[((1, 1, 256, 256), (3, 3, 256, 256), (5, 5, 256, 256))],
12+
(
13+
(2, 3),
14+
2,
15+
),
16+
(
17+
(2, 8),
18+
8,
19+
),
20+
(
21+
(1, 2, 3),
22+
2,
23+
),
24+
(
25+
(2, 2, 8),
26+
16,
27+
),
28+
(
29+
(2, 3),
30+
(1,),
31+
),
32+
(
33+
(2, 3),
34+
(2,),
35+
),
36+
(
37+
(2, 8),
38+
(4,),
39+
),
40+
(
41+
(2, 8),
42+
(16,),
43+
),
44+
(
45+
(2, 3, 1),
46+
(1,),
47+
),
48+
(
49+
(2, 3, 2),
50+
(2,),
51+
),
52+
(
53+
(2, 3, 4),
54+
(4,),
55+
),
56+
(
57+
(2, 2, 32),
58+
(31,),
59+
),
60+
(
61+
(2, 2, 32),
62+
(64,),
5063
),
51-
]
52-
self.run_test_with_dynamic_shape(
53-
TestModule(), input_specs, use_dynamo_tracer=True
54-
)
55-
56-
@parameterized.expand(
57-
[
58-
((16, 16, 16),),
59-
((32, 16, 4),),
60-
(32,),
6164
]
6265
)
63-
def test_adaptive_avgpool3d(
66+
def test_adaptive_avg_pool1d(
6467
self,
68+
input_shape,
6569
output_size,
6670
):
6771
class TestModule(torch.nn.Module):
68-
def __init__(self):
69-
super().__init__()
70-
self.pool = torch.nn.AdaptiveAvgPool3d(output_size)
71-
7272
def forward(self, x):
73-
return self.pool(x)
73+
return torch.ops.aten.adaptive_avg_pool1d.default(x, output_size)
7474

75-
inputs = [torch.randn(1, 3, 32, 64, 64)]
75+
inputs = [torch.randn(input_shape)]
7676
self.run_test(
7777
TestModule(),
7878
inputs,
79-
use_dynamo_tracer=True,
79+
# use_dynamo_tracer=True,
80+
enable_passes=True,
8081
)
8182

82-
def test_adaptive_avgpool3d_with_dynamic_shape(self):
83-
class TestModule(torch.nn.Module):
84-
def __init__(self):
85-
super().__init__()
86-
self.pool = torch.nn.AdaptiveAvgPool3d((16, 16, 16))
87-
88-
def forward(self, x):
89-
return self.pool(x)
90-
91-
input_specs = [
92-
Input(
93-
shape=(-1, -1, 32, 64, 64),
94-
dtype=torch.float32,
95-
shape_ranges=[
96-
((1, 1, 32, 64, 64), (3, 3, 32, 64, 64), (5, 5, 32, 64, 64))
97-
],
98-
),
99-
]
100-
self.run_test_with_dynamic_shape(
101-
TestModule(),
102-
input_specs,
103-
use_dynamo_tracer=True,
104-
)
105-
106-
# Testing with shape(-1, -1, -1, -1) results into error: "AdaptiveAvgPool2d and AdaptiveAvgPool3d currently doesn't support dynamic shapes for last two dims."
107-
10883

10984
if __name__ == "__main__":
11085
run_tests()

0 commit comments

Comments
 (0)