Skip to content

Commit a694541

Browse files
pssrawatfacebook-github-bot
authored andcommitted
Accept default padding value for torch.constant_pad_nd (#7469)
Summary: xnnpack delegation for pad op assumes that the pad value is always present. However, constant_pad_nd defults to padding value of 0.0 if it's not present in the op. When absent, it errors out in `padding_value = cast(float, node.args[2])` with `IndexError: tuple index out of range`. {F1974161274} This diff defaults to padding value of 0.0 if the arg is absent from torch.constant_pad_nd op. Differential Revision: D67756862
1 parent 3ef78ee commit a694541

File tree

2 files changed

+47
-1
lines changed

2 files changed

+47
-1
lines changed

backends/xnnpack/operators/op_static_constant_pad.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,14 @@ def define_node(
116116
pre_paddings = all_paddings[-2::-2] # even index elements in reverse order
117117
post_paddings = all_paddings[::-2] # odd index elements in reverse order
118118

119+
# the padding value, which defaults to 0.0
120+
padding_value = cast(float, node.args[2]) if len(node.args) > 2 else 0.0
121+
119122
ser_node = XNode(
120123
xnode_union=XNNStaticConstantPad(
121124
pre_paddings=pre_paddings,
122125
post_paddings=post_paddings,
123-
padding_value=cast(float, node.args[2]),
126+
padding_value=padding_value,
124127
input_id=input_id,
125128
output_id=output_id,
126129
flags=0,

backends/xnnpack/test/ops/test_static_constant_pad.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,49 @@ def test_fp32_static_constant_pad_functional(self):
114114
)
115115
self._test_static_constant_pad_functional(inputs)
116116

117+
def test_constant_pad_nd(self):
118+
class ConstantPad(torch.nn.Module):
119+
def __init__(self):
120+
super().__init__()
121+
122+
def forward(self, x, y, z):
123+
pad_6 = (1, 2, 3, 4, 5, 6)
124+
pad_4 = (1, 2, 3, 4)
125+
pad_2 = (1, 2)
126+
a = torch.constant_pad_nd(
127+
input=x,
128+
pad=pad_6
129+
)
130+
b = torch.constant_pad_nd(
131+
input=y,
132+
pad=pad_4
133+
)
134+
c = torch.constant_pad_nd(
135+
input=z,
136+
pad=pad_2
137+
)
138+
139+
return (a + a, b + b, c + c)
140+
141+
inputs = (
142+
torch.randn(size=(5, 4, 3, 2)),
143+
torch.randn(size=(5, 3, 2)),
144+
torch.randn(size=(4, 3)),
145+
)
146+
(
147+
Tester(ConstantPad(), inputs)
148+
.export()
149+
.check_count({"torch.ops.aten.constant_pad_nd.default": 3})
150+
.to_edge_transform_and_lower()
151+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
152+
.check_not(
153+
["executorch_exir_dialects_edge__ops_aten_constant_pad_nd_default"]
154+
)
155+
.to_executorch()
156+
.serialize()
157+
.run_method_and_compare_outputs()
158+
)
159+
117160
def test_qs8_static_constant_pad_functional(self):
118161
class Pad(torch.nn.Module):
119162
def __init__(self):

0 commit comments

Comments
 (0)