Skip to content

Dynamic Conv1d + W2L #2976

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions backends/xnnpack/operators/op_squeeze.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from typing import cast, Dict

import torch
from executorch.backends.transforms import get_shape
from executorch.backends.xnnpack.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
Expand Down Expand Up @@ -53,7 +52,21 @@ def define_node(
"val" in input_node.meta,
"Missing val in tensor metadata for input when serializing XNNStaticReshape node",
)
new_shape = get_shape(input_node)[:-1]
dynamic_shape = node.meta["val"].shape
new_shape = []

num_dynamic_dims = 0
for dim in dynamic_shape:
if isinstance(dim, torch.SymInt):
num_dynamic_dims += 1
new_shape.append(0)
else:
new_shape.append(dim)

check_or_raise(
num_dynamic_dims <= 1,
"XNNPACK reshape only supports 1 dynamic dimension. This may occur when ",
)

ser_node = XNode(
xnode_union=XNNStaticReshape(
Expand Down Expand Up @@ -101,7 +114,21 @@ def define_node(
"val" in input_node.meta,
"Missing val in tensor metadata for input when serializing XNNStaticReshape node",
)
new_shape = get_shape(input_node) + [1]
dynamic_shape = node.meta["val"].shape
new_shape = []

num_dynamic_dims = 0
for dim in dynamic_shape:
if isinstance(dim, torch.SymInt):
num_dynamic_dims += 1
new_shape.append(0)
else:
new_shape.append(dim)

check_or_raise(
num_dynamic_dims <= 1,
"XNNPACK reshape only supports 1 dynamic dimension. This may occur when ",
)

ser_node = XNode(
xnode_union=XNNStaticReshape(
Expand Down
12 changes: 7 additions & 5 deletions backends/xnnpack/test/models/w2l.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@ class TestW2L(unittest.TestCase):
batch_size = 10
input_frames = 700
vocab_size = 4096
num_features = 1
wav2letter = models.Wav2Letter(num_classes=vocab_size).eval()

model_inputs = (torch.randn(batch_size, 1, input_frames),)
model_inputs = (torch.randn(batch_size, num_features, input_frames),)
dynamic_shape = ({0: torch.export.Dim("batch", min=2, max=10)},)

def test_fp32_w2l(self):
(
Tester(self.wav2letter, self.model_inputs)
Tester(self.wav2letter, self.model_inputs, self.dynamic_shape)
.export()
.to_edge()
.partition()
Expand All @@ -34,12 +36,12 @@ def test_fp32_w2l(self):
.check(["torch.ops.higher_order.executorch_call_delegate"])
.to_executorch()
.serialize()
.run_method_and_compare_outputs()
.run_method_and_compare_outputs(num_runs=10)
)

def test_qs8_w2l(self):
(
Tester(self.wav2letter.eval(), self.model_inputs)
Tester(self.wav2letter.eval(), self.model_inputs, self.dynamic_shape)
.quantize()
.export()
.to_edge()
Expand All @@ -53,5 +55,5 @@ def test_qs8_w2l(self):
.check(["torch.ops.higher_order.executorch_call_delegate"])
.to_executorch()
.serialize()
.run_method_and_compare_outputs()
.run_method_and_compare_outputs(num_runs=10)
)
50 changes: 38 additions & 12 deletions backends/xnnpack/test/ops/conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,15 @@ def forward(self, x):
z = torch.add(y, z)
return z

def _test_conv1d(self, module, inputs, conv_count, quantized=False):
def _test_conv1d(
self, module, inputs, conv_count, quantized=False, dynamic_shape=None
):
(
(Tester(module, inputs).quantize() if quantized else Tester(module, inputs))
(
Tester(module, inputs, dynamic_shape).quantize()
if quantized
else Tester(module, inputs)
)
.export()
.check_count({"torch.ops.aten.convolution.default": conv_count})
.to_edge()
Expand All @@ -101,21 +107,41 @@ def _test_conv1d(self, module, inputs, conv_count, quantized=False):
)

def test_fp16_conv1d(self):
inputs = (torch.randn(1, 2, 4).to(torch.float16),)
self._test_conv1d(self.Conv1d(dtype=torch.float16), inputs, conv_count=1)
inputs = (torch.randn(2, 2, 4).to(torch.float16),)
dynamic_shapes = ({0: torch.export.Dim("batch", min=2, max=10)},)
self._test_conv1d(
self.Conv1d(dtype=torch.float16),
inputs,
conv_count=1,
dynamic_shape=dynamic_shapes,
)

def test_fp32_conv1d(self):
inputs = (torch.randn(1, 2, 4),)
self._test_conv1d(self.Conv1d(), inputs, 1)
inputs = (torch.randn(2, 2, 4),)
dynamic_shapes = ({0: torch.export.Dim("batch", min=2, max=10)},)
self._test_conv1d(self.Conv1d(), inputs, 1, dynamic_shape=dynamic_shapes)

def test_fp32_conv1d_batchnorm_seq(self):
inputs = (torch.randn(1, 2, 4),)
self._test_conv1d(self.Conv1dBatchNormSequential(), inputs, 2)
inputs = (torch.randn(2, 2, 4),)
dynamic_shapes = ({0: torch.export.Dim("batch", min=2, max=10)},)
self._test_conv1d(
self.Conv1dBatchNormSequential(), inputs, 2, dynamic_shape=dynamic_shapes
)

def test_qs8_conv1d(self):
inputs = (torch.randn(1, 2, 4),)
self._test_conv1d(self.Conv1d(), inputs, 1, quantized=True)
inputs = (torch.randn(2, 2, 4),)
dynamic_shapes = ({0: torch.export.Dim("batch", min=2, max=10)},)
self._test_conv1d(
self.Conv1d(), inputs, 1, quantized=True, dynamic_shape=dynamic_shapes
)

def test_qs8_conv1d_batchnorm_seq(self):
inputs = (torch.randn(1, 2, 4),)
self._test_conv1d(self.Conv1dBatchNormSequential(), inputs, 2, quantized=True)
inputs = (torch.randn(2, 2, 4),)
dynamic_shapes = ({0: torch.export.Dim("batch", min=2, max=10)},)
self._test_conv1d(
self.Conv1dBatchNormSequential(),
inputs,
2,
quantized=True,
dynamic_shape=dynamic_shapes,
)