Skip to content

conv1d fix #411

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions backends/xnnpack/passes/conv1d_unsqueeze_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,23 @@ def call(self, graph_module: torch.fx.GraphModule):
kernel_param_4d = torch.nn.Parameter(
data=kernel_param_3d.data.contiguous().unsqueeze(dim=-1)
)
setattr(
kernel_node.graph.owning_module,
kernel_node.target,
kernel_param_4d,
)

if torch._export.utils.is_param(self.exported_program, kernel_node):
parameter_name = (
self.exported_program.graph_signature.inputs_to_parameters[
kernel_node.name
]
)
self.exported_program.state_dict[
parameter_name
] = kernel_param_4d
kernel_node.meta["val"] = kernel_param_4d.data.contiguous()
else:
setattr(
kernel_node.graph.owning_module,
kernel_node.target,
kernel_param_4d,
)

# (b) Extend stride, padding, and dilation for extra dim
node.args = (
Expand Down
1 change: 1 addition & 0 deletions backends/xnnpack/test/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ python_unittest(
name = "test_xnnpack_ops",
srcs = [
"ops/add.py",
"ops/conv1d.py",
],
deps = [
"//caffe2:torch",
Expand Down
58 changes: 58 additions & 0 deletions backends/xnnpack/test/ops/conv1d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch

from executorch.backends.xnnpack.test.tester import Tester


class TestConv1d(unittest.TestCase):
class Conv1d(torch.nn.Module):
def __init__(self):
groups = 1
stride = [2]
padding = [1]
dilation = [1]
in_channels = 2
out_channels = 1
kernel_size = (3,)

super().__init__()

self.conv1d = torch.nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
dilation=dilation,
bias=True,
)

def forward(self, x):
return self.conv1d(x)

def test_conv1d(self):
inputs = (torch.randn(1, 2, 4),)
(
Tester(self.Conv1d(), inputs)
.export()
.check_count({"torch.ops.aten.convolution.default": 1})
.to_edge()
.check_count(
{"executorch_exir_dialects_edge__ops_aten_convolution_default": 1}
)
.partition()
.check_not(["executorch_exir_dialects_edge__ops_aten_convolution_default"])
.check_count({"torch.ops.executorch_call_delegate": 1})
.to_executorch()
.serialize()
.run_method()
.compare_outputs()
)
2 changes: 1 addition & 1 deletion examples/backend/xnnpack_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@
edge = export_to_edge(
model,
example_inputs,
capture_config=CaptureConfig(enable_aot=True, _unlift=True),
capture_config=CaptureConfig(enable_aot=True),
edge_compile_config=EdgeCompileConfig(
_check_ir_validity=False if args.quantize else True,
),
Expand Down