Skip to content

Commit cfa6756

Browse files
mcr229facebook-github-bot
authored andcommitted
Dynamic Conv1d + W2L (#2976)
Summary: Pull Request resolved: #2976 Conv1d uses static reshape operator, in order to convert 3d tensor to 4d tensor so xnnpack can operate using conv2d. For dynamism, reshape only accepts a single dynamic dimension, which is denoted as dynamic with a dim of 0. Reviewed By: digantdesai Differential Revision: D55815092
1 parent 39642b9 commit cfa6756

File tree

3 files changed

+75
-19
lines changed

3 files changed

+75
-19
lines changed

backends/xnnpack/operators/op_squeeze.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,21 @@ def define_node(
5353
"val" in input_node.meta,
5454
"Missing val in tensor metadata for input when serializing XNNStaticReshape node",
5555
)
56-
new_shape = get_shape(input_node)[:-1]
56+
dynamic_shape = node.meta["val"].shape
57+
new_shape = []
58+
59+
num_dynamic_dims = 0
60+
for dim in dynamic_shape:
61+
if isinstance(dim, torch.SymInt):
62+
num_dynamic_dims += 1
63+
new_shape.append(0)
64+
else:
65+
new_shape.append(dim)
66+
67+
check_or_raise(
68+
num_dynamic_dims <= 1,
69+
"XNNPACK reshape only supports 1 dynamic dimension. This may occur when ",
70+
)
5771

5872
ser_node = XNode(
5973
xnode_union=XNNStaticReshape(
@@ -101,7 +115,21 @@ def define_node(
101115
"val" in input_node.meta,
102116
"Missing val in tensor metadata for input when serializing XNNStaticReshape node",
103117
)
104-
new_shape = get_shape(input_node) + [1]
118+
dynamic_shape = node.meta["val"].shape
119+
new_shape = []
120+
121+
num_dynamic_dims = 0
122+
for dim in dynamic_shape:
123+
if isinstance(dim, torch.SymInt):
124+
num_dynamic_dims += 1
125+
new_shape.append(0)
126+
else:
127+
new_shape.append(dim)
128+
129+
check_or_raise(
130+
num_dynamic_dims <= 1,
131+
"XNNPACK reshape only supports 1 dynamic dimension. This may occur when ",
132+
)
105133

106134
ser_node = XNode(
107135
xnode_union=XNNStaticReshape(

backends/xnnpack/test/models/w2l.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@ class TestW2L(unittest.TestCase):
1515
batch_size = 10
1616
input_frames = 700
1717
vocab_size = 4096
18+
num_features = 1
1819
wav2letter = models.Wav2Letter(num_classes=vocab_size).eval()
1920

20-
model_inputs = (torch.randn(batch_size, 1, input_frames),)
21+
model_inputs = (torch.randn(batch_size, num_features, input_frames),)
22+
dynamic_shape = ({0: torch.export.Dim("batch", min=2, max=10)},)
2123

2224
def test_fp32_w2l(self):
2325
(
24-
Tester(self.wav2letter, self.model_inputs)
26+
Tester(self.wav2letter, self.model_inputs, self.dynamic_shape)
2527
.export()
2628
.to_edge()
2729
.partition()
@@ -34,12 +36,12 @@ def test_fp32_w2l(self):
3436
.check(["torch.ops.higher_order.executorch_call_delegate"])
3537
.to_executorch()
3638
.serialize()
37-
.run_method_and_compare_outputs()
39+
.run_method_and_compare_outputs(num_runs=10)
3840
)
3941

4042
def test_qs8_w2l(self):
4143
(
42-
Tester(self.wav2letter.eval(), self.model_inputs)
44+
Tester(self.wav2letter.eval(), self.model_inputs, self.dynamic_shape)
4345
.quantize()
4446
.export()
4547
.to_edge()
@@ -53,5 +55,5 @@ def test_qs8_w2l(self):
5355
.check(["torch.ops.higher_order.executorch_call_delegate"])
5456
.to_executorch()
5557
.serialize()
56-
.run_method_and_compare_outputs()
58+
.run_method_and_compare_outputs(num_runs=10)
5759
)

backends/xnnpack/test/ops/conv1d.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,15 @@ def forward(self, x):
8181
z = torch.add(y, z)
8282
return z
8383

84-
def _test_conv1d(self, module, inputs, conv_count, quantized=False):
84+
def _test_conv1d(
85+
self, module, inputs, conv_count, quantized=False, dynamic_shape=None
86+
):
8587
(
86-
(Tester(module, inputs).quantize() if quantized else Tester(module, inputs))
88+
(
89+
Tester(module, inputs, dynamic_shape).quantize()
90+
if quantized
91+
else Tester(module, inputs)
92+
)
8793
.export()
8894
.check_count({"torch.ops.aten.convolution.default": conv_count})
8995
.to_edge()
@@ -101,21 +107,41 @@ def _test_conv1d(self, module, inputs, conv_count, quantized=False):
101107
)
102108

103109
def test_fp16_conv1d(self):
104-
inputs = (torch.randn(1, 2, 4).to(torch.float16),)
105-
self._test_conv1d(self.Conv1d(dtype=torch.float16), inputs, conv_count=1)
110+
inputs = (torch.randn(2, 2, 4).to(torch.float16),)
111+
dynamic_shapes = ({0: torch.export.Dim("batch", min=2, max=10)},)
112+
self._test_conv1d(
113+
self.Conv1d(dtype=torch.float16),
114+
inputs,
115+
conv_count=1,
116+
dynamic_shape=dynamic_shapes,
117+
)
106118

107119
def test_fp32_conv1d(self):
108-
inputs = (torch.randn(1, 2, 4),)
109-
self._test_conv1d(self.Conv1d(), inputs, 1)
120+
inputs = (torch.randn(2, 2, 4),)
121+
dynamic_shapes = ({0: torch.export.Dim("batch", min=2, max=10)},)
122+
self._test_conv1d(self.Conv1d(), inputs, 1, dynamic_shape=dynamic_shapes)
110123

111124
def test_fp32_conv1d_batchnorm_seq(self):
112-
inputs = (torch.randn(1, 2, 4),)
113-
self._test_conv1d(self.Conv1dBatchNormSequential(), inputs, 2)
125+
inputs = (torch.randn(2, 2, 4),)
126+
dynamic_shapes = ({0: torch.export.Dim("batch", min=2, max=10)},)
127+
self._test_conv1d(
128+
self.Conv1dBatchNormSequential(), inputs, 2, dynamic_shape=dynamic_shapes
129+
)
114130

115131
def test_qs8_conv1d(self):
116-
inputs = (torch.randn(1, 2, 4),)
117-
self._test_conv1d(self.Conv1d(), inputs, 1, quantized=True)
132+
inputs = (torch.randn(2, 2, 4),)
133+
dynamic_shapes = ({0: torch.export.Dim("batch", min=2, max=10)},)
134+
self._test_conv1d(
135+
self.Conv1d(), inputs, 1, quantized=True, dynamic_shape=dynamic_shapes
136+
)
118137

119138
def test_qs8_conv1d_batchnorm_seq(self):
120-
inputs = (torch.randn(1, 2, 4),)
121-
self._test_conv1d(self.Conv1dBatchNormSequential(), inputs, 2, quantized=True)
139+
inputs = (torch.randn(2, 2, 4),)
140+
dynamic_shapes = ({0: torch.export.Dim("batch", min=2, max=10)},)
141+
self._test_conv1d(
142+
self.Conv1dBatchNormSequential(),
143+
inputs,
144+
2,
145+
quantized=True,
146+
dynamic_shape=dynamic_shapes,
147+
)

0 commit comments

Comments
 (0)