Skip to content

Commit fbd4d36

Browse files
mergennachinfacebook-github-bot
authored andcommitted
Decouple custom ops in llama_transformer.py Part 2/N (#3007)
Summary: Pull Request resolved: #3007 Keep llama_transformer.py to look like stock implementation, so that it can be reused everywhere. Do module swap Differential Revision: D56048640
1 parent d0d3ec9 commit fbd4d36

File tree

3 files changed

+59
-57
lines changed

3 files changed

+59
-57
lines changed

examples/models/llama2/TARGETS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ runtime.python_library(
1818
],
1919
deps = [
2020
"//caffe2:torch",
21-
"//executorch/examples/models/llama2/custom_ops:custom_ops_aot_py",
2221
],
2322
)
2423

@@ -86,6 +85,7 @@ runtime.python_library(
8685
"//executorch/backends/vulkan/partitioner:vulkan_partitioner",
8786
"//executorch/examples/models:model_base",
8887
"//executorch/examples/models:models",
88+
"//executorch/examples/models/llama2/custom_ops:custom_ops_aot_py",
8989
"//executorch/examples/portable:utils",
9090
"//executorch/exir:lib",
9191
"//executorch/sdk/etrecord:etrecord",

examples/models/llama2/export_llama_lib.py

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,11 @@
2323
XnnpackDynamicallyQuantizedPartitioner,
2424
)
2525

26-
from executorch.examples.models.llama2.llama_transformer import Transformer
26+
from executorch.examples.models.llama2.llama_transformer import (
27+
KVCache,
28+
SDPA,
29+
Transformer,
30+
)
2731
from executorch.exir.backend.backend_details import CompileSpec
2832

2933
from executorch.sdk.etrecord import generate_etrecord
@@ -88,6 +92,58 @@ def materialze_broadcast_of_rope_freq_cis(
8892
return module
8993

9094

95+
class SDPACustom(torch.nn.Module):
96+
def __init__(
97+
self,
98+
kv_cache: KVCache,
99+
mask,
100+
dim: int,
101+
):
102+
super().__init__()
103+
self.kv_cache = kv_cache
104+
self.mask = mask
105+
self.dim = dim
106+
107+
def forward(
108+
self,
109+
input_pos: torch.Tensor,
110+
q: torch.Tensor,
111+
k: torch.Tensor,
112+
v: torch.Tensor,
113+
bsz,
114+
seqlen,
115+
):
116+
output = torch.ops.llama.sdpa_with_kv_cache(
117+
q,
118+
k,
119+
v,
120+
self.kv_cache.k_cache,
121+
self.kv_cache.v_cache,
122+
input_pos[-1].item(),
123+
seqlen,
124+
)
125+
return output.view(bsz, seqlen, self.dim)
126+
127+
128+
def _replace_sdpa_with_custom_op(module: torch.nn.Module):
129+
for name, child in module.named_children():
130+
if isinstance(child, SDPA):
131+
setattr(
132+
module,
133+
name,
134+
SDPACustom(child.kv_cache, child.mask, child.dim),
135+
)
136+
else:
137+
_replace_sdpa_with_custom_op(child)
138+
139+
140+
def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:
141+
from executorch.examples.models.llama2.custom_ops import sdpa_with_kv_cache # noqa
142+
143+
_replace_sdpa_with_custom_op(module)
144+
return module
145+
146+
91147
def quantize(
92148
model: torch.nn.Module,
93149
qmode: str,
@@ -493,8 +549,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
493549
transforms.append(materialze_broadcast_of_rope_freq_cis)
494550

495551
if args.use_sdpa_with_kv_cache:
496-
pass
497-
# TODO: Next diff transforms.append()
552+
transforms.append(replace_sdpa_with_custom_op)
498553

499554
return (
500555
load_llama_model(

examples/models/llama2/llama_transformer.py

Lines changed: 0 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -214,14 +214,12 @@ def __init__(
214214
self,
215215
kv_cache: KVCache,
216216
mask,
217-
use_sdpa_with_kv_cache_op: bool,
218217
dim: int,
219218
n_rep: int,
220219
):
221220
super().__init__()
222221
self.kv_cache = kv_cache
223222
self.mask = mask
224-
self.use_sdpa_with_kv_cache_op = use_sdpa_with_kv_cache_op
225223
self.dim = dim
226224
self.n_rep = n_rep
227225

@@ -233,56 +231,6 @@ def forward(
233231
v: torch.Tensor,
234232
bsz,
235233
seqlen,
236-
) -> torch.Tensor:
237-
if not self.use_sdpa_with_kv_cache_op:
238-
return self._forward_default(
239-
input_pos,
240-
q,
241-
k,
242-
v,
243-
bsz,
244-
seqlen,
245-
)
246-
else:
247-
return self._forward_custom(
248-
input_pos,
249-
q,
250-
k,
251-
v,
252-
bsz,
253-
seqlen,
254-
)
255-
256-
def _forward_custom(
257-
self,
258-
input_pos: torch.Tensor,
259-
q: torch.Tensor,
260-
k: torch.Tensor,
261-
v: torch.Tensor,
262-
bsz,
263-
seqlen,
264-
):
265-
from .custom_ops import sdpa_with_kv_cache # noqa
266-
267-
output = torch.ops.llama.sdpa_with_kv_cache(
268-
q,
269-
k,
270-
v,
271-
self.kv_cache.k_cache,
272-
self.kv_cache.v_cache,
273-
input_pos[-1].item(),
274-
seqlen,
275-
)
276-
return output.view(bsz, seqlen, self.dim)
277-
278-
def _forward_default(
279-
self,
280-
input_pos: torch.Tensor,
281-
q: torch.Tensor,
282-
k: torch.Tensor,
283-
v: torch.Tensor,
284-
bsz,
285-
seqlen,
286234
) -> torch.Tensor:
287235
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
288236
k = k.transpose(1, 2)
@@ -341,7 +289,6 @@ def __init__(self, args: ModelArgs, layer_id: int):
341289
self.SDPA = SDPA(
342290
self.kv_cache,
343291
self.mask,
344-
args.use_sdpa_with_kv_cache_op,
345292
self.dim,
346293
self.n_rep,
347294
)

0 commit comments

Comments
 (0)