Skip to content

Commit f680897

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Migrate the quantizer to use aten ops directly (#4195)
Summary: This major change allows a lot more flexibility in the quantizer, and reduces the dependency on the decompositions/graph tracing tools. The motivation is that some of those do not preserve or propagate `source_fn_stack` information, resulting in quantization misses. SDPA is an example, where the underlying `bmm` ops cannot be quantized with `source_fn_stack` information alone, or MHA, which can hide its SDPA component and sometimes even `linear` ops depending on the model (see ViT for an example). Also note than in most cases, we match single nodes anyway, with a 1-1 mapping between the op (either nn.Module or nn.functional) and the aten op, so using the aten op directly is simply easier. Summary of the changes: - change the quantizer to match aten ops directly, through `node.target` - propagate required changes to the `QuantFusion` pass - update/remove existing patterns Reviewed By: dulinriley Differential Revision: D59552606
1 parent aa50879 commit f680897

File tree

5 files changed

+127
-29
lines changed

5 files changed

+127
-29
lines changed

backends/cadence/aot/compiler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
)
2020
from executorch.backends.cadence.aot.quantizer.fusion_pass import QuantFusion
2121
from executorch.backends.cadence.aot.quantizer.quantizer import (
22-
CadenceGenericQuantizer,
22+
CadenceAtenQuantizer,
2323
CadenceQuantizer,
2424
)
2525
from executorch.backends.cadence.aot.utils import model_is_quantized
@@ -58,7 +58,7 @@ def quantize_pt2(
5858

5959
# Get patterns and apply fusion of dq -> op -> q to qop
6060
patterns = [
61-
assert_is_instance(q, CadenceGenericQuantizer).pattern
61+
assert_is_instance(q, CadenceAtenQuantizer).pattern
6262
for q in quantizer.quantizers
6363
]
6464
QuantFusion(patterns)(converted_model)

backends/cadence/aot/quantizer/fusion_pass.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,19 @@
1414
BmmPattern,
1515
Conv1dPattern,
1616
Conv2dPattern,
17-
LayerNormFunctionalPattern,
1817
LayerNormPattern,
19-
LinearFunctionalPattern,
2018
LinearPattern,
2119
MatmulPattern,
2220
ReluPattern,
2321
)
2422
from executorch.backends.cadence.aot.quantizer.utils import (
2523
create_zero_bias_int32,
24+
find_sequential_partitions_aten,
2625
get_conv_args,
2726
quantize_tensor_multiplier,
2827
)
2928
from executorch.exir.pass_base import ExportPass
3029
from torch import fx
31-
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
3230
from torch.fx import GraphModule
3331
from torch.fx.passes.infra.pass_base import PassResult
3432
from torch.fx.passes.utils.fuser_utils import legalize_graph
@@ -310,14 +308,15 @@ def __init__(self, patterns) -> None:
310308

311309
def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
312310
for pattern in self.patterns:
313-
fused_partitions = find_sequential_partitions(
311+
fused_partitions = find_sequential_partitions_aten(
314312
graph_module,
315313
pattern.partition_types(),
316314
)
317315
for fused_partition in fused_partitions:
318316
anchors = pattern.get_anchors(graph_module, fused_partition)
319317
if not anchors:
320318
continue
319+
# pyre-ignore[16]: Undefined attribute
321320
if any(self.is_fused(p.nodes) for p in fused_partition):
322321
continue
323322

@@ -373,9 +372,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
373372
quant_node,
374373
op_node,
375374
)
376-
elif isinstance(pattern, LinearPattern) or isinstance(
377-
pattern, LinearFunctionalPattern
378-
):
375+
elif isinstance(pattern, LinearPattern):
379376
args, kwargs = get_args_and_kwargs_linear(
380377
graph_module,
381378
inputs_inputs,
@@ -385,9 +382,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
385382
bias_inputs,
386383
quant_node,
387384
)
388-
elif isinstance(pattern, LayerNormPattern) or isinstance(
389-
pattern, LayerNormFunctionalPattern
390-
):
385+
elif isinstance(pattern, LayerNormPattern):
391386
args, kwargs = get_args_and_kwargs_layer_norm(
392387
graph_module,
393388
inputs_inputs,

backends/cadence/aot/quantizer/quantizer.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,14 @@
1414
BmmPattern,
1515
Conv1dPattern,
1616
Conv2dPattern,
17-
LayerNormFunctionalPattern,
1817
LayerNormPattern,
19-
LinearFunctionalPattern,
2018
LinearPattern,
2119
MatmulPattern,
2220
QuantizationPattern,
2321
ReluPattern,
2422
)
2523
from executorch.backends.cadence.aot.quantizer.utils import (
24+
find_sequential_partitions_aten,
2625
is_annotated,
2726
no_outside_users,
2827
)
@@ -31,7 +30,6 @@
3130
from torch import fx
3231

3332
from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver
34-
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
3533
from torch.ao.quantization.quantizer import DerivedQuantizationSpec, Quantizer
3634
from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer
3735
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
@@ -63,7 +61,7 @@
6361
bias_qspec: Optional[QuantizationSpec] = None
6462

6563

66-
class CadenceGenericQuantizer(Quantizer):
64+
class CadenceAtenQuantizer(Quantizer):
6765
def __init__(
6866
self, pattern: QuantizationPattern, quantization_config: QuantizationConfig
6967
) -> None:
@@ -72,7 +70,7 @@ def __init__(
7270
self.quantization_config = quantization_config
7371

7472
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
75-
fused_partitions = find_sequential_partitions(
73+
fused_partitions = find_sequential_partitions_aten(
7674
model,
7775
self.pattern.partition_types(),
7876
)
@@ -154,15 +152,13 @@ def __init__(self) -> None:
154152
)
155153
super().__init__(
156154
[
157-
CadenceGenericQuantizer(AddmmPattern(), static_qconfig),
158-
CadenceGenericQuantizer(BmmPattern(), static_qconfig),
159-
CadenceGenericQuantizer(Conv1dPattern(), static_qconfig),
160-
CadenceGenericQuantizer(Conv2dPattern(), static_qconfig),
161-
CadenceGenericQuantizer(LayerNormPattern(), static_qconfig),
162-
CadenceGenericQuantizer(LayerNormFunctionalPattern(), static_qconfig),
163-
CadenceGenericQuantizer(LinearPattern(), static_qconfig),
164-
CadenceGenericQuantizer(LinearFunctionalPattern(), static_qconfig),
165-
CadenceGenericQuantizer(MatmulPattern(), static_qconfig),
166-
CadenceGenericQuantizer(ReluPattern(), static_qconfig),
155+
CadenceAtenQuantizer(AddmmPattern(), static_qconfig),
156+
CadenceAtenQuantizer(BmmPattern(), static_qconfig),
157+
CadenceAtenQuantizer(Conv1dPattern(), static_qconfig),
158+
CadenceAtenQuantizer(Conv2dPattern(), static_qconfig),
159+
CadenceAtenQuantizer(LayerNormPattern(), static_qconfig),
160+
CadenceAtenQuantizer(LinearPattern(), static_qconfig),
161+
CadenceAtenQuantizer(MatmulPattern(), static_qconfig),
162+
CadenceAtenQuantizer(ReluPattern(), static_qconfig),
167163
]
168164
)

backends/cadence/aot/quantizer/utils.py

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,21 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import itertools
8+
from collections import OrderedDict
79
from math import frexp, isclose, trunc
8-
from typing import List, Tuple
10+
from typing import Any, Dict, List, Tuple, Type
911

1012
import torch
1113
from torch import fx
14+
from torch._ops import OpOverload
1215
from torch.ao.quantization import ObserverOrFakeQuantize
1316

1417
from torch.fx import GraphModule
18+
from torch.fx.passes.utils.source_matcher_utils import (
19+
check_subgraphs_connected,
20+
SourcePartition,
21+
)
1522

1623

1724
def quantize_tensor_multiplier(
@@ -127,3 +134,101 @@ def get_bias_qparams(
127134

128135
def get_conv_args(arg, first_val: int) -> List[fx.Node]:
129136
return arg if len(arg) == 2 else [first_val, arg[0]]
137+
138+
139+
def get_aten_node_target_partitions(
140+
graph: torch.fx.Graph,
141+
wanted_original_aten_op: List[OpOverload],
142+
) -> Dict[Any, List[SourcePartition]]:
143+
"""
144+
Args:
145+
graph: The graph we want to partition
146+
wanted_sources: List of orginal_aten ops (OpOverload)
147+
148+
Returns:
149+
Dictionary mapping aten ops that were given to a list of SourcePartitions
150+
that correspond to the list of nodes that were decomposed from the given
151+
aten ops.
152+
"""
153+
modules: Dict[Type, Dict[str, List[torch.fx.Node]]] = {}
154+
155+
for node in graph.nodes:
156+
# The metadata source_fn should contain a tuple of a unique name for the
157+
# source, and the source function if the node is decomposed from a
158+
# function, or the type of module if the node is decomposed from a leaf
159+
# module
160+
# TODO(matthiascremon): look into ways to avoid using source_fn_stack
161+
if (source_fn_st := node.meta.get("source_fn_stack")) is None:
162+
continue
163+
164+
source_fn = source_fn_st[-1]
165+
if node.target not in wanted_original_aten_op:
166+
continue
167+
168+
diff_modules = modules.setdefault(source_fn[1], {})
169+
partition = diff_modules.setdefault(node.name, [])
170+
partition.append(node)
171+
172+
def make_partition(
173+
nodes: List[torch.fx.Node], module_type: Type
174+
) -> SourcePartition:
175+
input_nodes = set()
176+
output_nodes = set()
177+
params = set()
178+
for node in nodes:
179+
for arg in node.args:
180+
if isinstance(arg, torch.fx.Node) and arg not in nodes:
181+
input_nodes.add(arg)
182+
183+
if node.op == "get_attr":
184+
params.add(node)
185+
186+
for user in node.users.keys():
187+
if user not in nodes:
188+
output_nodes.add(node)
189+
190+
return SourcePartition(
191+
nodes,
192+
module_type,
193+
list(input_nodes),
194+
list(output_nodes),
195+
list(params), # type: ignore[arg-type]
196+
)
197+
198+
ret: Dict[Type[Any], List[SourcePartition]] = {}
199+
200+
for k, v in modules.items():
201+
ret[k] = [make_partition(partition, k) for partition in v.values()]
202+
203+
return ret
204+
205+
206+
def _partitions_sequential(partitions: Tuple[SourcePartition]) -> bool:
207+
prev_partition = None
208+
for partition in partitions:
209+
if prev_partition is not None and not check_subgraphs_connected(
210+
prev_partition, partition
211+
):
212+
return False
213+
prev_partition = partition
214+
return True
215+
216+
217+
def find_sequential_partitions_aten(
218+
gm: torch.fx.GraphModule,
219+
partition_types: List[Any],
220+
) -> List[SourcePartition]:
221+
typed_partitions: OrderedDict[Any, List[SourcePartition]] = OrderedDict()
222+
for partition_type in partition_types:
223+
partitions = get_aten_node_target_partitions(gm.graph, [partition_type])
224+
typed_partitions[partition_type] = list(
225+
itertools.chain.from_iterable(partitions.values())
226+
)
227+
228+
typed_partitions_list = list(typed_partitions.values())
229+
fusion_candidates = itertools.product(*typed_partitions_list)
230+
fused_partitions = []
231+
for candidate in fusion_candidates:
232+
if _partitions_sequential(candidate):
233+
fused_partitions.append(candidate)
234+
return fused_partitions

backends/cadence/aot/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
# pyre-strict
8+
79
import logging
810
import operator
911
from typing import Dict, List, Tuple
@@ -116,7 +118,7 @@ def get_ops_count(graph_module: torch.fx.GraphModule) -> Dict[str, int]:
116118
def print_ops_info(
117119
to_edge_gm: torch.fx.GraphModule,
118120
jarvis_gm: torch.fx.GraphModule,
119-
):
121+
) -> None:
120122
to_edge_ops_count = get_ops_count(to_edge_gm)
121123
jarvis_ops_count = get_ops_count(jarvis_gm)
122124

0 commit comments

Comments
 (0)