Skip to content

feat: support adaptive_avg_pool1d dynamo converter #2614

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2184,6 +2184,24 @@ def aten_ops_avg_pool(
)


@dynamo_tensorrt_converter(torch.ops.aten.adaptive_avg_pool1d.default)
def aten_ops_adaptive_avg_pool(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.pool.adaptive_avg_pool1d(
ctx,
target,
source_ir=SourceIR.ATEN,
name=name,
input=args[0],
output_size=args[1],
)


def max_pool_param_validator(pool_node: Node) -> bool:
dilation = args_bounds_check(pool_node.args, 4, 1)
ceil_mode = args_bounds_check(pool_node.args, 5, False)
Expand Down
67 changes: 66 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/impl/pool.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Optional, Sequence, Union
import math
from typing import Dict, Optional, Sequence, Union

import tensorrt as trt
import torch_tensorrt.dynamo.conversion.impl as impl
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
Expand Down Expand Up @@ -104,3 +106,66 @@ def max_poolNd(

set_layer_name(pool_layer, target, name, source_ir)
return pool_layer.get_output(0)


def adaptive_avg_pool1d(
ctx: ConversionContext,
target: Union[Target, str],
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
output_size: Union[int, Sequence[int]],
) -> TRTTensor:
def start_index(idx: int, out_dim: int, in_dim: int) -> int:
"""Calculate the start index of each pooling window"""
return math.floor((float(idx) * float(in_dim)) / out_dim)

def end_index(idx: int, out_dim: int, in_dim: int) -> int:
"""Calculate the end index of each pooling window"""
return math.ceil((float(idx + 1) * float(in_dim)) / out_dim)

in_dim = input.shape[-1]
out_dim = output_size if isinstance(output_size, int) else output_size[0]
output_list = []

# store {index: slice} for reducing repeated slice ops
idx_slice_map: Dict[int, TRTTensor] = {}
# iterate over each output dimension
for i in range(out_dim):
# calculate the start and end index of each pooling window
start = start_index(i, out_dim, in_dim)
end = end_index(i, out_dim, in_dim)

# slice the input tensor from start to end index, the result of which is the window waiting for pooling
slices = []
for j in range(start, end):
if j in idx_slice_map:
slice = idx_slice_map[j]
else:
slice = impl.select.select(
ctx, target, source_ir, f"{name}_select_{j}", input, -1, j
)
slice = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_reshape_{i}_{j}",
slice,
(*slice.shape, 1),
)
idx_slice_map[j] = slice

slices.append(slice)

slices = impl.cat.cat(
ctx, target, source_ir, f"{name}_slices_cat_{i}", slices, dim=-1
)
# calculate the mean of the slices (average pooling output) and append to the output list
output_list.append(
impl.reduce.mean(
ctx, target, source_ir, f"{name}_sum_{i}", slices, dim=-1, keepdim=True
)
)

output = impl.cat.cat(ctx, target, source_ir, f"{name}_cat", output_list, dim=-1)
return output
139 changes: 57 additions & 82 deletions tests/py/dynamo/conversion/test_adaptive_avgpool_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,102 +9,77 @@
class TestAdaptiveAvgPoolConverter(DispatchTestCase):
@parameterized.expand(
[
((64, 64),),
((128, 64),),
# (64,), This case has been there in previous code but it isn't a valid pytorch code.
]
)
def test_adaptive_avgpool(
self,
output_size,
):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.pool = torch.nn.AdaptiveAvgPool2d(output_size)

def forward(self, x):
return self.pool(x)

inputs = [torch.randn(1, 3, 256, 256)]
self.run_test(
TestModule(),
inputs,
use_dynamo_tracer=True,
)

def test_adaptive_avgpool_with_dynamic_shape(self):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.pool = torch.nn.AdaptiveAvgPool2d((64, 64))

def forward(self, x):
return self.pool(x)

input_specs = [
Input(
shape=(-1, -1, 256, 256),
dtype=torch.float32,
shape_ranges=[((1, 1, 256, 256), (3, 3, 256, 256), (5, 5, 256, 256))],
(
(2, 3),
2,
),
(
(2, 8),
8,
),
(
(1, 2, 3),
2,
),
(
(2, 2, 8),
16,
),
(
(2, 3),
(1,),
),
(
(2, 3),
(2,),
),
(
(2, 8),
(4,),
),
(
(2, 8),
(16,),
),
(
(2, 3, 1),
(1,),
),
(
(2, 3, 2),
(2,),
),
(
(2, 3, 4),
(4,),
),
(
(2, 2, 32),
(31,),
),
(
(2, 2, 32),
(64,),
),
]
self.run_test_with_dynamic_shape(
TestModule(), input_specs, use_dynamo_tracer=True
)

@parameterized.expand(
[
((16, 16, 16),),
((32, 16, 4),),
(32,),
]
)
def test_adaptive_avgpool3d(
def test_adaptive_avg_pool1d(
self,
input_shape,
output_size,
):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.pool = torch.nn.AdaptiveAvgPool3d(output_size)

def forward(self, x):
return self.pool(x)
return torch.ops.aten.adaptive_avg_pool1d.default(x, output_size)

inputs = [torch.randn(1, 3, 32, 64, 64)]
inputs = [torch.randn(input_shape)]
self.run_test(
TestModule(),
inputs,
use_dynamo_tracer=True,
# use_dynamo_tracer=True,
enable_passes=True,
)

def test_adaptive_avgpool3d_with_dynamic_shape(self):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.pool = torch.nn.AdaptiveAvgPool3d((16, 16, 16))

def forward(self, x):
return self.pool(x)

input_specs = [
Input(
shape=(-1, -1, 32, 64, 64),
dtype=torch.float32,
shape_ranges=[
((1, 1, 32, 64, 64), (3, 3, 32, 64, 64), (5, 5, 32, 64, 64))
],
),
]
self.run_test_with_dynamic_shape(
TestModule(),
input_specs,
use_dynamo_tracer=True,
)

# Testing with shape(-1, -1, -1, -1) results into error: "AdaptiveAvgPool2d and AdaptiveAvgPool3d currently doesn't support dynamic shapes for last two dims."


if __name__ == "__main__":
run_tests()