Skip to content

Commit 7054b1f

Browse files
authored
LLM export pass to swap in custom SDPA
Differential Revision: D73444078 Pull Request resolved: #10355
1 parent c723212 commit 7054b1f

File tree

3 files changed

+168
-1
lines changed

3 files changed

+168
-1
lines changed

extension/llm/export/TARGETS

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,21 @@ runtime.python_library(
4141
"//executorch/exir:lib",
4242
"//executorch/exir/backend:backend_details",
4343
"//executorch/extension/export_util:export_util",
44+
"//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
45+
"//executorch/extension/llm/custom_ops:custom_ops_aot_py",
4446
"//pytorch/tokenizers/pytorch_tokenizers:tokenizers",
4547
],
4648
)
49+
50+
runtime.python_test(
51+
name = "export_passes_test",
52+
srcs = [
53+
"test_export_passes.py",
54+
],
55+
preload_deps = [
56+
"//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
57+
],
58+
deps = [
59+
":export_lib",
60+
],
61+
)

extension/llm/export/export_passes.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import logging
2+
13
import torch
24

35
from executorch.exir.pass_base import ExportPass
@@ -95,3 +97,106 @@ def call(self, graph_module: torch.fx.GraphModule):
9597
graph_module.recompile()
9698

9799
return PassResult(graph_module, graph_changed)
100+
101+
102+
class ReplaceSDPAWithCustomSDPAPass(ExportPass):
103+
"""
104+
This pass replaces aten.scaled_dot_product_attention.default with llama.custom_sdpa.default.
105+
If assume_causal_mask is set to True, this pass will ignore any explicit masks and simply set
106+
is_causal to True in custoom_spda.
107+
"""
108+
109+
def __init__(self, assume_causal_mask=False):
110+
super().__init__()
111+
self.assume_causal_mask = assume_causal_mask
112+
113+
def call_operator(self, op, args, kwargs, meta):
114+
from executorch.extension.llm.custom_ops import custom_ops # noqa
115+
116+
if op != torch.ops.aten.scaled_dot_product_attention.default:
117+
return super().call_operator(op, args, kwargs, meta)
118+
119+
q, k, v, mask, dropout, is_causal, scale = self._extract_args(args, kwargs)
120+
121+
qT = self._transpose(q, meta)
122+
kT = self._transpose(k, meta)
123+
vT = self._transpose(v, meta)
124+
125+
if not (
126+
q.node.meta["val"].dim()
127+
== k.node.meta["val"].dim()
128+
== v.node.meta["val"].dim()
129+
== 4
130+
):
131+
logging.info("ReplaceSDPAWithCustomSDPAPass only supports 4D QKV inputs.")
132+
return super().call_operator(op, args, kwargs, meta)
133+
134+
if self.assume_causal_mask:
135+
# Ignore specified mask simply set the is_causal flag.
136+
mask = None
137+
is_causal = True
138+
139+
if mask is not None:
140+
mask_fake_tensor = mask.node.meta["val"]
141+
if mask_fake_tensor.dim() > 2:
142+
if all(d == 1 for d in mask_fake_tensor.size()[:-2]):
143+
mask = super().call_operator(
144+
torch.ops.aten.squeeze.dims,
145+
(mask, tuple(i for i in range(mask_fake_tensor.dim() - 2))),
146+
{},
147+
meta,
148+
)
149+
else:
150+
logging.info(
151+
"ReplaceSDPAWithCustomSDPAPass only supports 2D attention mask."
152+
)
153+
return super().call_operator(op, args, kwargs, meta)
154+
155+
# TODO(kimishpatel): Remove once custom SDPA supports boolean mask.
156+
if mask_fake_tensor.dtype == torch.bool:
157+
mask = super().call_operator(
158+
torch.ops.aten.where.Scalar,
159+
(mask, 0.0, float("-inf")),
160+
{},
161+
meta,
162+
)
163+
164+
custom_sdpa = super().call_operator(
165+
torch.ops.llama.custom_sdpa.default,
166+
(qT, kT, vT, 0, mask, dropout, is_causal, scale),
167+
{},
168+
meta,
169+
)
170+
return self._transpose(custom_sdpa, meta)
171+
172+
def _extract_args(self, args, kwargs):
173+
q, k, v, *rest = args
174+
mask = None
175+
dropout = 0.0
176+
is_causal = False
177+
scale = None
178+
if len(rest) > 0:
179+
mask = rest[0]
180+
if len(rest) > 1:
181+
dropout = rest[1]
182+
if len(rest) > 2:
183+
is_causal = rest[2]
184+
if "scale" in kwargs:
185+
scale = kwargs["scale"]
186+
187+
return q, k, v, mask, dropout, is_causal, scale
188+
189+
def _transpose(self, x, meta):
190+
transpose = super().call_operator(
191+
torch.ops.aten.transpose.int,
192+
(x, 1, 2),
193+
{},
194+
meta,
195+
)
196+
contiguous = super().call_operator(
197+
torch.ops.aten.contiguous.default,
198+
(transpose,),
199+
{},
200+
meta,
201+
)
202+
return contiguous

extension/llm/export/test_export_passes.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22

33
import torch
44

5-
from executorch.extension.llm.export.export_passes import RemoveRedundantTransposes
5+
from executorch.extension.llm.export.export_passes import (
6+
RemoveRedundantTransposes,
7+
ReplaceSDPAWithCustomSDPAPass,
8+
)
69

710
from torch.export import export_for_training
811
from torch.testing import FileCheck
@@ -160,3 +163,47 @@ def forward(self, x):
160163

161164
m = TestModule2()
162165
self._check(m, (x,), key, 3, 2)
166+
167+
168+
class ReplaceSDPAWithCustomSDPAPassTest(unittest.TestCase):
169+
class TestModule(torch.nn.Module):
170+
def forward(self, x, mask, is_causal):
171+
return torch.nn.functional.scaled_dot_product_attention(
172+
x, x, x, attn_mask=mask, is_causal=is_causal
173+
)
174+
175+
def setUp(self):
176+
torch.manual_seed(0)
177+
178+
def _test(self, args, assume_causal_mask=False):
179+
m = self.TestModule()
180+
gm = export_for_training(m, args, strict=True).module()
181+
182+
sdpa_key = "torch.ops.aten.scaled_dot_product_attention.default"
183+
custom_sdpa_key = "torch.ops.llama.custom_sdpa.default"
184+
FileCheck().check_count(sdpa_key, 1, exactly=True).run(gm.code)
185+
gm = ReplaceSDPAWithCustomSDPAPass(assume_causal_mask)(gm).graph_module
186+
FileCheck().check_count(sdpa_key, 0, exactly=True).run(gm.code)
187+
FileCheck().check_count(custom_sdpa_key, 1, exactly=True).run(gm.code)
188+
189+
y1 = m(*args)
190+
y2 = gm(*args)
191+
self.assertTrue(torch.allclose(y1, y2))
192+
193+
def test_causal_mask(self):
194+
self._test((torch.rand(1, 4, 32, 64), None, True))
195+
196+
def test_explicit_causal_mask(self):
197+
mask = torch.tril(torch.ones(32, 32, dtype=torch.bool))
198+
self._test((torch.rand(1, 4, 32, 64), mask, False), assume_causal_mask=True)
199+
200+
def test_custom_mask(self):
201+
m1 = torch.tril(torch.ones(32, 32, dtype=torch.bool))
202+
m2 = torch.tril(torch.ones(32, 32, dtype=torch.bool), diagonal=-16)
203+
self._test((torch.rand(1, 4, 32, 64), torch.logical_xor(m1, m2), False))
204+
205+
def test_squeezable_mask(self):
206+
m1 = torch.tril(torch.ones(32, 32, dtype=torch.bool))
207+
m2 = torch.tril(torch.ones(32, 32, dtype=torch.bool), diagonal=-16)
208+
m = torch.logical_xor(m1, m2).view(1, 1, 32, 32)
209+
self._test((torch.rand(1, 4, 32, 64), m, False))

0 commit comments

Comments
 (0)