Skip to content

Commit 46b4ddc

Browse files
mcr229facebook-github-bot
authored andcommitted
dynamic qd8-fc test with 2 batch dims
Summary: Adding the first dynamic input test, in which we test DQ Linear where it's inputs have rank = 3. Reviewed By: digantdesai Differential Revision: D54665767
1 parent 00e076c commit 46b4ddc

File tree

1 file changed

+127
-71
lines changed

1 file changed

+127
-71
lines changed

backends/xnnpack/test/ops/linear.py

Lines changed: 127 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -25,23 +25,27 @@
2525
class TestLinear(unittest.TestCase):
2626
def test_fp16_linear(self):
2727
for use_bias in (True, False):
28-
self._test_linear(
29-
lambda in_size, out_size: torch.nn.Linear(
30-
in_size, out_size, bias=use_bias # noqa
31-
),
32-
uses_bias=use_bias,
33-
dtype=torch.float16,
34-
atol=5e-2,
35-
)
28+
for num_batch_dims in range(1, 3):
29+
self._test_linear(
30+
lambda in_size, out_size: torch.nn.Linear(
31+
in_size, out_size, bias=use_bias # noqa
32+
),
33+
num_batch_dims=num_batch_dims,
34+
uses_bias=use_bias,
35+
dtype=torch.float16,
36+
atol=5e-2,
37+
)
3638

3739
def test_fp32_linear(self):
3840
for use_bias in (True, False):
39-
self._test_linear(
40-
lambda in_size, out_size: torch.nn.Linear(
41-
in_size, out_size, bias=use_bias # noqa
42-
),
43-
uses_bias=use_bias,
44-
)
41+
for num_batch_dims in range(1, 3):
42+
self._test_linear(
43+
lambda in_size, out_size: torch.nn.Linear(
44+
in_size, out_size, bias=use_bias # noqa
45+
),
46+
uses_bias=use_bias,
47+
num_batch_dims=num_batch_dims,
48+
)
4549

4650
def test_fp32_addmm(self):
4751
"""
@@ -62,24 +66,71 @@ def forward(self, x):
6266
uses_bias=True,
6367
)
6468

69+
def test_fp32_linear_fused_relu(self):
70+
class LinearReluModule(torch.nn.Module):
71+
def __init__(self, in_size, out_size, use_bias):
72+
super().__init__()
73+
self.linear = torch.nn.Linear(in_size, out_size, bias=use_bias)
74+
75+
def forward(self, x):
76+
return torch.nn.functional.relu(self.linear(x))
77+
78+
for use_bias in (True, False):
79+
for num_batch_dims in range(1, 3):
80+
self._test_linear(
81+
lambda in_size, out_size: LinearReluModule(
82+
in_size,
83+
out_size,
84+
use_bias, # noqa
85+
),
86+
uses_bias=use_bias,
87+
num_batch_dims=num_batch_dims,
88+
)
89+
90+
def test_qs8_linear_fused_relu(self):
91+
class LinearReluModule(torch.nn.Module):
92+
def __init__(self, in_size, out_size, use_bias):
93+
super().__init__()
94+
self.linear = torch.nn.Linear(in_size, out_size, bias=use_bias)
95+
96+
def forward(self, x):
97+
return torch.nn.functional.relu(self.linear(x))
98+
99+
for use_bias in (True, False):
100+
for num_batch_dims in range(1, 3):
101+
self._test_linear(
102+
lambda in_size, out_size: LinearReluModule(
103+
in_size,
104+
out_size,
105+
use_bias, # noqa
106+
),
107+
num_batch_dims=num_batch_dims,
108+
uses_bias=use_bias,
109+
quant=True,
110+
)
111+
65112
def test_qs8_linear(self):
66113
for use_bias in (True, False):
67-
self._test_linear(
68-
lambda in_size, out_size: torch.nn.Linear(
69-
in_size, out_size, bias=use_bias # noqa
70-
),
71-
uses_bias=use_bias,
72-
)
114+
for num_batch_dims in range(1, 3):
115+
self._test_linear(
116+
lambda in_size, out_size: torch.nn.Linear(
117+
in_size, out_size, bias=use_bias # noqa
118+
),
119+
uses_bias=use_bias,
120+
num_batch_dims=num_batch_dims,
121+
)
73122

74123
@unittest.skip("XNNPACK currently only supports per-channel dynamic quantization.")
75124
def test_qd8_per_tensor_linear(self):
76125
for uses_bias in (False, True):
77126
inputs = (torch.randn(2, 4),)
78127
module = torch.nn.Linear(4, 5, bias=uses_bias)
128+
dynamic_shapes = ({0: torch.export.Dim("batch", max=100)},)
79129

80130
self._test_dqlinear(
81131
module,
82132
inputs,
133+
dynamic_shapes=dynamic_shapes,
83134
is_per_channel=False,
84135
uses_bias=uses_bias,
85136
)
@@ -92,6 +143,7 @@ def test_qd8_per_channel_linear(self):
92143
self._test_dqlinear(
93144
module,
94145
inputs,
146+
dynamic_shapes=({0: torch.export.Dim("batch", max=100)},),
95147
is_per_channel=True,
96148
uses_bias=uses_bias,
97149
)
@@ -113,7 +165,7 @@ def test_qd8_per_channel_4w_linear(self):
113165
qconfig = self._get_4b_dqconfig()
114166
input_channels = [2, 63]
115167
output_channels = [1, 8, 127]
116-
batches = [1, 2]
168+
batches = [2, 2]
117169
use_bias = [False, True]
118170

119171
for bs, bias, ipc, opc in product(
@@ -128,13 +180,14 @@ def test_qd8_per_channel_4w_linear(self):
128180
self._test_dqlinear(
129181
module,
130182
inputs,
183+
dynamic_shapes=({0: torch.export.Dim("batch", max=100)},),
131184
is_per_channel=True,
132185
uses_bias=bias,
133186
qconfig=qconfig,
134187
)
135188

136189
def test_qd8_per_channel_linear_parallel(self):
137-
in_size = 1
190+
in_size = 2
138191
input_size = 4
139192
output_size = 5
140193

@@ -164,17 +217,39 @@ def forward(self, x, y):
164217
torch.rand(in_size, input_size, dtype=torch.float),
165218
torch.rand(in_size, input_size, dtype=torch.float),
166219
)
220+
batch_dim = torch.export.Dim("batch", max=100)
221+
dynamic_shapes = ({0: batch_dim}, {0: batch_dim})
167222

168223
self._test_dqlinear(
169224
ParallelLinear(),
170225
inputs,
226+
dynamic_shapes=dynamic_shapes,
171227
linear_count=2,
172228
is_per_channel=True,
173229
uses_bias=True,
174230
)
175231

232+
def test_qd8_per_channel_linear_with_two_batch(self):
233+
in_size = 2
234+
input_size = 4
235+
output_size = 5
236+
237+
linear = torch.nn.Linear(input_size, output_size)
238+
inputs = (torch.randn(2, in_size, input_size, dtype=torch.float),)
239+
batch_dim = torch.export.Dim("batch", max=100)
240+
dynamic_shapes = ({0: batch_dim, 1: batch_dim},)
241+
242+
self._test_dqlinear(
243+
linear,
244+
inputs,
245+
dynamic_shapes=dynamic_shapes,
246+
linear_count=1,
247+
is_per_channel=True,
248+
uses_bias=True,
249+
)
250+
176251
def test_qd8_per_channel_linear_sequential(self):
177-
in_size = 1
252+
in_size = 2
178253
input_size = 4
179254
intermediate_size = 5
180255
output_size = 3
@@ -202,17 +277,19 @@ def forward(self, x):
202277
return b
203278

204279
inputs = (torch.rand(in_size, input_size, dtype=torch.float),)
280+
dynamic_shapes = ({0: torch.export.Dim("batch", max=100)},)
205281

206282
self._test_dqlinear(
207283
LinearSequential(),
208284
inputs,
285+
dynamic_shapes=dynamic_shapes,
209286
linear_count=2,
210287
is_per_channel=True,
211288
uses_bias=True,
212289
)
213290

214291
def test_qd8_per_channel_linear_parellel_and_sequential(self):
215-
in_size = 1
292+
in_size = 2
216293
input_size = 4
217294
intermediate_size = 5
218295
output_size = 3
@@ -251,54 +328,26 @@ def forward(self, x, y):
251328
torch.rand(in_size, input_size, dtype=torch.float),
252329
torch.rand(in_size, input_size, dtype=torch.float),
253330
)
331+
dynamic_shapes = (
332+
{0: torch.export.Dim("batch", max=100)},
333+
{0: torch.export.Dim("batch2", max=100)},
334+
)
254335

255336
self._test_dqlinear(
256-
LinearModule(), inputs, linear_count=3, is_per_channel=True, uses_bias=True
337+
LinearModule(),
338+
inputs,
339+
dynamic_shapes=dynamic_shapes,
340+
linear_count=3,
341+
is_per_channel=True,
342+
uses_bias=True,
343+
atol=1e-1,
257344
)
258345

259-
def test_fp32_linear_fused_relu(self):
260-
class LinearReluModule(torch.nn.Module):
261-
def __init__(self, in_size, out_size, use_bias):
262-
super().__init__()
263-
self.linear = torch.nn.Linear(in_size, out_size, bias=use_bias)
264-
265-
def forward(self, x):
266-
return torch.nn.functional.relu(self.linear(x))
267-
268-
for use_bias in (True, False):
269-
self._test_linear(
270-
lambda in_size, out_size: LinearReluModule(
271-
in_size,
272-
out_size,
273-
use_bias, # noqa
274-
),
275-
uses_bias=use_bias,
276-
)
277-
278-
def test_qs8_linear_fused_relu(self):
279-
class LinearReluModule(torch.nn.Module):
280-
def __init__(self, in_size, out_size, use_bias):
281-
super().__init__()
282-
self.linear = torch.nn.Linear(in_size, out_size, bias=use_bias)
283-
284-
def forward(self, x):
285-
return torch.nn.functional.relu(self.linear(x))
286-
287-
for use_bias in (True, False):
288-
self._test_linear(
289-
lambda in_size, out_size: LinearReluModule(
290-
in_size,
291-
out_size,
292-
use_bias, # noqa
293-
),
294-
uses_bias=use_bias,
295-
quant=True,
296-
)
297-
298346
def _test_linear(
299347
self,
300348
make_module,
301349
uses_bias,
350+
num_batch_dims=1,
302351
quant=False,
303352
dtype: torch.dtype = torch.float,
304353
atol=1e-03,
@@ -315,7 +364,7 @@ def _test_linear(
315364
)
316365
)
317366

318-
in_sizes = [1, 4, 4]
367+
in_sizes = [3, 4, 4]
319368
input_sizes = [4, 37, 17]
320369
output_sizes = [4, 17, 37]
321370

@@ -327,12 +376,18 @@ def _test_linear(
327376
in_size = int(in_sizes[i])
328377
input_size = int(input_sizes[i])
329378
output_size = int(output_sizes[i])
330-
print(f"Testing {in_size} {input_size} {output_size}")
379+
input_shape = [in_size] * num_batch_dims + [input_size]
380+
print(f"Testing input_shape {input_shape} with {output_size} out_channels")
331381

332382
module = make_module(input_size, output_size).eval().to(dtype)
333-
inputs = (torch.randn(in_size, input_size).to(dtype),)
383+
inputs = (torch.randn(input_shape).to(dtype),)
384+
dynamic_shape = {}
385+
for i in range(num_batch_dims):
386+
dynamic_shape[i] = torch.export.Dim(f"batch{i}", max=in_size)
387+
388+
dynamic_shape = (dynamic_shape,)
334389

335-
tester = Tester(module, inputs)
390+
tester = Tester(module, inputs, dynamic_shapes=dynamic_shape)
336391

337392
if quant:
338393
tester.quantize()
@@ -360,10 +415,12 @@ def _test_dqlinear(
360415
self,
361416
module,
362417
inputs,
418+
dynamic_shapes,
363419
linear_count=1,
364420
is_per_channel=False,
365421
uses_bias=False,
366422
qconfig: Optional[QuantizationConfig] = None,
423+
atol=5e-02,
367424
):
368425
aten_op, edge_op = (
369426
(
@@ -382,13 +439,12 @@ def _test_dqlinear(
382439
is_dynamic=True,
383440
)
384441

385-
tester = Tester(module, inputs)
442+
tester = Tester(module, inputs, dynamic_shapes=dynamic_shapes)
386443
tester.quantize(Quantize(quantization_config=quant_config))
387444

388445
tester.export()
389446
tester.check_count({aten_op: linear_count})
390447
tester.check(["torch.ops.quantized_decomposed"])
391-
tester.dump_artifact()
392448
tester.to_edge()
393449
tester.check_count({edge_op: linear_count})
394450

@@ -400,4 +456,4 @@ def _test_dqlinear(
400456

401457
tester.to_executorch()
402458
tester.serialize()
403-
tester.run_method_and_compare_outputs(atol=5e-02)
459+
tester.run_method_and_compare_outputs(atol=atol)

0 commit comments

Comments
 (0)