Skip to content

Commit 891c2ef

Browse files
authored
feat: support 1D, 2D, and 3D avg and max pooling dynamo converters (#2317)
1 parent 0d402fb commit 891c2ef

File tree

4 files changed

+428
-0
lines changed

4 files changed

+428
-0
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1413,3 +1413,90 @@ def aten_ops_linear(
14131413
weight=args[1],
14141414
bias=args_bounds_check(args, 2, None),
14151415
)
1416+
1417+
1418+
def avg_pool_param_validator(pool_node: Node) -> bool:
1419+
ceil_mode = args_bounds_check(pool_node.args, 4, False)
1420+
divisor_override = args_bounds_check(pool_node.args, 6)
1421+
1422+
if ceil_mode is not False:
1423+
_LOGGER.debug(
1424+
f"Currently we don't support specifying ceil_mode, got ceil_mode={ceil_mode}."
1425+
)
1426+
return False
1427+
1428+
if divisor_override is not None:
1429+
_LOGGER.debug(
1430+
f"Currently we don't support divisor_override, got divisor_override={divisor_override}."
1431+
)
1432+
return False
1433+
1434+
return True
1435+
1436+
1437+
# Note: AvgPool1d uses avg_pool2d as it converts to 2D first.
1438+
@dynamo_tensorrt_converter(torch.ops.aten.avg_pool1d.default, capability_validator=avg_pool_param_validator) # type: ignore[misc]
1439+
@dynamo_tensorrt_converter(torch.ops.aten.avg_pool2d.default, capability_validator=avg_pool_param_validator) # type: ignore[misc]
1440+
@dynamo_tensorrt_converter(torch.ops.aten.avg_pool3d.default, capability_validator=avg_pool_param_validator) # type: ignore[misc]
1441+
def aten_ops_avg_pool(
1442+
network: TRTNetwork,
1443+
target: Target,
1444+
args: Tuple[Argument, ...],
1445+
kwargs: Dict[str, Argument],
1446+
name: str,
1447+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1448+
return impl.pool.avg_poolNd(
1449+
network,
1450+
target,
1451+
source_ir=SourceIR.ATEN,
1452+
name=name,
1453+
input=args[0],
1454+
kernel_size=args[1],
1455+
stride=args_bounds_check(args, 2, replacement=[]),
1456+
padding=args_bounds_check(args, 3, replacement=0),
1457+
ceil_mode=args_bounds_check(args, 4, replacement=False),
1458+
count_include_pad=args_bounds_check(args, 5, replacement=True),
1459+
divisor_override=args_bounds_check(args, 6, replacement=None),
1460+
)
1461+
1462+
1463+
def max_pool_param_validator(pool_node: Node) -> bool:
1464+
dilation = args_bounds_check(pool_node.args, 4, 1)
1465+
ceil_mode = args_bounds_check(pool_node.args, 5, False)
1466+
1467+
if dilation != 1:
1468+
_LOGGER.debug(f"Currently we don't support dilation, got dilation={dilation}.")
1469+
return False
1470+
1471+
if ceil_mode is not False:
1472+
_LOGGER.debug(
1473+
f"Currently we don't support specifying ceil_mode, got ceil_mode={ceil_mode}."
1474+
)
1475+
return False
1476+
1477+
return True
1478+
1479+
1480+
# Note: MaxPool1d uses max_pool2d as it converts to 2D first.
1481+
@dynamo_tensorrt_converter(torch.ops.aten.max_pool1d.default, capability_validator=max_pool_param_validator) # type: ignore[misc]
1482+
@dynamo_tensorrt_converter(torch.ops.aten.max_pool2d.default, capability_validator=max_pool_param_validator) # type: ignore[misc]
1483+
@dynamo_tensorrt_converter(torch.ops.aten.max_pool3d.default, capability_validator=max_pool_param_validator) # type: ignore[misc]
1484+
def aten_ops_max_pool(
1485+
network: TRTNetwork,
1486+
target: Target,
1487+
args: Tuple[Argument, ...],
1488+
kwargs: Dict[str, Argument],
1489+
name: str,
1490+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1491+
return impl.pool.max_poolNd(
1492+
network,
1493+
target,
1494+
source_ir=SourceIR.ATEN,
1495+
name=name,
1496+
input=args[0],
1497+
kernel_size=args[1],
1498+
stride=args_bounds_check(args, 2, replacement=[]),
1499+
padding=args_bounds_check(args, 3, replacement=0),
1500+
dilation=args_bounds_check(args, 4, replacement=1),
1501+
ceil_mode=args_bounds_check(args, 5, replacement=False),
1502+
)

py/torch_tensorrt/dynamo/conversion/impl/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
matmul,
1313
normalization,
1414
permutation,
15+
pool,
1516
reduce,
1617
select,
1718
shape,
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
from typing import Optional, Sequence, Union
2+
3+
import tensorrt as trt
4+
from torch.fx.node import Target
5+
from torch_tensorrt.dynamo._SourceIR import SourceIR
6+
from torch_tensorrt.dynamo.conversion.converter_utils import extend_attr_to_tuple
7+
from torch_tensorrt.fx.converters.converter_utils import (
8+
has_dynamic_shape,
9+
set_layer_name,
10+
)
11+
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
12+
13+
14+
def avg_poolNd(
15+
network: TRTNetwork,
16+
target: Union[Target, str],
17+
source_ir: Optional[SourceIR],
18+
name: str,
19+
input: TRTTensor,
20+
kernel_size: Sequence[int],
21+
stride: Union[int, Sequence[int]],
22+
padding: Union[int, Sequence[int]] = 0,
23+
ceil_mode: bool = False,
24+
count_include_pad: bool = True,
25+
divisor_override: Optional[int] = None,
26+
) -> TRTTensor:
27+
if has_dynamic_shape(input.shape):
28+
assert input.shape[1] != -1, "Channel dim can't be dynamic for pooling."
29+
30+
if ceil_mode is not False:
31+
raise RuntimeError("ceil_mode is not yet supported!")
32+
33+
if divisor_override is not None:
34+
raise RuntimeError("divisor_override is not yet supported!")
35+
36+
dim = len(kernel_size)
37+
38+
kernel_size = extend_attr_to_tuple(kernel_size, dim)
39+
40+
if stride == []:
41+
stride = kernel_size
42+
else:
43+
stride = extend_attr_to_tuple(stride, dim)
44+
45+
padding = extend_attr_to_tuple(padding, dim)
46+
47+
# add average pooling layer
48+
pool_layer = network.add_pooling_nd(
49+
input=input,
50+
type=trt.PoolingType.AVERAGE,
51+
window_size=kernel_size,
52+
)
53+
54+
pool_layer.stride_nd = stride
55+
pool_layer.padding_nd = padding
56+
pool_layer.average_count_excludes_padding = not count_include_pad
57+
58+
set_layer_name(pool_layer, target, name, source_ir)
59+
return pool_layer.get_output(0)
60+
61+
62+
def max_poolNd(
63+
network: TRTNetwork,
64+
target: Union[Target, str],
65+
source_ir: Optional[SourceIR],
66+
name: str,
67+
input: TRTTensor,
68+
kernel_size: Sequence[int],
69+
stride: Union[int, Sequence[int]],
70+
padding: Union[int, Sequence[int]] = 0,
71+
dilation: Union[int, Sequence[int]] = 1,
72+
ceil_mode: bool = False,
73+
) -> TRTTensor:
74+
if has_dynamic_shape(input.shape):
75+
assert input.shape[1] != -1, "Channel dim can't be dynamic for pooling."
76+
77+
if dilation != 1:
78+
raise RuntimeError("dilation is not yet supported!")
79+
80+
if ceil_mode is not False:
81+
raise RuntimeError("ceil_mode is not yet supported!")
82+
83+
dim = len(kernel_size)
84+
85+
kernel_size = extend_attr_to_tuple(kernel_size, dim)
86+
87+
if stride == []:
88+
stride = kernel_size
89+
else:
90+
stride = extend_attr_to_tuple(stride, dim)
91+
92+
padding = extend_attr_to_tuple(padding, dim)
93+
94+
# add max pooling layer
95+
pool_layer = network.add_pooling_nd(
96+
input=input,
97+
type=trt.PoolingType.MAX,
98+
window_size=kernel_size,
99+
)
100+
101+
pool_layer.stride_nd = stride
102+
pool_layer.padding_nd = padding
103+
104+
set_layer_name(pool_layer, target, name, source_ir)
105+
return pool_layer.get_output(0)

0 commit comments

Comments
 (0)