Skip to content

Commit 96b83ce

Browse files
kirklandsignfacebook-github-bot
authored andcommitted
conv1d fix (#411)
Summary: Pull Request resolved: #411 Fix conv1d to handle both unlifted=True and unlifted=False. When unlifted is False, we need to update the node for the parameter. `setattr` doesn't work there. In a follow-up diff, we add that to torch._export.utils Reviewed By: digantdesai Differential Revision: D49428868 fbshipit-source-id: 74ae7afa6782720059ffac0b505aed47ad79a74c
1 parent fedc04c commit 96b83ce

File tree

4 files changed

+77
-6
lines changed

4 files changed

+77
-6
lines changed

backends/xnnpack/passes/conv1d_unsqueeze_pass.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,23 @@ def call(self, graph_module: torch.fx.GraphModule):
6363
kernel_param_4d = torch.nn.Parameter(
6464
data=kernel_param_3d.data.contiguous().unsqueeze(dim=-1)
6565
)
66-
setattr(
67-
kernel_node.graph.owning_module,
68-
kernel_node.target,
69-
kernel_param_4d,
70-
)
66+
67+
if torch._export.utils.is_param(self.exported_program, kernel_node):
68+
parameter_name = (
69+
self.exported_program.graph_signature.inputs_to_parameters[
70+
kernel_node.name
71+
]
72+
)
73+
self.exported_program.state_dict[
74+
parameter_name
75+
] = kernel_param_4d
76+
kernel_node.meta["val"] = kernel_param_4d.data.contiguous()
77+
else:
78+
setattr(
79+
kernel_node.graph.owning_module,
80+
kernel_node.target,
81+
kernel_param_4d,
82+
)
7183

7284
# (b) Extend stride, padding, and dilation for extra dim
7385
node.args = (

backends/xnnpack/test/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ python_unittest(
111111
name = "test_xnnpack_ops",
112112
srcs = [
113113
"ops/add.py",
114+
"ops/conv1d.py",
114115
],
115116
deps = [
116117
"//caffe2:torch",

backends/xnnpack/test/ops/conv1d.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
11+
from executorch.backends.xnnpack.test.tester import Tester
12+
13+
14+
class TestConv1d(unittest.TestCase):
15+
class Conv1d(torch.nn.Module):
16+
def __init__(self):
17+
groups = 1
18+
stride = [2]
19+
padding = [1]
20+
dilation = [1]
21+
in_channels = 2
22+
out_channels = 1
23+
kernel_size = (3,)
24+
25+
super().__init__()
26+
27+
self.conv1d = torch.nn.Conv1d(
28+
in_channels=in_channels,
29+
out_channels=out_channels,
30+
kernel_size=kernel_size,
31+
stride=stride,
32+
padding=padding,
33+
groups=groups,
34+
dilation=dilation,
35+
bias=True,
36+
)
37+
38+
def forward(self, x):
39+
return self.conv1d(x)
40+
41+
def test_conv1d(self):
42+
inputs = (torch.randn(1, 2, 4),)
43+
(
44+
Tester(self.Conv1d(), inputs)
45+
.export()
46+
.check_count({"torch.ops.aten.convolution.default": 1})
47+
.to_edge()
48+
.check_count(
49+
{"executorch_exir_dialects_edge__ops_aten_convolution_default": 1}
50+
)
51+
.partition()
52+
.check_not(["executorch_exir_dialects_edge__ops_aten_convolution_default"])
53+
.check_count({"torch.ops.executorch_call_delegate": 1})
54+
.to_executorch()
55+
.serialize()
56+
.run_method()
57+
.compare_outputs()
58+
)

examples/backend/xnnpack_examples.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@
8282
edge = export_to_edge(
8383
model,
8484
example_inputs,
85-
capture_config=CaptureConfig(enable_aot=True, _unlift=True),
85+
capture_config=CaptureConfig(enable_aot=True),
8686
edge_compile_config=EdgeCompileConfig(
8787
_check_ir_validity=False if args.quantize else True,
8888
),

0 commit comments

Comments
 (0)