Skip to content

Commit e66cdaf

Browse files
authored
[Xnnpack] Accept default padding value for torch.constant_pad_nd
Differential Revision: D67756862 Pull Request resolved: #7469
1 parent c86b39d commit e66cdaf

File tree

2 files changed

+38
-1
lines changed

2 files changed

+38
-1
lines changed

backends/xnnpack/operators/op_static_constant_pad.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,14 @@ def define_node(
116116
pre_paddings = all_paddings[-2::-2] # even index elements in reverse order
117117
post_paddings = all_paddings[::-2] # odd index elements in reverse order
118118

119+
# the padding value, which defaults to 0.0
120+
padding_value = cast(float, node.args[2]) if len(node.args) > 2 else 0.0
121+
119122
ser_node = XNode(
120123
xnode_union=XNNStaticConstantPad(
121124
pre_paddings=pre_paddings,
122125
post_paddings=post_paddings,
123-
padding_value=cast(float, node.args[2]),
126+
padding_value=padding_value,
124127
input_id=input_id,
125128
output_id=output_id,
126129
flags=0,

backends/xnnpack/test/ops/test_static_constant_pad.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,40 @@ def test_fp32_static_constant_pad_functional(self):
114114
)
115115
self._test_static_constant_pad_functional(inputs)
116116

117+
def test_constant_pad_nd(self):
118+
class ConstantPad(torch.nn.Module):
119+
def __init__(self):
120+
super().__init__()
121+
122+
def forward(self, x, y, z):
123+
pad_6 = (1, 2, 3, 4, 5, 6)
124+
pad_4 = (1, 2, 3, 4)
125+
pad_2 = (1, 2)
126+
a = torch.constant_pad_nd(input=x, pad=pad_6)
127+
b = torch.constant_pad_nd(input=y, pad=pad_4)
128+
c = torch.constant_pad_nd(input=z, pad=pad_2)
129+
130+
return (a + a, b + b, c + c)
131+
132+
inputs = (
133+
torch.randn(size=(5, 4, 3, 2)),
134+
torch.randn(size=(5, 3, 2)),
135+
torch.randn(size=(4, 3)),
136+
)
137+
(
138+
Tester(ConstantPad(), inputs)
139+
.export()
140+
.check_count({"torch.ops.aten.constant_pad_nd.default": 3})
141+
.to_edge_transform_and_lower()
142+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
143+
.check_not(
144+
["executorch_exir_dialects_edge__ops_aten_constant_pad_nd_default"]
145+
)
146+
.to_executorch()
147+
.serialize()
148+
.run_method_and_compare_outputs()
149+
)
150+
117151
def test_qs8_static_constant_pad_functional(self):
118152
class Pad(torch.nn.Module):
119153
def __init__(self):

0 commit comments

Comments
 (0)