Skip to content

Commit 81c7522

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Run decompositions before the quantizer (#7111)
Summary: In the current flow, decompositions run in `to_edge()`, long after the quantization process is done. This creates a lot of issues, since we cannot quantize any operations contained in the large operators that the graph tracer can give (e.g. aten.scaled_dot_product_attention, aten.rnn_<tanh, relu>.input, and a few others). Any models using those will see many fp32 operators in the final graph. Running the decomps earlier solves the problem, but we need to retain a couple operators that we do rely on in the quantizer, like `aten.linear`, `aten.conv1d` and `aten.conv2d`. Reviewed By: zonglinpeng Differential Revision: D66461406
1 parent 2d499b3 commit 81c7522

File tree

1 file changed

+19
-4
lines changed

1 file changed

+19
-4
lines changed

backends/cadence/aot/compiler.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
to_edge,
2929
)
3030
from executorch.exir.pass_base import PassResult
31+
from torch._inductor.decomposition import remove_decompositions
3132
from torch.ao.quantization.pt2e.export_utils import model_is_exported
3233
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
3334

@@ -59,15 +60,29 @@ def convert_pt2(
5960
"""
6061

6162
# Export with dynamo
62-
model_gm = torch.export.export_for_training(model, inputs).module()
63+
decomp_table = torch.export.default_decompositions()
64+
ops_to_keep = [
65+
torch.ops.aten.conv1d.default,
66+
torch.ops.aten.conv2d.default,
67+
torch.ops.aten.layer_norm.default,
68+
torch.ops.aten.linear.default,
69+
torch.ops.aten.matmul.default,
70+
]
71+
# pyre-fixme[6]: For 1st argument expected `Dict[typing.Callable[..., typing.Any
72+
remove_decompositions(decomp_table, ops_to_keep)
73+
model_gm = (
74+
torch.export.export_for_training(model, inputs)
75+
.run_decompositions(decomp_table)
76+
.module()
77+
)
6378

64-
if model_gm_has_SDPA(model_gm): # pyre-fixme[6]
79+
if model_gm_has_SDPA(model_gm):
6580
# Decompose SDPA
66-
DecomposeScaledDotProductAttention(False)(model_gm) # pyre-fixme[6]
81+
DecomposeScaledDotProductAttention(False)(model_gm)
6782

6883
# Swap _safe_softmax with _softmax (see https://github.com/pytorch/pytorch/pull/133882
6984
# for details).
70-
result = ReplaceSafeSoftmaxWithSoftmax()(model_gm) # pyre-fixme[6]
85+
result = ReplaceSafeSoftmaxWithSoftmax()(model_gm)
7186
assert result is not None
7287
model_gm = result.graph_module
7388

0 commit comments

Comments
 (0)