26
26
class TestLinear (unittest .TestCase ):
27
27
def test_fp16_linear (self ):
28
28
for use_bias in (True , False ):
29
- self ._test_linear (
30
- lambda in_size , out_size : torch .nn .Linear (
31
- in_size , out_size , bias = use_bias # noqa
32
- ),
33
- uses_bias = use_bias ,
34
- dtype = torch .float16 ,
35
- atol = 5e-2 ,
36
- )
29
+ for num_batch_dims in range (1 , 3 ):
30
+ self ._test_linear (
31
+ lambda in_size , out_size : torch .nn .Linear (
32
+ in_size , out_size , bias = use_bias # noqa
33
+ ),
34
+ num_batch_dims = num_batch_dims ,
35
+ uses_bias = use_bias ,
36
+ dtype = torch .float16 ,
37
+ atol = 5e-2 ,
38
+ )
37
39
38
40
def test_fp32_linear (self ):
39
41
for use_bias in (True , False ):
40
- self ._test_linear (
41
- lambda in_size , out_size : torch .nn .Linear (
42
- in_size , out_size , bias = use_bias # noqa
43
- ),
44
- uses_bias = use_bias ,
45
- )
42
+ for num_batch_dims in range (1 , 3 ):
43
+ self ._test_linear (
44
+ lambda in_size , out_size : torch .nn .Linear (
45
+ in_size , out_size , bias = use_bias # noqa
46
+ ),
47
+ uses_bias = use_bias ,
48
+ num_batch_dims = num_batch_dims ,
49
+ )
46
50
47
51
def test_fp32_addmm (self ):
48
52
"""
@@ -63,24 +67,71 @@ def forward(self, x):
63
67
uses_bias = True ,
64
68
)
65
69
70
+ def test_fp32_linear_fused_relu (self ):
71
+ class LinearReluModule (torch .nn .Module ):
72
+ def __init__ (self , in_size , out_size , use_bias ):
73
+ super ().__init__ ()
74
+ self .linear = torch .nn .Linear (in_size , out_size , bias = use_bias )
75
+
76
+ def forward (self , x ):
77
+ return torch .nn .functional .relu (self .linear (x ))
78
+
79
+ for use_bias in (True , False ):
80
+ for num_batch_dims in range (1 , 3 ):
81
+ self ._test_linear (
82
+ lambda in_size , out_size : LinearReluModule (
83
+ in_size ,
84
+ out_size ,
85
+ use_bias , # noqa
86
+ ),
87
+ uses_bias = use_bias ,
88
+ num_batch_dims = num_batch_dims ,
89
+ )
90
+
91
+ def test_qs8_linear_fused_relu (self ):
92
+ class LinearReluModule (torch .nn .Module ):
93
+ def __init__ (self , in_size , out_size , use_bias ):
94
+ super ().__init__ ()
95
+ self .linear = torch .nn .Linear (in_size , out_size , bias = use_bias )
96
+
97
+ def forward (self , x ):
98
+ return torch .nn .functional .relu (self .linear (x ))
99
+
100
+ for use_bias in (True , False ):
101
+ for num_batch_dims in range (1 , 3 ):
102
+ self ._test_linear (
103
+ lambda in_size , out_size : LinearReluModule (
104
+ in_size ,
105
+ out_size ,
106
+ use_bias , # noqa
107
+ ),
108
+ num_batch_dims = num_batch_dims ,
109
+ uses_bias = use_bias ,
110
+ quant = True ,
111
+ )
112
+
66
113
def test_qs8_linear (self ):
67
114
for use_bias in (True , False ):
68
- self ._test_linear (
69
- lambda in_size , out_size : torch .nn .Linear (
70
- in_size , out_size , bias = use_bias # noqa
71
- ),
72
- uses_bias = use_bias ,
73
- )
115
+ for num_batch_dims in range (1 , 3 ):
116
+ self ._test_linear (
117
+ lambda in_size , out_size : torch .nn .Linear (
118
+ in_size , out_size , bias = use_bias # noqa
119
+ ),
120
+ uses_bias = use_bias ,
121
+ num_batch_dims = num_batch_dims ,
122
+ )
74
123
75
124
@unittest .skip ("XNNPACK currently only supports per-channel dynamic quantization." )
76
125
def test_qd8_per_tensor_linear (self ):
77
126
for uses_bias in (False , True ):
78
127
inputs = (torch .randn (2 , 4 ),)
79
128
module = torch .nn .Linear (4 , 5 , bias = uses_bias )
129
+ dynamic_shapes = ({0 : torch .export .Dim ("batch" , max = 100 )},)
80
130
81
131
self ._test_dqlinear (
82
132
module ,
83
133
inputs ,
134
+ dynamic_shapes = dynamic_shapes ,
84
135
is_per_channel = False ,
85
136
uses_bias = uses_bias ,
86
137
)
@@ -93,6 +144,7 @@ def test_qd8_per_channel_linear(self):
93
144
self ._test_dqlinear (
94
145
module ,
95
146
inputs ,
147
+ dynamic_shapes = ({0 : torch .export .Dim ("batch" , max = 100 )},),
96
148
is_per_channel = True ,
97
149
uses_bias = uses_bias ,
98
150
)
@@ -114,7 +166,7 @@ def test_qd8_per_channel_4w_linear(self):
114
166
qconfig = self ._get_4b_dqconfig ()
115
167
input_channels = [2 , 63 ]
116
168
output_channels = [1 , 8 , 127 ]
117
- batches = [1 , 2 ]
169
+ batches = [2 , 2 ]
118
170
use_bias = [False , True ]
119
171
120
172
for bs , bias , ipc , opc in product (
@@ -129,13 +181,14 @@ def test_qd8_per_channel_4w_linear(self):
129
181
self ._test_dqlinear (
130
182
module ,
131
183
inputs ,
184
+ dynamic_shapes = ({0 : torch .export .Dim ("batch" , max = 100 )},),
132
185
is_per_channel = True ,
133
186
uses_bias = bias ,
134
187
qconfig = qconfig ,
135
188
)
136
189
137
190
def test_qd8_per_channel_linear_parallel (self ):
138
- in_size = 1
191
+ in_size = 2
139
192
input_size = 4
140
193
output_size = 5
141
194
@@ -165,17 +218,39 @@ def forward(self, x, y):
165
218
torch .rand (in_size , input_size , dtype = torch .float ),
166
219
torch .rand (in_size , input_size , dtype = torch .float ),
167
220
)
221
+ batch_dim = torch .export .Dim ("batch" , max = 100 )
222
+ dynamic_shapes = ({0 : batch_dim }, {0 : batch_dim })
168
223
169
224
self ._test_dqlinear (
170
225
ParallelLinear (),
171
226
inputs ,
227
+ dynamic_shapes = dynamic_shapes ,
172
228
linear_count = 2 ,
173
229
is_per_channel = True ,
174
230
uses_bias = True ,
175
231
)
176
232
233
+ def test_qd8_per_channel_linear_with_two_batch (self ):
234
+ in_size = 2
235
+ input_size = 4
236
+ output_size = 5
237
+
238
+ linear = torch .nn .Linear (input_size , output_size )
239
+ inputs = (torch .randn (2 , in_size , input_size , dtype = torch .float ),)
240
+ batch_dim = torch .export .Dim ("batch" , max = 100 )
241
+ dynamic_shapes = ({0 : batch_dim , 1 : batch_dim },)
242
+
243
+ self ._test_dqlinear (
244
+ linear ,
245
+ inputs ,
246
+ dynamic_shapes = dynamic_shapes ,
247
+ linear_count = 1 ,
248
+ is_per_channel = True ,
249
+ uses_bias = True ,
250
+ )
251
+
177
252
def test_qd8_per_channel_linear_sequential (self ):
178
- in_size = 1
253
+ in_size = 2
179
254
input_size = 4
180
255
intermediate_size = 5
181
256
output_size = 3
@@ -203,17 +278,19 @@ def forward(self, x):
203
278
return b
204
279
205
280
inputs = (torch .rand (in_size , input_size , dtype = torch .float ),)
281
+ dynamic_shapes = ({0 : torch .export .Dim ("batch" , max = 100 )},)
206
282
207
283
self ._test_dqlinear (
208
284
LinearSequential (),
209
285
inputs ,
286
+ dynamic_shapes = dynamic_shapes ,
210
287
linear_count = 2 ,
211
288
is_per_channel = True ,
212
289
uses_bias = True ,
213
290
)
214
291
215
292
def test_qd8_per_channel_linear_parellel_and_sequential (self ):
216
- in_size = 1
293
+ in_size = 2
217
294
input_size = 4
218
295
intermediate_size = 5
219
296
output_size = 3
@@ -252,9 +329,19 @@ def forward(self, x, y):
252
329
torch .rand (in_size , input_size , dtype = torch .float ),
253
330
torch .rand (in_size , input_size , dtype = torch .float ),
254
331
)
332
+ dynamic_shapes = (
333
+ {0 : torch .export .Dim ("batch" , max = 100 )},
334
+ {0 : torch .export .Dim ("batch2" , max = 100 )},
335
+ )
255
336
256
337
self ._test_dqlinear (
257
- LinearModule (), inputs , linear_count = 3 , is_per_channel = True , uses_bias = True
338
+ LinearModule (),
339
+ inputs ,
340
+ dynamic_shapes = dynamic_shapes ,
341
+ linear_count = 3 ,
342
+ is_per_channel = True ,
343
+ uses_bias = True ,
344
+ atol = 1e-1 ,
258
345
)
259
346
260
347
def test_fp32_linear_fused_relu (self ):
@@ -595,8 +682,7 @@ def _test_manual_dq_linear(
595
682
)
596
683
.to_executorch ()
597
684
.serialize ()
598
- .run_method ()
599
- .compare_outputs (atol = atol , rtol = rtol )
685
+ .run_method_and_compare_outputs (atol = atol , rtol = rtol )
600
686
)
601
687
602
688
def _run_manual_dqlinear_tests (self , weight_n_bit : int , op_dtype : torch .dtype ):
@@ -677,6 +763,7 @@ def _test_linear(
677
763
self ,
678
764
make_module ,
679
765
uses_bias ,
766
+ num_batch_dims = 1 ,
680
767
quant = False ,
681
768
dtype : torch .dtype = torch .float ,
682
769
atol = 1e-03 ,
@@ -693,7 +780,7 @@ def _test_linear(
693
780
)
694
781
)
695
782
696
- in_sizes = [1 , 4 , 4 ]
783
+ in_sizes = [3 , 4 , 4 ]
697
784
input_sizes = [4 , 37 , 17 ]
698
785
output_sizes = [4 , 17 , 37 ]
699
786
@@ -705,11 +792,18 @@ def _test_linear(
705
792
in_size = int (in_sizes [i ])
706
793
input_size = int (input_sizes [i ])
707
794
output_size = int (output_sizes [i ])
795
+ input_shape = [in_size ] * num_batch_dims + [input_size ]
796
+ print (f"Testing input_shape { input_shape } with { output_size } out_channels" )
708
797
709
798
module = make_module (input_size , output_size ).eval ().to (dtype )
710
- inputs = (torch .randn (in_size , input_size ).to (dtype ),)
799
+ inputs = (torch .randn (input_shape ).to (dtype ),)
800
+ dynamic_shape = {}
801
+ for i in range (num_batch_dims ):
802
+ dynamic_shape [i ] = torch .export .Dim (f"batch{ i } " , max = in_size )
803
+
804
+ dynamic_shape = (dynamic_shape ,)
711
805
712
- tester = Tester (module , inputs )
806
+ tester = Tester (module , inputs , dynamic_shapes = dynamic_shape )
713
807
714
808
if quant :
715
809
tester .quantize ()
@@ -737,10 +831,12 @@ def _test_dqlinear(
737
831
self ,
738
832
module ,
739
833
inputs ,
834
+ dynamic_shapes ,
740
835
linear_count = 1 ,
741
836
is_per_channel = False ,
742
837
uses_bias = False ,
743
838
qconfig : Optional [QuantizationConfig ] = None ,
839
+ atol = 5e-02 ,
744
840
):
745
841
aten_op , edge_op = (
746
842
(
@@ -759,13 +855,12 @@ def _test_dqlinear(
759
855
is_dynamic = True ,
760
856
)
761
857
762
- tester = Tester (module , inputs )
858
+ tester = Tester (module , inputs , dynamic_shapes = dynamic_shapes )
763
859
tester .quantize (Quantize (quantization_config = quant_config ))
764
860
765
861
tester .export ()
766
862
tester .check_count ({aten_op : linear_count })
767
863
tester .check (["torch.ops.quantized_decomposed" ])
768
- tester .dump_artifact ()
769
864
tester .to_edge ()
770
865
tester .check_count ({edge_op : linear_count })
771
866
@@ -777,4 +872,4 @@ def _test_dqlinear(
777
872
778
873
tester .to_executorch ()
779
874
tester .serialize ()
780
- tester .run_method_and_compare_outputs (atol = 5e-02 )
875
+ tester .run_method_and_compare_outputs (atol = atol )
0 commit comments