Skip to content

Commit 0dfd1c0

Browse files
authored
Do not offload transposed convolutions with non-zero output padding
Differential Revision: D69578428 Pull Request resolved: #8462
1 parent 556a1a2 commit 0dfd1c0

File tree

2 files changed

+51
-0
lines changed

2 files changed

+51
-0
lines changed

backends/xnnpack/partition/config/gemm_configs.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,18 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
334334
is_transpose = node.args[6]
335335
groups = cast(int, node.args[8])
336336

337+
# XNNPack does not support non-zero output padding in transposed
338+
# convolutions.
339+
if is_transpose and any(
340+
out_pad != 0 for out_pad in cast(List[int], node.args[7])
341+
):
342+
why(
343+
node,
344+
"XNNPACK does not support transposed convolutions with"
345+
"non-zero output padding",
346+
)
347+
return False
348+
337349
if (
338350
is_transpose
339351
and weight_quant_params is not None

backends/xnnpack/test/ops/test_conv2d.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,3 +657,42 @@ def get_inputs(self):
657657
quant_config=None,
658658
conv_count=1,
659659
)
660+
661+
def test_padded_output_tconv(self):
662+
class TConv2d(torch.nn.Module):
663+
def __init__(self):
664+
super().__init__()
665+
self.conv = torch.nn.ConvTranspose2d(
666+
in_channels=2,
667+
out_channels=1,
668+
kernel_size=(3, 3),
669+
stride=(2, 2),
670+
padding=(1, 1),
671+
output_padding=(0, 1),
672+
dilation=(1, 1),
673+
groups=1,
674+
bias=True,
675+
).to(torch.float)
676+
677+
def forward(self, x):
678+
return self.conv(x)
679+
680+
m = TConv2d()
681+
inputs = (torch.randn(1, 2, 8, 8),)
682+
tester = Tester(m.eval(), inputs)
683+
684+
conv_count: int = 1
685+
op = "torch.ops.aten.conv_transpose2d"
686+
687+
(tester.export().check_count({op: conv_count}).to_edge_transform_and_lower())
688+
689+
# tconv should not be offloaded to XNNPack, since output padding is not
690+
(
691+
tester.check(
692+
["executorch_exir_dialects_edge__ops_aten_convolution_default"]
693+
)
694+
.check_not(["torch.ops.higher_order.executorch_call_delegate"])
695+
.to_executorch()
696+
.serialize()
697+
.run_method_and_compare_outputs(qtol=1)
698+
)

0 commit comments

Comments
 (0)