4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
+ # pyre-strict
8
+
7
9
from abc import ABC , abstractmethod
8
10
from dataclasses import dataclass , field
9
- from typing import Any , Callable , List , Optional , Tuple , Type , Union
11
+ from typing import List , Optional , Tuple , Union
10
12
11
13
import torch
12
14
from executorch .backends .cadence .aot .quantizer .utils import get_bias_qparams
13
15
14
16
from torch import fx
17
+ from torch ._ops import OpOverload
15
18
from torch .ao .quantization .quantizer import (
16
19
DerivedQuantizationSpec ,
17
20
SharedQuantizationSpec ,
@@ -44,18 +47,20 @@ class PartitionAnchors:
44
47
45
48
class QuantizationPattern (ABC ):
46
49
@abstractmethod
47
- def partition_types (self ):
50
+ def partition_types (self ) -> list [ OpOverload ] :
48
51
"""
49
- List of types to be passed to find_sequential_partitions .
52
+ List of types to be passed to find_sequential_partitions_aten .
50
53
"""
51
54
pass
52
55
53
56
@abstractmethod
54
- def get_anchors (self , gm , fused_partition ) -> Optional [PartitionAnchors ]:
57
+ def get_anchors (
58
+ self , gm : torch .fx .GraphModule , fused_partition : List [fx .GraphModule ]
59
+ ) -> Optional [PartitionAnchors ]:
55
60
pass
56
61
57
62
@abstractmethod
58
- def replacement_op (self ) -> Callable [..., Any ] :
63
+ def replacement_op (self ) -> OpOverload :
59
64
"""
60
65
Operator (most likely a custom one) that this partition should be fused into in
61
66
the backend. Refer to the QuantFusion pass for examples.
@@ -64,8 +69,8 @@ def replacement_op(self) -> Callable[..., Any]:
64
69
65
70
66
71
class AddmmPattern (QuantizationPattern ):
67
- def partition_types (self ) -> List [Type [ torch . nn . Module ] ]:
68
- return [torch .addmm ]
72
+ def partition_types (self ) -> List [OpOverload ]:
73
+ return [torch .ops . aten . addmm . default ]
69
74
70
75
def get_anchors (
71
76
self , gm : fx .GraphModule , fused_partition : List [fx .GraphModule ]
@@ -91,13 +96,13 @@ def get_anchors(
91
96
output = [(addmm_node ,)],
92
97
)
93
98
94
- def replacement_op (self ):
99
+ def replacement_op (self ) -> OpOverload :
95
100
return torch .ops .cadence .quantized_linear
96
101
97
102
98
103
class BmmPattern (QuantizationPattern ):
99
- def partition_types (self ) -> List [Callable [..., torch . Tensor ] ]:
100
- return [torch .bmm ]
104
+ def partition_types (self ) -> List [OpOverload ]:
105
+ return [torch .ops . aten . bmm . default ]
101
106
102
107
def get_anchors (
103
108
self , gm : fx .GraphModule , fused_partition : List [fx .GraphModule ]
@@ -111,13 +116,13 @@ def get_anchors(
111
116
output = [(bmm_node ,)],
112
117
)
113
118
114
- def replacement_op (self ):
119
+ def replacement_op (self ) -> OpOverload :
115
120
return torch .ops .cadence .quantized_matmul .default
116
121
117
122
118
123
class Conv1dPattern (QuantizationPattern ):
119
- def partition_types (self ) -> List [Type [ torch . nn . Module ] ]:
120
- return [torch .nn . Conv1d ]
124
+ def partition_types (self ) -> List [OpOverload ]:
125
+ return [torch .ops . aten . conv1d . default ]
121
126
122
127
def get_anchors (
123
128
self , gm : fx .GraphModule , fused_partition : List [fx .GraphModule ]
@@ -149,13 +154,13 @@ def get_anchors(
149
154
output = [(conv1d_node ,)],
150
155
)
151
156
152
- def replacement_op (self ):
157
+ def replacement_op (self ) -> OpOverload :
153
158
return torch .ops .cadence .quantized_conv .default
154
159
155
160
156
161
class Conv2dPattern (QuantizationPattern ):
157
- def partition_types (self ) -> List [Type [ torch . nn . Module ] ]:
158
- return [torch .nn . Conv2d ]
162
+ def partition_types (self ) -> List [OpOverload ]:
163
+ return [torch .ops . aten . conv2d . default ]
159
164
160
165
def get_anchors (
161
166
self , gm : fx .GraphModule , fused_partition : List [fx .GraphModule ]
@@ -187,37 +192,17 @@ def get_anchors(
187
192
output = [(conv2d_node ,)],
188
193
)
189
194
190
- def replacement_op (self ):
195
+ def replacement_op (self ) -> OpOverload :
191
196
return torch .ops .cadence .quantized_conv .default
192
197
193
198
194
199
class LayerNormPattern (QuantizationPattern ):
195
- def partition_types (self ):
196
- return [torch .nn .LayerNorm ]
197
-
198
- def get_anchors (self , gm , fused_partition ) -> PartitionAnchors :
199
- layer_norm_node = fused_partition [0 ].nodes [- 1 ]
200
-
201
- # Weights and biases are used as fp32 by our kernel, so they are
202
- # passed in as others here along with the normalized shape.
203
- return PartitionAnchors (
204
- inputs = [(layer_norm_node , 0 )],
205
- weights = [],
206
- biases = [],
207
- # Ordering: normalized_shape, weights, bias
208
- others = [(layer_norm_node , 1 ), (layer_norm_node , 2 ), (layer_norm_node , 3 )],
209
- output = [(layer_norm_node ,)],
210
- )
200
+ def partition_types (self ) -> List [OpOverload ]:
201
+ return [torch .ops .aten .layer_norm .default ]
211
202
212
- def replacement_op (self ):
213
- return torch .ops .cadence .quantized_layer_norm .default
214
-
215
-
216
- class LayerNormFunctionalPattern (QuantizationPattern ):
217
- def partition_types (self ):
218
- return [torch .nn .functional .layer_norm ]
219
-
220
- def get_anchors (self , gm , fused_partition ) -> PartitionAnchors :
203
+ def get_anchors (
204
+ self , gm : fx .GraphModule , fused_partition : List [fx .GraphModule ]
205
+ ) -> PartitionAnchors :
221
206
layer_norm_node = fused_partition [0 ].nodes [- 1 ]
222
207
223
208
others = [(layer_norm_node , 1 )]
@@ -241,13 +226,13 @@ def get_anchors(self, gm, fused_partition) -> PartitionAnchors:
241
226
output = [(layer_norm_node ,)],
242
227
)
243
228
244
- def replacement_op (self ):
229
+ def replacement_op (self ) -> OpOverload :
245
230
return torch .ops .cadence .quantized_layer_norm .default
246
231
247
232
248
233
class LinearPattern (QuantizationPattern ):
249
- def partition_types (self ) -> List [Type [ torch . nn . Module ] ]:
250
- return [torch .nn . Linear ]
234
+ def partition_types (self ) -> List [OpOverload ]:
235
+ return [torch .ops . aten . linear . default ]
251
236
252
237
def get_anchors (
253
238
self , gm : fx .GraphModule , fused_partition : List [fx .GraphModule ]
@@ -279,51 +264,13 @@ def get_anchors(
279
264
output = [(linear_node ,)],
280
265
)
281
266
282
- def replacement_op (self ):
283
- return torch .ops .cadence .quantized_linear .default
284
-
285
-
286
- class LinearFunctionalPattern (QuantizationPattern ):
287
- def partition_types (self ):
288
- return [torch .nn .functional .linear ]
289
-
290
- def get_anchors (
291
- self , gm : fx .GraphModule , fused_partition : List [fx .GraphModule ]
292
- ) -> PartitionAnchors :
293
- linear_node = fused_partition [0 ].nodes [- 1 ]
294
-
295
- bias_qspec = DerivedQuantizationSpec (
296
- derived_from = [
297
- (linear_node .args [0 ], linear_node ),
298
- (linear_node .args [1 ], linear_node ),
299
- ],
300
- derive_qparams_fn = get_bias_qparams ,
301
- dtype = torch .int32 ,
302
- quant_min = - (2 ** 31 ),
303
- quant_max = 2 ** 31 - 1 ,
304
- qscheme = torch .per_tensor_affine ,
305
- )
306
-
307
- # Keep bias empty if not supplied
308
- bias = []
309
- if len (linear_node .args ) > 2 and linear_node .args [2 ] is not None :
310
- bias = [(linear_node , 2 , bias_qspec )]
311
-
312
- return PartitionAnchors (
313
- inputs = [(linear_node , 0 )],
314
- weights = [(linear_node , 1 )],
315
- # pyre-fixme[6]: Incompatible parameter type
316
- biases = bias ,
317
- output = [(linear_node ,)],
318
- )
319
-
320
- def replacement_op (self ):
267
+ def replacement_op (self ) -> OpOverload :
321
268
return torch .ops .cadence .quantized_linear .default
322
269
323
270
324
271
class MatmulPattern (QuantizationPattern ):
325
- def partition_types (self ):
326
- return [torch .matmul ]
272
+ def partition_types (self ) -> List [ OpOverload ] :
273
+ return [torch .ops . aten . matmul . default ]
327
274
328
275
def get_anchors (
329
276
self , gm : fx .GraphModule , fused_partition : List [fx .GraphModule ]
@@ -337,13 +284,13 @@ def get_anchors(
337
284
output = [(matmul_node ,)],
338
285
)
339
286
340
- def replacement_op (self ):
287
+ def replacement_op (self ) -> OpOverload :
341
288
return torch .ops .cadence .quantized_matmul .default
342
289
343
290
344
291
class ReluPattern (QuantizationPattern ):
345
- def partition_types (self ) -> List [Type [ torch . nn . Module ] ]:
346
- return [torch .nn . ReLU ]
292
+ def partition_types (self ) -> List [OpOverload ]:
293
+ return [torch .ops . aten . relu . default ]
347
294
348
295
def get_anchors (
349
296
self , gm : fx .GraphModule , fused_partition : List [fx .GraphModule ]
@@ -359,5 +306,5 @@ def get_anchors(
359
306
],
360
307
)
361
308
362
- def replacement_op (self ):
309
+ def replacement_op (self ) -> OpOverload :
363
310
return torch .ops .cadence .quantized_relu .default
0 commit comments