26
26
class TestLinear (unittest .TestCase ):
27
27
def test_fp16_linear (self ):
28
28
for use_bias in (True , False ):
29
- self ._test_linear (
30
- lambda in_size , out_size : torch .nn .Linear (
31
- in_size , out_size , bias = use_bias # noqa
32
- ),
33
- uses_bias = use_bias ,
34
- dtype = torch .float16 ,
35
- atol = 5e-2 ,
36
- )
29
+ for num_batch_dims in range (1 , 3 ):
30
+ self ._test_linear (
31
+ lambda in_size , out_size : torch .nn .Linear (
32
+ in_size , out_size , bias = use_bias # noqa
33
+ ),
34
+ num_batch_dims = num_batch_dims ,
35
+ uses_bias = use_bias ,
36
+ dtype = torch .float16 ,
37
+ atol = 5e-2 ,
38
+ )
37
39
38
40
def test_fp32_linear (self ):
39
41
for use_bias in (True , False ):
40
- self ._test_linear (
41
- lambda in_size , out_size : torch .nn .Linear (
42
- in_size , out_size , bias = use_bias # noqa
43
- ),
44
- uses_bias = use_bias ,
45
- )
42
+ for num_batch_dims in range (1 , 3 ):
43
+ self ._test_linear (
44
+ lambda in_size , out_size : torch .nn .Linear (
45
+ in_size , out_size , bias = use_bias # noqa
46
+ ),
47
+ uses_bias = use_bias ,
48
+ num_batch_dims = num_batch_dims ,
49
+ )
46
50
47
51
def test_fp32_addmm (self ):
48
52
"""
@@ -63,24 +67,71 @@ def forward(self, x):
63
67
uses_bias = True ,
64
68
)
65
69
70
+ def test_fp32_linear_fused_relu (self ):
71
+ class LinearReluModule (torch .nn .Module ):
72
+ def __init__ (self , in_size , out_size , use_bias ):
73
+ super ().__init__ ()
74
+ self .linear = torch .nn .Linear (in_size , out_size , bias = use_bias )
75
+
76
+ def forward (self , x ):
77
+ return torch .nn .functional .relu (self .linear (x ))
78
+
79
+ for use_bias in (True , False ):
80
+ for num_batch_dims in range (1 , 3 ):
81
+ self ._test_linear (
82
+ lambda in_size , out_size : LinearReluModule (
83
+ in_size ,
84
+ out_size ,
85
+ use_bias , # noqa
86
+ ),
87
+ uses_bias = use_bias ,
88
+ num_batch_dims = num_batch_dims ,
89
+ )
90
+
91
+ def test_qs8_linear_fused_relu (self ):
92
+ class LinearReluModule (torch .nn .Module ):
93
+ def __init__ (self , in_size , out_size , use_bias ):
94
+ super ().__init__ ()
95
+ self .linear = torch .nn .Linear (in_size , out_size , bias = use_bias )
96
+
97
+ def forward (self , x ):
98
+ return torch .nn .functional .relu (self .linear (x ))
99
+
100
+ for use_bias in (True , False ):
101
+ for num_batch_dims in range (1 , 3 ):
102
+ self ._test_linear (
103
+ lambda in_size , out_size : LinearReluModule (
104
+ in_size ,
105
+ out_size ,
106
+ use_bias , # noqa
107
+ ),
108
+ num_batch_dims = num_batch_dims ,
109
+ uses_bias = use_bias ,
110
+ quant = True ,
111
+ )
112
+
66
113
def test_qs8_linear (self ):
67
114
for use_bias in (True , False ):
68
- self ._test_linear (
69
- lambda in_size , out_size : torch .nn .Linear (
70
- in_size , out_size , bias = use_bias # noqa
71
- ),
72
- uses_bias = use_bias ,
73
- )
115
+ for num_batch_dims in range (1 , 3 ):
116
+ self ._test_linear (
117
+ lambda in_size , out_size : torch .nn .Linear (
118
+ in_size , out_size , bias = use_bias # noqa
119
+ ),
120
+ uses_bias = use_bias ,
121
+ num_batch_dims = num_batch_dims ,
122
+ )
74
123
75
124
@unittest .skip ("XNNPACK currently only supports per-channel dynamic quantization." )
76
125
def test_qd8_per_tensor_linear (self ):
77
126
for uses_bias in (False , True ):
78
127
inputs = (torch .randn (2 , 4 ),)
79
128
module = torch .nn .Linear (4 , 5 , bias = uses_bias )
129
+ dynamic_shapes = ({0 : torch .export .Dim ("batch" , max = 100 )},)
80
130
81
131
self ._test_dqlinear (
82
132
module ,
83
133
inputs ,
134
+ dynamic_shapes = dynamic_shapes ,
84
135
is_per_channel = False ,
85
136
uses_bias = uses_bias ,
86
137
)
@@ -93,6 +144,7 @@ def test_qd8_per_channel_linear(self):
93
144
self ._test_dqlinear (
94
145
module ,
95
146
inputs ,
147
+ dynamic_shapes = ({0 : torch .export .Dim ("batch" , max = 100 )},),
96
148
is_per_channel = True ,
97
149
uses_bias = uses_bias ,
98
150
)
@@ -114,7 +166,7 @@ def test_qd8_per_channel_4w_linear(self):
114
166
qconfig = self ._get_4b_dqconfig ()
115
167
input_channels = [2 , 63 ]
116
168
output_channels = [1 , 8 , 127 ]
117
- batches = [1 , 2 ]
169
+ batches = [2 , 2 ]
118
170
use_bias = [False , True ]
119
171
120
172
for bs , bias , ipc , opc in product (
@@ -129,13 +181,14 @@ def test_qd8_per_channel_4w_linear(self):
129
181
self ._test_dqlinear (
130
182
module ,
131
183
inputs ,
184
+ dynamic_shapes = ({0 : torch .export .Dim ("batch" , max = 100 )},),
132
185
is_per_channel = True ,
133
186
uses_bias = bias ,
134
187
qconfig = qconfig ,
135
188
)
136
189
137
190
def test_qd8_per_channel_linear_parallel (self ):
138
- in_size = 1
191
+ in_size = 2
139
192
input_size = 4
140
193
output_size = 5
141
194
@@ -165,17 +218,39 @@ def forward(self, x, y):
165
218
torch .rand (in_size , input_size , dtype = torch .float ),
166
219
torch .rand (in_size , input_size , dtype = torch .float ),
167
220
)
221
+ batch_dim = torch .export .Dim ("batch" , max = 100 )
222
+ dynamic_shapes = ({0 : batch_dim }, {0 : batch_dim })
168
223
169
224
self ._test_dqlinear (
170
225
ParallelLinear (),
171
226
inputs ,
227
+ dynamic_shapes = dynamic_shapes ,
172
228
linear_count = 2 ,
173
229
is_per_channel = True ,
174
230
uses_bias = True ,
175
231
)
176
232
233
+ def test_qd8_per_channel_linear_with_two_batch (self ):
234
+ in_size = 2
235
+ input_size = 4
236
+ output_size = 5
237
+
238
+ linear = torch .nn .Linear (input_size , output_size )
239
+ inputs = (torch .randn (2 , in_size , input_size , dtype = torch .float ),)
240
+ batch_dim = torch .export .Dim ("batch" , max = 100 )
241
+ dynamic_shapes = ({0 : batch_dim , 1 : batch_dim },)
242
+
243
+ self ._test_dqlinear (
244
+ linear ,
245
+ inputs ,
246
+ dynamic_shapes = dynamic_shapes ,
247
+ linear_count = 1 ,
248
+ is_per_channel = True ,
249
+ uses_bias = True ,
250
+ )
251
+
177
252
def test_qd8_per_channel_linear_sequential (self ):
178
- in_size = 1
253
+ in_size = 2
179
254
input_size = 4
180
255
intermediate_size = 5
181
256
output_size = 3
@@ -203,17 +278,20 @@ def forward(self, x):
203
278
return b
204
279
205
280
inputs = (torch .rand (in_size , input_size , dtype = torch .float ),)
281
+ dynamic_shapes = ({0 : torch .export .Dim ("batch" , max = 100 )},)
206
282
207
283
self ._test_dqlinear (
208
284
LinearSequential (),
209
285
inputs ,
286
+ dynamic_shapes = dynamic_shapes ,
210
287
linear_count = 2 ,
211
288
is_per_channel = True ,
212
289
uses_bias = True ,
290
+ atol = 1e-1 ,
213
291
)
214
292
215
293
def test_qd8_per_channel_linear_parellel_and_sequential (self ):
216
- in_size = 1
294
+ in_size = 2
217
295
input_size = 4
218
296
intermediate_size = 5
219
297
output_size = 3
@@ -252,50 +330,21 @@ def forward(self, x, y):
252
330
torch .rand (in_size , input_size , dtype = torch .float ),
253
331
torch .rand (in_size , input_size , dtype = torch .float ),
254
332
)
333
+ dynamic_shapes = (
334
+ {0 : torch .export .Dim ("batch" , max = 100 )},
335
+ {0 : torch .export .Dim ("batch2" , max = 100 )},
336
+ )
255
337
256
338
self ._test_dqlinear (
257
- LinearModule (), inputs , linear_count = 3 , is_per_channel = True , uses_bias = True
339
+ LinearModule (),
340
+ inputs ,
341
+ dynamic_shapes = dynamic_shapes ,
342
+ linear_count = 3 ,
343
+ is_per_channel = True ,
344
+ uses_bias = True ,
345
+ atol = 1e-1 ,
258
346
)
259
347
260
- def test_fp32_linear_fused_relu (self ):
261
- class LinearReluModule (torch .nn .Module ):
262
- def __init__ (self , in_size , out_size , use_bias ):
263
- super ().__init__ ()
264
- self .linear = torch .nn .Linear (in_size , out_size , bias = use_bias )
265
-
266
- def forward (self , x ):
267
- return torch .nn .functional .relu (self .linear (x ))
268
-
269
- for use_bias in (True , False ):
270
- self ._test_linear (
271
- lambda in_size , out_size : LinearReluModule (
272
- in_size ,
273
- out_size ,
274
- use_bias , # noqa
275
- ),
276
- uses_bias = use_bias ,
277
- )
278
-
279
- def test_qs8_linear_fused_relu (self ):
280
- class LinearReluModule (torch .nn .Module ):
281
- def __init__ (self , in_size , out_size , use_bias ):
282
- super ().__init__ ()
283
- self .linear = torch .nn .Linear (in_size , out_size , bias = use_bias )
284
-
285
- def forward (self , x ):
286
- return torch .nn .functional .relu (self .linear (x ))
287
-
288
- for use_bias in (True , False ):
289
- self ._test_linear (
290
- lambda in_size , out_size : LinearReluModule (
291
- in_size ,
292
- out_size ,
293
- use_bias , # noqa
294
- ),
295
- uses_bias = use_bias ,
296
- quant = True ,
297
- )
298
-
299
348
class ManualDQLinear (torch .nn .Module ):
300
349
def __init__ (
301
350
self ,
@@ -676,6 +725,7 @@ def _test_linear(
676
725
self ,
677
726
make_module ,
678
727
uses_bias ,
728
+ num_batch_dims = 1 ,
679
729
quant = False ,
680
730
dtype : torch .dtype = torch .float ,
681
731
atol = 1e-03 ,
@@ -692,7 +742,7 @@ def _test_linear(
692
742
)
693
743
)
694
744
695
- in_sizes = [1 , 4 , 4 ]
745
+ in_sizes = [3 , 4 , 4 ]
696
746
input_sizes = [4 , 37 , 17 ]
697
747
output_sizes = [4 , 17 , 37 ]
698
748
@@ -704,11 +754,19 @@ def _test_linear(
704
754
in_size = int (in_sizes [i ])
705
755
input_size = int (input_sizes [i ])
706
756
output_size = int (output_sizes [i ])
757
+ input_shape = [in_size ] * num_batch_dims + [input_size ]
758
+ print (f"Testing input_shape { input_shape } with { output_size } out_channels" )
707
759
708
760
module = make_module (input_size , output_size ).eval ().to (dtype )
709
- inputs = (torch .randn (in_size , input_size ).to (dtype ),)
761
+ inputs = (torch .randn (input_shape ).to (dtype ),)
762
+ dynamic_shape = {}
763
+ for i in range (num_batch_dims ):
764
+ dynamic_shape [i ] = torch .export .Dim (f"batch{ i } " , min = 2 , max = in_size )
765
+
766
+ dynamic_shape = (dynamic_shape ,)
767
+ print (dynamic_shape )
710
768
711
- tester = Tester (module , inputs )
769
+ tester = Tester (module , inputs , dynamic_shapes = dynamic_shape )
712
770
713
771
if quant :
714
772
tester .quantize ()
@@ -736,10 +794,12 @@ def _test_dqlinear(
736
794
self ,
737
795
module ,
738
796
inputs ,
797
+ dynamic_shapes ,
739
798
linear_count = 1 ,
740
799
is_per_channel = False ,
741
800
uses_bias = False ,
742
801
qconfig : Optional [QuantizationConfig ] = None ,
802
+ atol = 5e-02 ,
743
803
):
744
804
aten_op , edge_op = (
745
805
(
@@ -758,13 +818,12 @@ def _test_dqlinear(
758
818
is_dynamic = True ,
759
819
)
760
820
761
- tester = Tester (module , inputs )
821
+ tester = Tester (module , inputs , dynamic_shapes = dynamic_shapes )
762
822
tester .quantize (Quantize (quantization_config = quant_config ))
763
823
764
824
tester .export ()
765
825
tester .check_count ({aten_op : linear_count })
766
826
tester .check (["torch.ops.quantized_decomposed" ])
767
- tester .dump_artifact ()
768
827
tester .to_edge ()
769
828
tester .check_count ({edge_op : linear_count })
770
829
@@ -776,4 +835,4 @@ def _test_dqlinear(
776
835
777
836
tester .to_executorch ()
778
837
tester .serialize ()
779
- tester .run_method_and_compare_outputs (atol = 5e-02 )
838
+ tester .run_method_and_compare_outputs (atol = atol )
0 commit comments