Skip to content

Commit 7a2d885

Browse files
authored
[ExecuTorch][XNNPACK] Don't partition 3d and transposed convs (#4796)
Little bug got past partitioning, because 3d and transposed convolutions were now being partitioned. We expand the scope of the ConvConfig's check_constraint to also fail when the convolutions are either transposed or 3d Co-authored-by: Max Ren <[email protected]> Pull Request resolved: #4763
1 parent 4a27a53 commit 7a2d885

File tree

1 file changed

+18
-1
lines changed

1 file changed

+18
-1
lines changed

backends/xnnpack/partition/config/gemm_configs.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from itertools import chain
8-
from typing import List, Optional, Tuple
8+
from typing import cast, List, Optional, Tuple
99

1010
import torch
1111
from executorch.backends.xnnpack.partition.config.xnnpack_config import (
@@ -267,6 +267,23 @@ def __init__(self):
267267
fused_acts=["relu.default", "hardtanh.default"],
268268
)
269269

270+
def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
271+
"""
272+
Currently we have no support for convolution 3d and transposed convolution
273+
"""
274+
if not super().check_constraints(node, ep):
275+
return False
276+
277+
conv_stride = cast(List[int], node.args[3])
278+
if len(conv_stride) > 2:
279+
return False # Only support 1D + 2D Conv
280+
281+
transposed = cast(bool, node.args[6])
282+
if transposed:
283+
return False # Currently don't support transposed conv
284+
285+
return True
286+
270287
def supported_precision_types(self):
271288
return [
272289
ConfigPrecisionType.FP32,

0 commit comments

Comments
 (0)