Skip to content

Commit 7b375fe

Browse files
mcr229facebook-github-bot
authored andcommitted
Dynamic Conv1d + W2L (#2976)
Summary: Pull Request resolved: #2976 Conv1d uses static reshape operator, in order to convert 3d tensor to 4d tensor so xnnpack can operate using conv2d. For dynamism, reshape only accepts a single dynamic dimension, which is denoted as dynamic with a dim of 0. Reviewed By: digantdesai, kirklandsign Differential Revision: D55815092 fbshipit-source-id: a3c96bc5c86c130291c1d54f8174a6ff5d25a6b8
1 parent 15f141b commit 7b375fe

File tree

3 files changed

+75
-20
lines changed

3 files changed

+75
-20
lines changed

backends/xnnpack/operators/op_squeeze.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from typing import cast, Dict
88

99
import torch
10-
from executorch.backends.transforms import get_shape
1110
from executorch.backends.xnnpack.operators.node_visitor import (
1211
NodeVisitor,
1312
register_node_visitor,
@@ -53,7 +52,21 @@ def define_node(
5352
"val" in input_node.meta,
5453
"Missing val in tensor metadata for input when serializing XNNStaticReshape node",
5554
)
56-
new_shape = get_shape(input_node)[:-1]
55+
dynamic_shape = node.meta["val"].shape
56+
new_shape = []
57+
58+
num_dynamic_dims = 0
59+
for dim in dynamic_shape:
60+
if isinstance(dim, torch.SymInt):
61+
num_dynamic_dims += 1
62+
new_shape.append(0)
63+
else:
64+
new_shape.append(dim)
65+
66+
check_or_raise(
67+
num_dynamic_dims <= 1,
68+
"XNNPACK reshape only supports 1 dynamic dimension. This may occur when ",
69+
)
5770

5871
ser_node = XNode(
5972
xnode_union=XNNStaticReshape(
@@ -101,7 +114,21 @@ def define_node(
101114
"val" in input_node.meta,
102115
"Missing val in tensor metadata for input when serializing XNNStaticReshape node",
103116
)
104-
new_shape = get_shape(input_node) + [1]
117+
dynamic_shape = node.meta["val"].shape
118+
new_shape = []
119+
120+
num_dynamic_dims = 0
121+
for dim in dynamic_shape:
122+
if isinstance(dim, torch.SymInt):
123+
num_dynamic_dims += 1
124+
new_shape.append(0)
125+
else:
126+
new_shape.append(dim)
127+
128+
check_or_raise(
129+
num_dynamic_dims <= 1,
130+
"XNNPACK reshape only supports 1 dynamic dimension. This may occur when ",
131+
)
105132

106133
ser_node = XNode(
107134
xnode_union=XNNStaticReshape(

backends/xnnpack/test/models/w2l.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@ class TestW2L(unittest.TestCase):
1515
batch_size = 10
1616
input_frames = 700
1717
vocab_size = 4096
18+
num_features = 1
1819
wav2letter = models.Wav2Letter(num_classes=vocab_size).eval()
1920

20-
model_inputs = (torch.randn(batch_size, 1, input_frames),)
21+
model_inputs = (torch.randn(batch_size, num_features, input_frames),)
22+
dynamic_shape = ({0: torch.export.Dim("batch", min=2, max=10)},)
2123

2224
def test_fp32_w2l(self):
2325
(
24-
Tester(self.wav2letter, self.model_inputs)
26+
Tester(self.wav2letter, self.model_inputs, self.dynamic_shape)
2527
.export()
2628
.to_edge()
2729
.partition()
@@ -34,12 +36,12 @@ def test_fp32_w2l(self):
3436
.check(["torch.ops.higher_order.executorch_call_delegate"])
3537
.to_executorch()
3638
.serialize()
37-
.run_method_and_compare_outputs()
39+
.run_method_and_compare_outputs(num_runs=10)
3840
)
3941

4042
def test_qs8_w2l(self):
4143
(
42-
Tester(self.wav2letter.eval(), self.model_inputs)
44+
Tester(self.wav2letter.eval(), self.model_inputs, self.dynamic_shape)
4345
.quantize()
4446
.export()
4547
.to_edge()
@@ -53,5 +55,5 @@ def test_qs8_w2l(self):
5355
.check(["torch.ops.higher_order.executorch_call_delegate"])
5456
.to_executorch()
5557
.serialize()
56-
.run_method_and_compare_outputs()
58+
.run_method_and_compare_outputs(num_runs=10)
5759
)

backends/xnnpack/test/ops/conv1d.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,15 @@ def forward(self, x):
8181
z = torch.add(y, z)
8282
return z
8383

84-
def _test_conv1d(self, module, inputs, conv_count, quantized=False):
84+
def _test_conv1d(
85+
self, module, inputs, conv_count, quantized=False, dynamic_shape=None
86+
):
8587
(
86-
(Tester(module, inputs).quantize() if quantized else Tester(module, inputs))
88+
(
89+
Tester(module, inputs, dynamic_shape).quantize()
90+
if quantized
91+
else Tester(module, inputs)
92+
)
8793
.export()
8894
.check_count({"torch.ops.aten.convolution.default": conv_count})
8995
.to_edge()
@@ -101,21 +107,41 @@ def _test_conv1d(self, module, inputs, conv_count, quantized=False):
101107
)
102108

103109
def test_fp16_conv1d(self):
104-
inputs = (torch.randn(1, 2, 4).to(torch.float16),)
105-
self._test_conv1d(self.Conv1d(dtype=torch.float16), inputs, conv_count=1)
110+
inputs = (torch.randn(2, 2, 4).to(torch.float16),)
111+
dynamic_shapes = ({0: torch.export.Dim("batch", min=2, max=10)},)
112+
self._test_conv1d(
113+
self.Conv1d(dtype=torch.float16),
114+
inputs,
115+
conv_count=1,
116+
dynamic_shape=dynamic_shapes,
117+
)
106118

107119
def test_fp32_conv1d(self):
108-
inputs = (torch.randn(1, 2, 4),)
109-
self._test_conv1d(self.Conv1d(), inputs, 1)
120+
inputs = (torch.randn(2, 2, 4),)
121+
dynamic_shapes = ({0: torch.export.Dim("batch", min=2, max=10)},)
122+
self._test_conv1d(self.Conv1d(), inputs, 1, dynamic_shape=dynamic_shapes)
110123

111124
def test_fp32_conv1d_batchnorm_seq(self):
112-
inputs = (torch.randn(1, 2, 4),)
113-
self._test_conv1d(self.Conv1dBatchNormSequential(), inputs, 2)
125+
inputs = (torch.randn(2, 2, 4),)
126+
dynamic_shapes = ({0: torch.export.Dim("batch", min=2, max=10)},)
127+
self._test_conv1d(
128+
self.Conv1dBatchNormSequential(), inputs, 2, dynamic_shape=dynamic_shapes
129+
)
114130

115131
def test_qs8_conv1d(self):
116-
inputs = (torch.randn(1, 2, 4),)
117-
self._test_conv1d(self.Conv1d(), inputs, 1, quantized=True)
132+
inputs = (torch.randn(2, 2, 4),)
133+
dynamic_shapes = ({0: torch.export.Dim("batch", min=2, max=10)},)
134+
self._test_conv1d(
135+
self.Conv1d(), inputs, 1, quantized=True, dynamic_shape=dynamic_shapes
136+
)
118137

119138
def test_qs8_conv1d_batchnorm_seq(self):
120-
inputs = (torch.randn(1, 2, 4),)
121-
self._test_conv1d(self.Conv1dBatchNormSequential(), inputs, 2, quantized=True)
139+
inputs = (torch.randn(2, 2, 4),)
140+
dynamic_shapes = ({0: torch.export.Dim("batch", min=2, max=10)},)
141+
self._test_conv1d(
142+
self.Conv1dBatchNormSequential(),
143+
inputs,
144+
2,
145+
quantized=True,
146+
dynamic_shape=dynamic_shapes,
147+
)

0 commit comments

Comments
 (0)