Skip to content

Commit ae4d8d5

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Move SDPA decomp pass from Qualcomm's directory to be shareable and call it for Cadence backends (#4258)
Summary: Pull Request resolved: #4258 Moving the pass to backends/transforms so that other backends can call it, and call it from the Cadence side so that we can quantize the bmm ops in SDPA. Reviewed By: cccclai Differential Revision: D59600486 fbshipit-source-id: c5e3209be1d3af89903ab31c3c13d2b0e9a23925
1 parent 6903715 commit ae4d8d5

File tree

5 files changed

+38
-7
lines changed

5 files changed

+38
-7
lines changed

backends/cadence/aot/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ python_library(
3434
"//caffe2:torch",
3535
"//executorch/backends/cadence/aot/quantizer:fusion_pass",
3636
"//executorch/backends/cadence/aot/quantizer:quantizer",
37+
"//executorch/backends/transforms:decompose_sdpa",
3738
"//executorch/exir:lib",
3839
],
3940
)

backends/cadence/aot/compiler.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
CadenceQuantizer,
2424
)
2525
from executorch.backends.cadence.aot.utils import model_is_quantized
26+
from executorch.backends.transforms.decompose_sdpa import (
27+
DecomposeScaledDotProductAttention,
28+
)
2629
from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge
2730
from pyre_extensions import assert_is_instance
2831
from torch._export import capture_pre_autograd_graph
@@ -47,6 +50,9 @@ def quantize_pt2(
4750
# Export with dynamo
4851
model_exp = capture_pre_autograd_graph(model, inputs)
4952

53+
# Decompose SDPA
54+
DecomposeScaledDotProductAttention(False)(model_exp)
55+
5056
# Prepare
5157
prepared_model = prepare_pt2e(model_exp, quantizer)
5258

backends/qualcomm/quantizer/quantizer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,16 @@
77
from typing import Callable, Dict, Optional, Sequence, Set
88

99
import torch
10-
from executorch.backends.qualcomm.passes.decompose_scaled_dot_product_attention import (
11-
DecomposeScaledDotProductAttention,
12-
)
1310
from executorch.backends.qualcomm.passes.decompose_silu import DecomposeSilu
1411
from executorch.backends.qualcomm.passes.recompose_pixel_unshuffle import (
1512
RecomposePixelUnshuffle,
1613
)
1714
from executorch.backends.qualcomm.passes.reduce_dynamic_range import ReduceDynamicRange
1815
from executorch.backends.qualcomm.passes.remove_redundancy import RemoveRedundancy
1916
from executorch.backends.qualcomm.passes.replace_inf_buffer import ReplaceInfBuffer
17+
from executorch.backends.transforms.decompose_sdpa import (
18+
DecomposeScaledDotProductAttention,
19+
)
2020

2121
from torch._ops import OpOverload
2222
from torch.ao.quantization.quantizer import Quantizer

backends/transforms/TARGETS

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,19 @@ runtime.python_library(
2929
],
3030
)
3131

32+
runtime.python_library(
33+
name = "decompose_sdpa",
34+
srcs = ["decompose_sdpa.py"],
35+
visibility = [
36+
"//executorch/backends/...",
37+
"@EXECUTORCH_CLIENTS",
38+
],
39+
deps = [
40+
"//caffe2:torch",
41+
"//executorch/exir:pass_base",
42+
],
43+
)
44+
3245
runtime.python_library(
3346
name = "fuse_batch_norm_with_conv",
3447
srcs = ["fuse_batch_norm_with_conv.py"],

backends/qualcomm/passes/decompose_scaled_dot_product_attention.py renamed to backends/transforms/decompose_sdpa.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1-
# Copyright (c) Qualcomm Innovation Center, Inc.
2-
# All rights reserved
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
69
import torch
710
from executorch.exir.pass_base import ExportPass, PassResult
811
from torch._decomp import get_decompositions
@@ -14,7 +17,15 @@ class DecomposeScaledDotProductAttention(ExportPass):
1417
Decompose from scaled_dot_product_attention to multiple nodes.
1518
"""
1619

17-
def call(self, graph_module: torch.fx.GraphModule):
20+
def __init__(self, allow_non_fake_inputs: bool = True) -> None:
21+
super().__init__()
22+
# With allow_non_fake_inputs=False, we don't get _unsafe_view ops
23+
# in the graph, we allow disabling it here.
24+
self._allow_non_fake_inputs = allow_non_fake_inputs
25+
26+
def call(
27+
self, graph_module: torch.fx.GraphModule, allow_non_fake_inputs: bool = True
28+
) -> PassResult:
1829
graph = graph_module.graph
1930
for node in graph.nodes:
2031
if node.target == torch.ops.aten.scaled_dot_product_attention.default:
@@ -29,7 +40,7 @@ def call(self, graph_module: torch.fx.GraphModule):
2940
]
3041
),
3142
tracing_mode="fake",
32-
_allow_non_fake_inputs=True,
43+
_allow_non_fake_inputs=allow_non_fake_inputs,
3344
)(*input_tensors)
3445
with graph.inserting_before(node):
3546
name_to_input_tensor_map = {}

0 commit comments

Comments
 (0)