Skip to content

Commit bf59da6

Browse files
mcr229facebook-github-bot
authored andcommitted
dynamic qd8-fc test with 2 batch dims (#2441)
Summary: Pull Request resolved: #2441 Adding the first dynamic input test, in which we test DQ Linear where it's inputs have rank = 3. Reviewed By: digantdesai, kirklandsign Differential Revision: D54665767 fbshipit-source-id: 3c6c7eb0a10b32f390effeb9ae88b74df21e823f
1 parent 65be9b4 commit bf59da6

File tree

1 file changed

+129
-70
lines changed

1 file changed

+129
-70
lines changed

backends/xnnpack/test/ops/linear.py

Lines changed: 129 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -26,23 +26,27 @@
2626
class TestLinear(unittest.TestCase):
2727
def test_fp16_linear(self):
2828
for use_bias in (True, False):
29-
self._test_linear(
30-
lambda in_size, out_size: torch.nn.Linear(
31-
in_size, out_size, bias=use_bias # noqa
32-
),
33-
uses_bias=use_bias,
34-
dtype=torch.float16,
35-
atol=5e-2,
36-
)
29+
for num_batch_dims in range(1, 3):
30+
self._test_linear(
31+
lambda in_size, out_size: torch.nn.Linear(
32+
in_size, out_size, bias=use_bias # noqa
33+
),
34+
num_batch_dims=num_batch_dims,
35+
uses_bias=use_bias,
36+
dtype=torch.float16,
37+
atol=5e-2,
38+
)
3739

3840
def test_fp32_linear(self):
3941
for use_bias in (True, False):
40-
self._test_linear(
41-
lambda in_size, out_size: torch.nn.Linear(
42-
in_size, out_size, bias=use_bias # noqa
43-
),
44-
uses_bias=use_bias,
45-
)
42+
for num_batch_dims in range(1, 3):
43+
self._test_linear(
44+
lambda in_size, out_size: torch.nn.Linear(
45+
in_size, out_size, bias=use_bias # noqa
46+
),
47+
uses_bias=use_bias,
48+
num_batch_dims=num_batch_dims,
49+
)
4650

4751
def test_fp32_addmm(self):
4852
"""
@@ -63,24 +67,71 @@ def forward(self, x):
6367
uses_bias=True,
6468
)
6569

70+
def test_fp32_linear_fused_relu(self):
71+
class LinearReluModule(torch.nn.Module):
72+
def __init__(self, in_size, out_size, use_bias):
73+
super().__init__()
74+
self.linear = torch.nn.Linear(in_size, out_size, bias=use_bias)
75+
76+
def forward(self, x):
77+
return torch.nn.functional.relu(self.linear(x))
78+
79+
for use_bias in (True, False):
80+
for num_batch_dims in range(1, 3):
81+
self._test_linear(
82+
lambda in_size, out_size: LinearReluModule(
83+
in_size,
84+
out_size,
85+
use_bias, # noqa
86+
),
87+
uses_bias=use_bias,
88+
num_batch_dims=num_batch_dims,
89+
)
90+
91+
def test_qs8_linear_fused_relu(self):
92+
class LinearReluModule(torch.nn.Module):
93+
def __init__(self, in_size, out_size, use_bias):
94+
super().__init__()
95+
self.linear = torch.nn.Linear(in_size, out_size, bias=use_bias)
96+
97+
def forward(self, x):
98+
return torch.nn.functional.relu(self.linear(x))
99+
100+
for use_bias in (True, False):
101+
for num_batch_dims in range(1, 3):
102+
self._test_linear(
103+
lambda in_size, out_size: LinearReluModule(
104+
in_size,
105+
out_size,
106+
use_bias, # noqa
107+
),
108+
num_batch_dims=num_batch_dims,
109+
uses_bias=use_bias,
110+
quant=True,
111+
)
112+
66113
def test_qs8_linear(self):
67114
for use_bias in (True, False):
68-
self._test_linear(
69-
lambda in_size, out_size: torch.nn.Linear(
70-
in_size, out_size, bias=use_bias # noqa
71-
),
72-
uses_bias=use_bias,
73-
)
115+
for num_batch_dims in range(1, 3):
116+
self._test_linear(
117+
lambda in_size, out_size: torch.nn.Linear(
118+
in_size, out_size, bias=use_bias # noqa
119+
),
120+
uses_bias=use_bias,
121+
num_batch_dims=num_batch_dims,
122+
)
74123

75124
@unittest.skip("XNNPACK currently only supports per-channel dynamic quantization.")
76125
def test_qd8_per_tensor_linear(self):
77126
for uses_bias in (False, True):
78127
inputs = (torch.randn(2, 4),)
79128
module = torch.nn.Linear(4, 5, bias=uses_bias)
129+
dynamic_shapes = ({0: torch.export.Dim("batch", max=100)},)
80130

81131
self._test_dqlinear(
82132
module,
83133
inputs,
134+
dynamic_shapes=dynamic_shapes,
84135
is_per_channel=False,
85136
uses_bias=uses_bias,
86137
)
@@ -93,6 +144,7 @@ def test_qd8_per_channel_linear(self):
93144
self._test_dqlinear(
94145
module,
95146
inputs,
147+
dynamic_shapes=({0: torch.export.Dim("batch", max=100)},),
96148
is_per_channel=True,
97149
uses_bias=uses_bias,
98150
)
@@ -114,7 +166,7 @@ def test_qd8_per_channel_4w_linear(self):
114166
qconfig = self._get_4b_dqconfig()
115167
input_channels = [2, 63]
116168
output_channels = [1, 8, 127]
117-
batches = [1, 2]
169+
batches = [2, 2]
118170
use_bias = [False, True]
119171

120172
for bs, bias, ipc, opc in product(
@@ -129,13 +181,14 @@ def test_qd8_per_channel_4w_linear(self):
129181
self._test_dqlinear(
130182
module,
131183
inputs,
184+
dynamic_shapes=({0: torch.export.Dim("batch", max=100)},),
132185
is_per_channel=True,
133186
uses_bias=bias,
134187
qconfig=qconfig,
135188
)
136189

137190
def test_qd8_per_channel_linear_parallel(self):
138-
in_size = 1
191+
in_size = 2
139192
input_size = 4
140193
output_size = 5
141194

@@ -165,17 +218,39 @@ def forward(self, x, y):
165218
torch.rand(in_size, input_size, dtype=torch.float),
166219
torch.rand(in_size, input_size, dtype=torch.float),
167220
)
221+
batch_dim = torch.export.Dim("batch", max=100)
222+
dynamic_shapes = ({0: batch_dim}, {0: batch_dim})
168223

169224
self._test_dqlinear(
170225
ParallelLinear(),
171226
inputs,
227+
dynamic_shapes=dynamic_shapes,
172228
linear_count=2,
173229
is_per_channel=True,
174230
uses_bias=True,
175231
)
176232

233+
def test_qd8_per_channel_linear_with_two_batch(self):
234+
in_size = 2
235+
input_size = 4
236+
output_size = 5
237+
238+
linear = torch.nn.Linear(input_size, output_size)
239+
inputs = (torch.randn(2, in_size, input_size, dtype=torch.float),)
240+
batch_dim = torch.export.Dim("batch", max=100)
241+
dynamic_shapes = ({0: batch_dim, 1: batch_dim},)
242+
243+
self._test_dqlinear(
244+
linear,
245+
inputs,
246+
dynamic_shapes=dynamic_shapes,
247+
linear_count=1,
248+
is_per_channel=True,
249+
uses_bias=True,
250+
)
251+
177252
def test_qd8_per_channel_linear_sequential(self):
178-
in_size = 1
253+
in_size = 2
179254
input_size = 4
180255
intermediate_size = 5
181256
output_size = 3
@@ -203,17 +278,20 @@ def forward(self, x):
203278
return b
204279

205280
inputs = (torch.rand(in_size, input_size, dtype=torch.float),)
281+
dynamic_shapes = ({0: torch.export.Dim("batch", max=100)},)
206282

207283
self._test_dqlinear(
208284
LinearSequential(),
209285
inputs,
286+
dynamic_shapes=dynamic_shapes,
210287
linear_count=2,
211288
is_per_channel=True,
212289
uses_bias=True,
290+
atol=1e-1,
213291
)
214292

215293
def test_qd8_per_channel_linear_parellel_and_sequential(self):
216-
in_size = 1
294+
in_size = 2
217295
input_size = 4
218296
intermediate_size = 5
219297
output_size = 3
@@ -252,50 +330,21 @@ def forward(self, x, y):
252330
torch.rand(in_size, input_size, dtype=torch.float),
253331
torch.rand(in_size, input_size, dtype=torch.float),
254332
)
333+
dynamic_shapes = (
334+
{0: torch.export.Dim("batch", max=100)},
335+
{0: torch.export.Dim("batch2", max=100)},
336+
)
255337

256338
self._test_dqlinear(
257-
LinearModule(), inputs, linear_count=3, is_per_channel=True, uses_bias=True
339+
LinearModule(),
340+
inputs,
341+
dynamic_shapes=dynamic_shapes,
342+
linear_count=3,
343+
is_per_channel=True,
344+
uses_bias=True,
345+
atol=1e-1,
258346
)
259347

260-
def test_fp32_linear_fused_relu(self):
261-
class LinearReluModule(torch.nn.Module):
262-
def __init__(self, in_size, out_size, use_bias):
263-
super().__init__()
264-
self.linear = torch.nn.Linear(in_size, out_size, bias=use_bias)
265-
266-
def forward(self, x):
267-
return torch.nn.functional.relu(self.linear(x))
268-
269-
for use_bias in (True, False):
270-
self._test_linear(
271-
lambda in_size, out_size: LinearReluModule(
272-
in_size,
273-
out_size,
274-
use_bias, # noqa
275-
),
276-
uses_bias=use_bias,
277-
)
278-
279-
def test_qs8_linear_fused_relu(self):
280-
class LinearReluModule(torch.nn.Module):
281-
def __init__(self, in_size, out_size, use_bias):
282-
super().__init__()
283-
self.linear = torch.nn.Linear(in_size, out_size, bias=use_bias)
284-
285-
def forward(self, x):
286-
return torch.nn.functional.relu(self.linear(x))
287-
288-
for use_bias in (True, False):
289-
self._test_linear(
290-
lambda in_size, out_size: LinearReluModule(
291-
in_size,
292-
out_size,
293-
use_bias, # noqa
294-
),
295-
uses_bias=use_bias,
296-
quant=True,
297-
)
298-
299348
class ManualDQLinear(torch.nn.Module):
300349
def __init__(
301350
self,
@@ -676,6 +725,7 @@ def _test_linear(
676725
self,
677726
make_module,
678727
uses_bias,
728+
num_batch_dims=1,
679729
quant=False,
680730
dtype: torch.dtype = torch.float,
681731
atol=1e-03,
@@ -692,7 +742,7 @@ def _test_linear(
692742
)
693743
)
694744

695-
in_sizes = [1, 4, 4]
745+
in_sizes = [3, 4, 4]
696746
input_sizes = [4, 37, 17]
697747
output_sizes = [4, 17, 37]
698748

@@ -704,11 +754,19 @@ def _test_linear(
704754
in_size = int(in_sizes[i])
705755
input_size = int(input_sizes[i])
706756
output_size = int(output_sizes[i])
757+
input_shape = [in_size] * num_batch_dims + [input_size]
758+
print(f"Testing input_shape {input_shape} with {output_size} out_channels")
707759

708760
module = make_module(input_size, output_size).eval().to(dtype)
709-
inputs = (torch.randn(in_size, input_size).to(dtype),)
761+
inputs = (torch.randn(input_shape).to(dtype),)
762+
dynamic_shape = {}
763+
for i in range(num_batch_dims):
764+
dynamic_shape[i] = torch.export.Dim(f"batch{i}", min=2, max=in_size)
765+
766+
dynamic_shape = (dynamic_shape,)
767+
print(dynamic_shape)
710768

711-
tester = Tester(module, inputs)
769+
tester = Tester(module, inputs, dynamic_shapes=dynamic_shape)
712770

713771
if quant:
714772
tester.quantize()
@@ -736,10 +794,12 @@ def _test_dqlinear(
736794
self,
737795
module,
738796
inputs,
797+
dynamic_shapes,
739798
linear_count=1,
740799
is_per_channel=False,
741800
uses_bias=False,
742801
qconfig: Optional[QuantizationConfig] = None,
802+
atol=5e-02,
743803
):
744804
aten_op, edge_op = (
745805
(
@@ -758,13 +818,12 @@ def _test_dqlinear(
758818
is_dynamic=True,
759819
)
760820

761-
tester = Tester(module, inputs)
821+
tester = Tester(module, inputs, dynamic_shapes=dynamic_shapes)
762822
tester.quantize(Quantize(quantization_config=quant_config))
763823

764824
tester.export()
765825
tester.check_count({aten_op: linear_count})
766826
tester.check(["torch.ops.quantized_decomposed"])
767-
tester.dump_artifact()
768827
tester.to_edge()
769828
tester.check_count({edge_op: linear_count})
770829

@@ -776,4 +835,4 @@ def _test_dqlinear(
776835

777836
tester.to_executorch()
778837
tester.serialize()
779-
tester.run_method_and_compare_outputs(atol=5e-02)
838+
tester.run_method_and_compare_outputs(atol=atol)

0 commit comments

Comments
 (0)