Skip to content

Commit a11afab

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Run decompositions before the quantizer (#7111)
Summary: In the current flow, decompositions run in `to_edge()`, long after the quantization process is done. This creates a lot of issues, since we cannot quantize any operations contained in the large operators that the graph tracer can give (e.g. aten.scaled_dot_product_attention, aten.rnn_<tanh, relu>.input, and a few others). Any models using those will see many fp32 operators in the final graph. Running the decomps earlier solves the problem, but we need to retain a couple operators that we do rely on in the quantizer, like `aten.linear`, `aten.conv1d` and `aten.conv2d`. Reviewed By: zonglinpeng Differential Revision: D66461406
1 parent 2d499b3 commit a11afab

File tree

1 file changed

+22
-4
lines changed

1 file changed

+22
-4
lines changed

backends/cadence/aot/compiler.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
to_edge,
2929
)
3030
from executorch.exir.pass_base import PassResult
31+
from torch._inductor.decomposition import remove_decompositions
3132
from torch.ao.quantization.pt2e.export_utils import model_is_exported
3233
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
3334

@@ -58,16 +59,33 @@ def convert_pt2(
5859
Returns a GraphModule with the converted model.
5960
"""
6061

62+
# Get default decompositions
63+
decomp_table = torch.export.default_decompositions()
64+
# Select ops to keep
65+
ops_to_keep = [
66+
torch.ops.aten.conv1d.default,
67+
torch.ops.aten.conv2d.default,
68+
torch.ops.aten.layer_norm.default,
69+
torch.ops.aten.linear.default,
70+
torch.ops.aten.matmul.default,
71+
]
72+
# Remove decompositions for the ops we want to keep
73+
# pyre-fixme[6]: For 1st argument expected `Dict[typing.Callable[..., typing.Any
74+
remove_decompositions(decomp_table, ops_to_keep)
6175
# Export with dynamo
62-
model_gm = torch.export.export_for_training(model, inputs).module()
76+
model_gm = (
77+
torch.export.export_for_training(model, inputs)
78+
.run_decompositions(decomp_table)
79+
.module()
80+
)
6381

64-
if model_gm_has_SDPA(model_gm): # pyre-fixme[6]
82+
if model_gm_has_SDPA(model_gm):
6583
# Decompose SDPA
66-
DecomposeScaledDotProductAttention(False)(model_gm) # pyre-fixme[6]
84+
DecomposeScaledDotProductAttention(False)(model_gm)
6785

6886
# Swap _safe_softmax with _softmax (see https://github.com/pytorch/pytorch/pull/133882
6987
# for details).
70-
result = ReplaceSafeSoftmaxWithSoftmax()(model_gm) # pyre-fixme[6]
88+
result = ReplaceSafeSoftmaxWithSoftmax()(model_gm)
7189
assert result is not None
7290
model_gm = result.graph_module
7391

0 commit comments

Comments
 (0)