Skip to content

Commit d7c069f

Browse files
authored
Fix SDPA decomp problem
Differential Revision: D61639074 Pull Request resolved: #4851
1 parent bf64819 commit d7c069f

File tree

3 files changed

+47
-5
lines changed

3 files changed

+47
-5
lines changed

backends/cadence/aot/compiler.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,13 @@
1818
ReplaceLogicalNotBooleanWhereWithWherePass,
1919
ReplacePT2DequantWithCadenceDequantPass,
2020
ReplacePT2QuantWithCadenceQuantPass,
21+
ReplaceSafeSoftmaxWithSoftmax,
2122
ReplaceScalarTensorWithFullPass,
2223
ReplaceSqueezeAndUnsqueezeWithViewPass,
2324
)
2425
from executorch.backends.cadence.aot.quantizer.fusion_pass import QuantFusion
2526
from executorch.backends.cadence.aot.quantizer.quantizer import CadenceQuantizer
26-
from executorch.backends.cadence.aot.utils import model_is_quantized
27+
from executorch.backends.cadence.aot.utils import model_gm_has_SDPA, model_is_quantized
2728
from executorch.backends.transforms.decompose_sdpa import (
2829
DecomposeScaledDotProductAttention,
2930
)
@@ -57,13 +58,20 @@ def convert_pt2(
5758
"""
5859

5960
# Export with dynamo
60-
model_exp = capture_pre_autograd_graph(model, inputs)
61+
model_gm = capture_pre_autograd_graph(model, inputs)
6162

62-
# Decompose SDPA
63-
DecomposeScaledDotProductAttention(False)(model_exp)
63+
if model_gm_has_SDPA(model_gm):
64+
# Decompose SDPA
65+
DecomposeScaledDotProductAttention(False)(model_gm)
66+
67+
# Swap _safe_softmax with _softmax (see https://github.com/pytorch/pytorch/pull/133882
68+
# for details).
69+
result = ReplaceSafeSoftmaxWithSoftmax()(model_gm)
70+
assert result is not None
71+
model_gm = result.graph_module
6472

6573
# Prepare
66-
prepared_model = prepare_pt2e(model_exp, quantizer)
74+
prepared_model = prepare_pt2e(model_gm, quantizer)
6775

6876
# Calibrate
6977
prepared_model(*inputs)

backends/cadence/aot/passes.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,3 +266,29 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
266266
result = SpecPropPass()(graph_module)
267267
assert result is not None
268268
return result
269+
270+
271+
class ReplaceSafeSoftmaxWithSoftmax(ExportPass):
272+
"""
273+
Replace _safe_softmax with _softmax
274+
"""
275+
276+
def call_operator(
277+
self,
278+
op, # pyre-ignore
279+
args: tuple[Argument, ...],
280+
kwargs: dict[str, Argument],
281+
meta: NodeMetadata,
282+
) -> ProxyValue:
283+
if op != torch.ops.aten._safe_softmax.default:
284+
return super().call_operator(op, args, kwargs, meta)
285+
286+
# Add False for the half_to_float argument of softmax
287+
softmax_args = list(args) + [False]
288+
289+
return super().call_operator(
290+
torch.ops.aten._softmax.default,
291+
tuple(softmax_args),
292+
kwargs,
293+
meta,
294+
)

backends/cadence/aot/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,3 +177,11 @@ def print_ops_info(
177177
tablefmt="outline",
178178
)
179179
)
180+
181+
182+
def model_gm_has_SDPA(model_gm: torch.fx.GraphModule) -> bool:
183+
for node in model_gm.graph.nodes:
184+
if node.op == "call_function":
185+
if node.target == torch.ops.aten.scaled_dot_product_attention.default:
186+
return True
187+
return False

0 commit comments

Comments
 (0)