25
25
class TestLinear (unittest .TestCase ):
26
26
def test_fp16_linear (self ):
27
27
for use_bias in (True , False ):
28
- self ._test_linear (
29
- lambda in_size , out_size : torch .nn .Linear (
30
- in_size , out_size , bias = use_bias # noqa
31
- ),
32
- uses_bias = use_bias ,
33
- dtype = torch .float16 ,
34
- atol = 5e-2 ,
35
- )
28
+ for num_batch_dims in range (1 , 3 ):
29
+ self ._test_linear (
30
+ lambda in_size , out_size : torch .nn .Linear (
31
+ in_size , out_size , bias = use_bias # noqa
32
+ ),
33
+ num_batch_dims = num_batch_dims ,
34
+ uses_bias = use_bias ,
35
+ dtype = torch .float16 ,
36
+ atol = 5e-2 ,
37
+ )
36
38
37
39
def test_fp32_linear (self ):
38
40
for use_bias in (True , False ):
39
- self ._test_linear (
40
- lambda in_size , out_size : torch .nn .Linear (
41
- in_size , out_size , bias = use_bias # noqa
42
- ),
43
- uses_bias = use_bias ,
44
- )
41
+ for num_batch_dims in range (1 , 3 ):
42
+ self ._test_linear (
43
+ lambda in_size , out_size : torch .nn .Linear (
44
+ in_size , out_size , bias = use_bias # noqa
45
+ ),
46
+ uses_bias = use_bias ,
47
+ num_batch_dims = num_batch_dims ,
48
+ )
45
49
46
50
def test_fp32_addmm (self ):
47
51
"""
@@ -62,24 +66,71 @@ def forward(self, x):
62
66
uses_bias = True ,
63
67
)
64
68
69
+ def test_fp32_linear_fused_relu (self ):
70
+ class LinearReluModule (torch .nn .Module ):
71
+ def __init__ (self , in_size , out_size , use_bias ):
72
+ super ().__init__ ()
73
+ self .linear = torch .nn .Linear (in_size , out_size , bias = use_bias )
74
+
75
+ def forward (self , x ):
76
+ return torch .nn .functional .relu (self .linear (x ))
77
+
78
+ for use_bias in (True , False ):
79
+ for num_batch_dims in range (1 , 3 ):
80
+ self ._test_linear (
81
+ lambda in_size , out_size : LinearReluModule (
82
+ in_size ,
83
+ out_size ,
84
+ use_bias , # noqa
85
+ ),
86
+ uses_bias = use_bias ,
87
+ num_batch_dims = num_batch_dims ,
88
+ )
89
+
90
+ def test_qs8_linear_fused_relu (self ):
91
+ class LinearReluModule (torch .nn .Module ):
92
+ def __init__ (self , in_size , out_size , use_bias ):
93
+ super ().__init__ ()
94
+ self .linear = torch .nn .Linear (in_size , out_size , bias = use_bias )
95
+
96
+ def forward (self , x ):
97
+ return torch .nn .functional .relu (self .linear (x ))
98
+
99
+ for use_bias in (True , False ):
100
+ for num_batch_dims in range (1 , 3 ):
101
+ self ._test_linear (
102
+ lambda in_size , out_size : LinearReluModule (
103
+ in_size ,
104
+ out_size ,
105
+ use_bias , # noqa
106
+ ),
107
+ num_batch_dims = num_batch_dims ,
108
+ uses_bias = use_bias ,
109
+ quant = True ,
110
+ )
111
+
65
112
def test_qs8_linear (self ):
66
113
for use_bias in (True , False ):
67
- self ._test_linear (
68
- lambda in_size , out_size : torch .nn .Linear (
69
- in_size , out_size , bias = use_bias # noqa
70
- ),
71
- uses_bias = use_bias ,
72
- )
114
+ for num_batch_dims in range (1 , 3 ):
115
+ self ._test_linear (
116
+ lambda in_size , out_size : torch .nn .Linear (
117
+ in_size , out_size , bias = use_bias # noqa
118
+ ),
119
+ uses_bias = use_bias ,
120
+ num_batch_dims = num_batch_dims ,
121
+ )
73
122
74
123
@unittest .skip ("XNNPACK currently only supports per-channel dynamic quantization." )
75
124
def test_qd8_per_tensor_linear (self ):
76
125
for uses_bias in (False , True ):
77
126
inputs = (torch .randn (2 , 4 ),)
78
127
module = torch .nn .Linear (4 , 5 , bias = uses_bias )
128
+ dynamic_shapes = ({0 : torch .export .Dim ("batch" , max = 100 )},)
79
129
80
130
self ._test_dqlinear (
81
131
module ,
82
132
inputs ,
133
+ dynamic_shapes = dynamic_shapes ,
83
134
is_per_channel = False ,
84
135
uses_bias = uses_bias ,
85
136
)
@@ -92,6 +143,7 @@ def test_qd8_per_channel_linear(self):
92
143
self ._test_dqlinear (
93
144
module ,
94
145
inputs ,
146
+ dynamic_shapes = ({0 : torch .export .Dim ("batch" , max = 100 )},),
95
147
is_per_channel = True ,
96
148
uses_bias = uses_bias ,
97
149
)
@@ -113,7 +165,7 @@ def test_qd8_per_channel_4w_linear(self):
113
165
qconfig = self ._get_4b_dqconfig ()
114
166
input_channels = [2 , 63 ]
115
167
output_channels = [1 , 8 , 127 ]
116
- batches = [1 , 2 ]
168
+ batches = [2 , 2 ]
117
169
use_bias = [False , True ]
118
170
119
171
for bs , bias , ipc , opc in product (
@@ -128,13 +180,14 @@ def test_qd8_per_channel_4w_linear(self):
128
180
self ._test_dqlinear (
129
181
module ,
130
182
inputs ,
183
+ dynamic_shapes = ({0 : torch .export .Dim ("batch" , max = 100 )},),
131
184
is_per_channel = True ,
132
185
uses_bias = bias ,
133
186
qconfig = qconfig ,
134
187
)
135
188
136
189
def test_qd8_per_channel_linear_parallel (self ):
137
- in_size = 1
190
+ in_size = 2
138
191
input_size = 4
139
192
output_size = 5
140
193
@@ -164,17 +217,39 @@ def forward(self, x, y):
164
217
torch .rand (in_size , input_size , dtype = torch .float ),
165
218
torch .rand (in_size , input_size , dtype = torch .float ),
166
219
)
220
+ batch_dim = torch .export .Dim ("batch" , max = 100 )
221
+ dynamic_shapes = ({0 : batch_dim }, {0 : batch_dim })
167
222
168
223
self ._test_dqlinear (
169
224
ParallelLinear (),
170
225
inputs ,
226
+ dynamic_shapes = dynamic_shapes ,
171
227
linear_count = 2 ,
172
228
is_per_channel = True ,
173
229
uses_bias = True ,
174
230
)
175
231
232
+ def test_qd8_per_channel_linear_with_two_batch (self ):
233
+ in_size = 2
234
+ input_size = 4
235
+ output_size = 5
236
+
237
+ linear = torch .nn .Linear (input_size , output_size )
238
+ inputs = (torch .randn (2 , in_size , input_size , dtype = torch .float ),)
239
+ batch_dim = torch .export .Dim ("batch" , max = 100 )
240
+ dynamic_shapes = ({0 : batch_dim , 1 : batch_dim },)
241
+
242
+ self ._test_dqlinear (
243
+ linear ,
244
+ inputs ,
245
+ dynamic_shapes = dynamic_shapes ,
246
+ linear_count = 1 ,
247
+ is_per_channel = True ,
248
+ uses_bias = True ,
249
+ )
250
+
176
251
def test_qd8_per_channel_linear_sequential (self ):
177
- in_size = 1
252
+ in_size = 2
178
253
input_size = 4
179
254
intermediate_size = 5
180
255
output_size = 3
@@ -202,17 +277,19 @@ def forward(self, x):
202
277
return b
203
278
204
279
inputs = (torch .rand (in_size , input_size , dtype = torch .float ),)
280
+ dynamic_shapes = ({0 : torch .export .Dim ("batch" , max = 100 )},)
205
281
206
282
self ._test_dqlinear (
207
283
LinearSequential (),
208
284
inputs ,
285
+ dynamic_shapes = dynamic_shapes ,
209
286
linear_count = 2 ,
210
287
is_per_channel = True ,
211
288
uses_bias = True ,
212
289
)
213
290
214
291
def test_qd8_per_channel_linear_parellel_and_sequential (self ):
215
- in_size = 1
292
+ in_size = 2
216
293
input_size = 4
217
294
intermediate_size = 5
218
295
output_size = 3
@@ -251,54 +328,26 @@ def forward(self, x, y):
251
328
torch .rand (in_size , input_size , dtype = torch .float ),
252
329
torch .rand (in_size , input_size , dtype = torch .float ),
253
330
)
331
+ dynamic_shapes = (
332
+ {0 : torch .export .Dim ("batch" , max = 100 )},
333
+ {0 : torch .export .Dim ("batch2" , max = 100 )},
334
+ )
254
335
255
336
self ._test_dqlinear (
256
- LinearModule (), inputs , linear_count = 3 , is_per_channel = True , uses_bias = True
337
+ LinearModule (),
338
+ inputs ,
339
+ dynamic_shapes = dynamic_shapes ,
340
+ linear_count = 3 ,
341
+ is_per_channel = True ,
342
+ uses_bias = True ,
343
+ atol = 1e-1 ,
257
344
)
258
345
259
- def test_fp32_linear_fused_relu (self ):
260
- class LinearReluModule (torch .nn .Module ):
261
- def __init__ (self , in_size , out_size , use_bias ):
262
- super ().__init__ ()
263
- self .linear = torch .nn .Linear (in_size , out_size , bias = use_bias )
264
-
265
- def forward (self , x ):
266
- return torch .nn .functional .relu (self .linear (x ))
267
-
268
- for use_bias in (True , False ):
269
- self ._test_linear (
270
- lambda in_size , out_size : LinearReluModule (
271
- in_size ,
272
- out_size ,
273
- use_bias , # noqa
274
- ),
275
- uses_bias = use_bias ,
276
- )
277
-
278
- def test_qs8_linear_fused_relu (self ):
279
- class LinearReluModule (torch .nn .Module ):
280
- def __init__ (self , in_size , out_size , use_bias ):
281
- super ().__init__ ()
282
- self .linear = torch .nn .Linear (in_size , out_size , bias = use_bias )
283
-
284
- def forward (self , x ):
285
- return torch .nn .functional .relu (self .linear (x ))
286
-
287
- for use_bias in (True , False ):
288
- self ._test_linear (
289
- lambda in_size , out_size : LinearReluModule (
290
- in_size ,
291
- out_size ,
292
- use_bias , # noqa
293
- ),
294
- uses_bias = use_bias ,
295
- quant = True ,
296
- )
297
-
298
346
def _test_linear (
299
347
self ,
300
348
make_module ,
301
349
uses_bias ,
350
+ num_batch_dims = 1 ,
302
351
quant = False ,
303
352
dtype : torch .dtype = torch .float ,
304
353
atol = 1e-03 ,
@@ -315,7 +364,7 @@ def _test_linear(
315
364
)
316
365
)
317
366
318
- in_sizes = [1 , 4 , 4 ]
367
+ in_sizes = [3 , 4 , 4 ]
319
368
input_sizes = [4 , 37 , 17 ]
320
369
output_sizes = [4 , 17 , 37 ]
321
370
@@ -327,12 +376,18 @@ def _test_linear(
327
376
in_size = int (in_sizes [i ])
328
377
input_size = int (input_sizes [i ])
329
378
output_size = int (output_sizes [i ])
330
- print (f"Testing { in_size } { input_size } { output_size } " )
379
+ input_shape = [in_size ] * num_batch_dims + [input_size ]
380
+ print (f"Testing input_shape { input_shape } with { output_size } out_channels" )
331
381
332
382
module = make_module (input_size , output_size ).eval ().to (dtype )
333
- inputs = (torch .randn (in_size , input_size ).to (dtype ),)
383
+ inputs = (torch .randn (input_shape ).to (dtype ),)
384
+ dynamic_shape = {}
385
+ for i in range (num_batch_dims ):
386
+ dynamic_shape [i ] = torch .export .Dim (f"batch{ i } " , max = in_size )
387
+
388
+ dynamic_shape = (dynamic_shape ,)
334
389
335
- tester = Tester (module , inputs )
390
+ tester = Tester (module , inputs , dynamic_shapes = dynamic_shape )
336
391
337
392
if quant :
338
393
tester .quantize ()
@@ -360,10 +415,12 @@ def _test_dqlinear(
360
415
self ,
361
416
module ,
362
417
inputs ,
418
+ dynamic_shapes ,
363
419
linear_count = 1 ,
364
420
is_per_channel = False ,
365
421
uses_bias = False ,
366
422
qconfig : Optional [QuantizationConfig ] = None ,
423
+ atol = 5e-02 ,
367
424
):
368
425
aten_op , edge_op = (
369
426
(
@@ -382,13 +439,12 @@ def _test_dqlinear(
382
439
is_dynamic = True ,
383
440
)
384
441
385
- tester = Tester (module , inputs )
442
+ tester = Tester (module , inputs , dynamic_shapes = dynamic_shapes )
386
443
tester .quantize (Quantize (quantization_config = quant_config ))
387
444
388
445
tester .export ()
389
446
tester .check_count ({aten_op : linear_count })
390
447
tester .check (["torch.ops.quantized_decomposed" ])
391
- tester .dump_artifact ()
392
448
tester .to_edge ()
393
449
tester .check_count ({edge_op : linear_count })
394
450
@@ -400,4 +456,4 @@ def _test_dqlinear(
400
456
401
457
tester .to_executorch ()
402
458
tester .serialize ()
403
- tester .run_method_and_compare_outputs (atol = 5e-02 )
459
+ tester .run_method_and_compare_outputs (atol = atol )
0 commit comments