Skip to content

Commit 4542695

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Migrate the quantizer to use aten ops directly (#4195)
Summary: Pull Request resolved: #4195 This major change allows a lot more flexibility in the quantizer, and reduces the dependency on the decompositions/graph tracing tools. The motivation is that some of those do not preserve or propagate `source_fn_stack` information, resulting in quantization misses. SDPA is an example, where the underlying `bmm` ops cannot be quantized with `source_fn_stack` information alone, or MHA, which can hide its SDPA component and sometimes even `linear` ops depending on the model (see ViT for an example). Summary of the changes: - change the quantizer to match aten ops directly, through `node.target` - propagate required changes to the `QuantFusion` pass - update/remove existing patterns Differential Revision: D59552606
1 parent 238850b commit 4542695

File tree

5 files changed

+166
-121
lines changed

5 files changed

+166
-121
lines changed

backends/cadence/aot/compiler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
)
2020
from executorch.backends.cadence.aot.quantizer.fusion_pass import QuantFusion
2121
from executorch.backends.cadence.aot.quantizer.quantizer import (
22-
CadenceGenericQuantizer,
22+
CadenceAtenQuantizer,
2323
CadenceQuantizer,
2424
)
2525
from executorch.backends.cadence.aot.utils import model_is_quantized
@@ -58,7 +58,7 @@ def quantize_pt2(
5858

5959
# Get patterns and apply fusion of dq -> op -> q to qop
6060
patterns = [
61-
assert_is_instance(q, CadenceGenericQuantizer).pattern
61+
assert_is_instance(q, CadenceAtenQuantizer).pattern
6262
for q in quantizer.quantizers
6363
]
6464
QuantFusion(patterns)(converted_model)

backends/cadence/aot/quantizer/fusion_pass.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,19 @@
1414
BmmPattern,
1515
Conv1dPattern,
1616
Conv2dPattern,
17-
LayerNormFunctionalPattern,
1817
LayerNormPattern,
19-
LinearFunctionalPattern,
2018
LinearPattern,
2119
MatmulPattern,
2220
ReluPattern,
2321
)
2422
from executorch.backends.cadence.aot.quantizer.utils import (
2523
create_zero_bias_int32,
24+
find_sequential_partitions_aten,
2625
get_conv_args,
2726
quantize_tensor_multiplier,
2827
)
2928
from executorch.exir.pass_base import ExportPass
3029
from torch import fx
31-
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
3230
from torch.fx import GraphModule
3331
from torch.fx.passes.infra.pass_base import PassResult
3432
from torch.fx.passes.utils.fuser_utils import legalize_graph
@@ -310,7 +308,7 @@ def __init__(self, patterns) -> None:
310308

311309
def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
312310
for pattern in self.patterns:
313-
fused_partitions = find_sequential_partitions(
311+
fused_partitions = find_sequential_partitions_aten(
314312
graph_module,
315313
pattern.partition_types(),
316314
)
@@ -375,9 +373,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
375373
quant_node,
376374
op_node,
377375
)
378-
elif isinstance(pattern, LinearPattern) or isinstance(
379-
pattern, LinearFunctionalPattern
380-
):
376+
elif isinstance(pattern, LinearPattern):
381377
args, kwargs = get_args_and_kwargs_linear(
382378
graph_module,
383379
inputs_inputs,
@@ -387,9 +383,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
387383
bias_inputs,
388384
quant_node,
389385
)
390-
elif isinstance(pattern, LayerNormPattern) or isinstance(
391-
pattern, LayerNormFunctionalPattern
392-
):
386+
elif isinstance(pattern, LayerNormPattern):
393387
args, kwargs = get_args_and_kwargs_layer_norm(
394388
graph_module,
395389
inputs_inputs,

backends/cadence/aot/quantizer/patterns.py

Lines changed: 37 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,17 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
# pyre-strict
8+
79
from abc import ABC, abstractmethod
810
from dataclasses import dataclass, field
9-
from typing import Any, Callable, List, Optional, Tuple, Type, Union
11+
from typing import List, Optional, Tuple, Union
1012

1113
import torch
1214
from executorch.backends.cadence.aot.quantizer.utils import get_bias_qparams
1315

1416
from torch import fx
17+
from torch._ops import OpOverload
1518
from torch.ao.quantization.quantizer import (
1619
DerivedQuantizationSpec,
1720
SharedQuantizationSpec,
@@ -44,18 +47,20 @@ class PartitionAnchors:
4447

4548
class QuantizationPattern(ABC):
4649
@abstractmethod
47-
def partition_types(self):
50+
def partition_types(self) -> list[OpOverload]:
4851
"""
49-
List of types to be passed to find_sequential_partitions.
52+
List of types to be passed to find_sequential_partitions_aten.
5053
"""
5154
pass
5255

5356
@abstractmethod
54-
def get_anchors(self, gm, fused_partition) -> Optional[PartitionAnchors]:
57+
def get_anchors(
58+
self, gm: torch.fx.GraphModule, fused_partition: List[fx.GraphModule]
59+
) -> Optional[PartitionAnchors]:
5560
pass
5661

5762
@abstractmethod
58-
def replacement_op(self) -> Callable[..., Any]:
63+
def replacement_op(self) -> OpOverload:
5964
"""
6065
Operator (most likely a custom one) that this partition should be fused into in
6166
the backend. Refer to the QuantFusion pass for examples.
@@ -64,8 +69,8 @@ def replacement_op(self) -> Callable[..., Any]:
6469

6570

6671
class AddmmPattern(QuantizationPattern):
67-
def partition_types(self) -> List[Type[torch.nn.Module]]:
68-
return [torch.addmm]
72+
def partition_types(self) -> List[OpOverload]:
73+
return [torch.ops.aten.addmm.default]
6974

7075
def get_anchors(
7176
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
@@ -91,13 +96,13 @@ def get_anchors(
9196
output=[(addmm_node,)],
9297
)
9398

94-
def replacement_op(self):
99+
def replacement_op(self) -> OpOverload:
95100
return torch.ops.cadence.quantized_linear
96101

97102

98103
class BmmPattern(QuantizationPattern):
99-
def partition_types(self) -> List[Callable[..., torch.Tensor]]:
100-
return [torch.bmm]
104+
def partition_types(self) -> List[OpOverload]:
105+
return [torch.ops.aten.bmm.default]
101106

102107
def get_anchors(
103108
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
@@ -111,13 +116,13 @@ def get_anchors(
111116
output=[(bmm_node,)],
112117
)
113118

114-
def replacement_op(self):
119+
def replacement_op(self) -> OpOverload:
115120
return torch.ops.cadence.quantized_matmul.default
116121

117122

118123
class Conv1dPattern(QuantizationPattern):
119-
def partition_types(self) -> List[Type[torch.nn.Module]]:
120-
return [torch.nn.Conv1d]
124+
def partition_types(self) -> List[OpOverload]:
125+
return [torch.ops.aten.conv1d.default]
121126

122127
def get_anchors(
123128
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
@@ -149,13 +154,13 @@ def get_anchors(
149154
output=[(conv1d_node,)],
150155
)
151156

152-
def replacement_op(self):
157+
def replacement_op(self) -> OpOverload:
153158
return torch.ops.cadence.quantized_conv.default
154159

155160

156161
class Conv2dPattern(QuantizationPattern):
157-
def partition_types(self) -> List[Type[torch.nn.Module]]:
158-
return [torch.nn.Conv2d]
162+
def partition_types(self) -> List[OpOverload]:
163+
return [torch.ops.aten.conv2d.default]
159164

160165
def get_anchors(
161166
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
@@ -187,37 +192,17 @@ def get_anchors(
187192
output=[(conv2d_node,)],
188193
)
189194

190-
def replacement_op(self):
195+
def replacement_op(self) -> OpOverload:
191196
return torch.ops.cadence.quantized_conv.default
192197

193198

194199
class LayerNormPattern(QuantizationPattern):
195-
def partition_types(self):
196-
return [torch.nn.LayerNorm]
197-
198-
def get_anchors(self, gm, fused_partition) -> PartitionAnchors:
199-
layer_norm_node = fused_partition[0].nodes[-1]
200-
201-
# Weights and biases are used as fp32 by our kernel, so they are
202-
# passed in as others here along with the normalized shape.
203-
return PartitionAnchors(
204-
inputs=[(layer_norm_node, 0)],
205-
weights=[],
206-
biases=[],
207-
# Ordering: normalized_shape, weights, bias
208-
others=[(layer_norm_node, 1), (layer_norm_node, 2), (layer_norm_node, 3)],
209-
output=[(layer_norm_node,)],
210-
)
200+
def partition_types(self) -> List[OpOverload]:
201+
return [torch.ops.aten.layer_norm.default]
211202

212-
def replacement_op(self):
213-
return torch.ops.cadence.quantized_layer_norm.default
214-
215-
216-
class LayerNormFunctionalPattern(QuantizationPattern):
217-
def partition_types(self):
218-
return [torch.nn.functional.layer_norm]
219-
220-
def get_anchors(self, gm, fused_partition) -> PartitionAnchors:
203+
def get_anchors(
204+
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
205+
) -> PartitionAnchors:
221206
layer_norm_node = fused_partition[0].nodes[-1]
222207

223208
others = [(layer_norm_node, 1)]
@@ -241,13 +226,13 @@ def get_anchors(self, gm, fused_partition) -> PartitionAnchors:
241226
output=[(layer_norm_node,)],
242227
)
243228

244-
def replacement_op(self):
229+
def replacement_op(self) -> OpOverload:
245230
return torch.ops.cadence.quantized_layer_norm.default
246231

247232

248233
class LinearPattern(QuantizationPattern):
249-
def partition_types(self) -> List[Type[torch.nn.Module]]:
250-
return [torch.nn.Linear]
234+
def partition_types(self) -> List[OpOverload]:
235+
return [torch.ops.aten.linear.default]
251236

252237
def get_anchors(
253238
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
@@ -279,51 +264,13 @@ def get_anchors(
279264
output=[(linear_node,)],
280265
)
281266

282-
def replacement_op(self):
283-
return torch.ops.cadence.quantized_linear.default
284-
285-
286-
class LinearFunctionalPattern(QuantizationPattern):
287-
def partition_types(self):
288-
return [torch.nn.functional.linear]
289-
290-
def get_anchors(
291-
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
292-
) -> PartitionAnchors:
293-
linear_node = fused_partition[0].nodes[-1]
294-
295-
bias_qspec = DerivedQuantizationSpec(
296-
derived_from=[
297-
(linear_node.args[0], linear_node),
298-
(linear_node.args[1], linear_node),
299-
],
300-
derive_qparams_fn=get_bias_qparams,
301-
dtype=torch.int32,
302-
quant_min=-(2**31),
303-
quant_max=2**31 - 1,
304-
qscheme=torch.per_tensor_affine,
305-
)
306-
307-
# Keep bias empty if not supplied
308-
bias = []
309-
if len(linear_node.args) > 2 and linear_node.args[2] is not None:
310-
bias = [(linear_node, 2, bias_qspec)]
311-
312-
return PartitionAnchors(
313-
inputs=[(linear_node, 0)],
314-
weights=[(linear_node, 1)],
315-
# pyre-fixme[6]: Incompatible parameter type
316-
biases=bias,
317-
output=[(linear_node,)],
318-
)
319-
320-
def replacement_op(self):
267+
def replacement_op(self) -> OpOverload:
321268
return torch.ops.cadence.quantized_linear.default
322269

323270

324271
class MatmulPattern(QuantizationPattern):
325-
def partition_types(self):
326-
return [torch.matmul]
272+
def partition_types(self) -> List[OpOverload]:
273+
return [torch.ops.aten.matmul.default]
327274

328275
def get_anchors(
329276
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
@@ -337,13 +284,13 @@ def get_anchors(
337284
output=[(matmul_node,)],
338285
)
339286

340-
def replacement_op(self):
287+
def replacement_op(self) -> OpOverload:
341288
return torch.ops.cadence.quantized_matmul.default
342289

343290

344291
class ReluPattern(QuantizationPattern):
345-
def partition_types(self) -> List[Type[torch.nn.Module]]:
346-
return [torch.nn.ReLU]
292+
def partition_types(self) -> List[OpOverload]:
293+
return [torch.ops.aten.relu.default]
347294

348295
def get_anchors(
349296
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
@@ -359,5 +306,5 @@ def get_anchors(
359306
],
360307
)
361308

362-
def replacement_op(self):
309+
def replacement_op(self) -> OpOverload:
363310
return torch.ops.cadence.quantized_relu.default

backends/cadence/aot/quantizer/quantizer.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,30 +4,29 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import List
7+
from typing import List, Optional
88

99
import torch
1010
from executorch.backends.cadence.aot.quantizer.patterns import (
1111
AddmmPattern,
1212
BmmPattern,
1313
Conv1dPattern,
1414
Conv2dPattern,
15-
LayerNormFunctionalPattern,
1615
LayerNormPattern,
17-
LinearFunctionalPattern,
1816
LinearPattern,
1917
MatmulPattern,
18+
QuantizationPattern,
2019
ReluPattern,
2120
)
2221
from executorch.backends.cadence.aot.quantizer.utils import (
22+
find_sequential_partitions_aten,
2323
is_annotated,
2424
no_outside_users,
2525
)
2626

2727
from torch import fx
2828

2929
from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver
30-
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
3130
from torch.ao.quantization.quantizer import Quantizer
3231
from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer
3332
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
@@ -56,17 +55,19 @@
5655
observer_or_fake_quant_ctr=MinMaxObserver,
5756
)
5857

59-
bias_qspec = None
58+
bias_qspec: Optional[QuantizationSpec] = None
6059

6160

62-
class CadenceGenericQuantizer(Quantizer):
63-
def __init__(self, pattern, quantization_config):
61+
class CadenceAtenQuantizer(Quantizer):
62+
def __init__(
63+
self, pattern: QuantizationPattern, quantization_config: QuantizationConfig
64+
):
6465
super().__init__()
6566
self.pattern = pattern
6667
self.quantization_config = quantization_config
6768

6869
def annotate(self, model):
69-
fused_partitions = find_sequential_partitions(
70+
fused_partitions = find_sequential_partitions_aten(
7071
model,
7172
self.pattern.partition_types(),
7273
)
@@ -133,15 +134,13 @@ def __init__(self):
133134
)
134135
super().__init__(
135136
[
136-
CadenceGenericQuantizer(AddmmPattern(), static_qconfig),
137-
CadenceGenericQuantizer(BmmPattern(), static_qconfig),
138-
CadenceGenericQuantizer(Conv1dPattern(), static_qconfig),
139-
CadenceGenericQuantizer(Conv2dPattern(), static_qconfig),
140-
CadenceGenericQuantizer(LayerNormPattern(), static_qconfig),
141-
CadenceGenericQuantizer(LayerNormFunctionalPattern(), static_qconfig),
142-
CadenceGenericQuantizer(LinearPattern(), static_qconfig),
143-
CadenceGenericQuantizer(LinearFunctionalPattern(), static_qconfig),
144-
CadenceGenericQuantizer(MatmulPattern(), static_qconfig),
145-
CadenceGenericQuantizer(ReluPattern(), static_qconfig),
137+
CadenceAtenQuantizer(AddmmPattern(), static_qconfig),
138+
CadenceAtenQuantizer(BmmPattern(), static_qconfig),
139+
CadenceAtenQuantizer(Conv1dPattern(), static_qconfig),
140+
CadenceAtenQuantizer(Conv2dPattern(), static_qconfig),
141+
CadenceAtenQuantizer(LayerNormPattern(), static_qconfig),
142+
CadenceAtenQuantizer(LinearPattern(), static_qconfig),
143+
CadenceAtenQuantizer(MatmulPattern(), static_qconfig),
144+
CadenceAtenQuantizer(ReluPattern(), static_qconfig),
146145
]
147146
)

0 commit comments

Comments
 (0)