Skip to content

Commit dd6dab6

Browse files
mergennachinfacebook-github-bot
authored andcommitted
Decouple custom ops in llama_transformer.py Part 1/N (#3005)
Summary: This is a no-op Test Plan: CI Run with `python -m examples.models.llama2.export_llama -c stories110M.pt -p params.json -kv --use_sdpa_with_kv_cache -X` and with `python -m examples.models.llama2.export_llama -c stories110M.pt -p params.json -kv -X` Make sure both work Differential Revision: D56048177 Pulled By: mergennachin
1 parent 6acc86f commit dd6dab6

File tree

4 files changed

+104
-56
lines changed

4 files changed

+104
-56
lines changed

examples/models/llama2/builder.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -206,11 +206,7 @@ def source_transform(
206206
def _get_dynamic_shape(self) -> Any:
207207
dim = torch.export.Dim("token_dim", max=self.model.params.max_seq_len - 1)
208208
if self.use_kv_cache:
209-
if self.use_sdpa_with_kv_cache:
210-
return None
211-
else:
212-
# return {1: dim}, {0: dim}} TODO update xnnpack to be able to handle dynamic shape kv cache
213-
return None
209+
return None
214210
else:
215211
return ({1: dim},)
216212

examples/models/llama2/export_llama_lib.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,10 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
492492
if args.expand_rope_table:
493493
transforms.append(materialze_broadcast_of_rope_freq_cis)
494494

495+
if args.use_sdpa_with_kv_cache:
496+
pass
497+
# TODO: Next diff transforms.append()
498+
495499
return (
496500
load_llama_model(
497501
checkpoint=checkpoint_path,

examples/models/llama2/llama_transformer.py

Lines changed: 98 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,95 @@ def update(
209209
return k_out, v_out
210210

211211

212+
class SDPA(nn.Module):
213+
def __init__(
214+
self,
215+
kv_cache: KVCache,
216+
mask,
217+
use_sdpa_with_kv_cache_op: bool,
218+
dim: int,
219+
n_rep: int,
220+
):
221+
super().__init__()
222+
self.kv_cache = kv_cache
223+
self.mask = mask
224+
self.use_sdpa_with_kv_cache_op = use_sdpa_with_kv_cache_op
225+
self.dim = dim
226+
self.n_rep = n_rep
227+
228+
def forward(
229+
self,
230+
input_pos: torch.Tensor,
231+
q: torch.Tensor,
232+
k: torch.Tensor,
233+
v: torch.Tensor,
234+
bsz,
235+
seqlen,
236+
) -> torch.Tensor:
237+
if not self.use_sdpa_with_kv_cache_op:
238+
return self._forward_default(
239+
input_pos,
240+
q,
241+
k,
242+
v,
243+
bsz,
244+
seqlen,
245+
)
246+
else:
247+
return self._forward_custom(
248+
input_pos,
249+
q,
250+
k,
251+
v,
252+
bsz,
253+
seqlen,
254+
)
255+
256+
def _forward_custom(
257+
self,
258+
input_pos: torch.Tensor,
259+
q: torch.Tensor,
260+
k: torch.Tensor,
261+
v: torch.Tensor,
262+
bsz,
263+
seqlen,
264+
):
265+
from .custom_ops import sdpa_with_kv_cache # noqa
266+
267+
output = torch.ops.llama.sdpa_with_kv_cache(
268+
q,
269+
k,
270+
v,
271+
self.kv_cache.k_cache,
272+
self.kv_cache.v_cache,
273+
input_pos[-1].item(),
274+
seqlen,
275+
)
276+
return output.view(bsz, seqlen, self.dim)
277+
278+
def _forward_default(
279+
self,
280+
input_pos: torch.Tensor,
281+
q: torch.Tensor,
282+
k: torch.Tensor,
283+
v: torch.Tensor,
284+
bsz,
285+
seqlen,
286+
) -> torch.Tensor:
287+
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
288+
k = k.transpose(1, 2)
289+
v = v.transpose(1, 2)
290+
291+
k, v = self.kv_cache.update(input_pos, k, v)
292+
mask = self.mask[None, None, input_pos]
293+
294+
k = k.repeat_interleave(self.n_rep, dim=1)
295+
v = v.repeat_interleave(self.n_rep, dim=1)
296+
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
297+
298+
return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
299+
300+
212301
class Attention(nn.Module):
213302
def __init__(self, args: ModelArgs, layer_id: int):
214303
super().__init__()
@@ -229,7 +318,6 @@ def __init__(self, args: ModelArgs, layer_id: int):
229318
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
230319
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
231320

232-
self.use_sdpa_with_kv_cache_op = args.use_sdpa_with_kv_cache_op
233321
self.layer_id = layer_id
234322

235323
causal_mask = torch.tril(
@@ -250,6 +338,13 @@ def __init__(self, args: ModelArgs, layer_id: int):
250338
self.head_dim,
251339
not args.use_sdpa_with_kv_cache_op, # if we are using the custom op dont transpose the cache. Expect untransposed q k v
252340
)
341+
self.SDPA = SDPA(
342+
self.kv_cache,
343+
self.mask,
344+
args.use_sdpa_with_kv_cache_op,
345+
self.dim,
346+
self.n_rep,
347+
)
253348

254349
def forward(
255350
self,
@@ -272,41 +367,8 @@ def forward(
272367

273368
if self.use_kv_cache:
274369
assert input_pos is not None
275-
276-
if not self.use_sdpa_with_kv_cache_op:
277-
278-
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
279-
k = k.transpose(1, 2)
280-
v = v.transpose(1, 2)
281-
282-
k, v = self.kv_cache.update(input_pos, k, v)
283-
mask = self.mask[None, None, input_pos]
284-
285-
k = k.repeat_interleave(self.n_rep, dim=1)
286-
v = v.repeat_interleave(self.n_rep, dim=1)
287-
y = F.scaled_dot_product_attention(
288-
q, k, v, attn_mask=mask, dropout_p=0.0
289-
)
290-
291-
y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
292-
293-
y = self.wo(y)
294-
return y
295-
else:
296-
from .custom_ops import sdpa_with_kv_cache # noqa
297-
298-
output = torch.ops.llama.sdpa_with_kv_cache(
299-
q,
300-
k,
301-
v,
302-
self.kv_cache.k_cache,
303-
self.kv_cache.v_cache,
304-
input_pos[-1].item(),
305-
seqlen,
306-
)
307-
output = output.view(bsz, seqlen, -1)
308-
output = self.wo(output)
309-
return output
370+
output = self.SDPA(input_pos, q, k, v, bsz, seqlen)
371+
return self.wo(output)
310372

311373
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
312374
k = k.transpose(1, 2)

examples/models/llama2/model.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -209,11 +209,7 @@ def get_eager_model(self):
209209

210210
def get_example_inputs(self):
211211
if self.use_kv_cache:
212-
if self.use_sdpa_with_kv_cache_op:
213-
return self.get_example_inputs_kvcache_sdpa()
214-
else:
215-
# return self.get_example_inputs_kvcache() TODO xnnpack does not handle forwarding symints, update partitioner to not partition symints
216-
return self.get_example_inputs_kvcache_sdpa()
212+
return self.get_example_inputs_kvcache_sdpa()
217213
else:
218214
return (
219215
torch.tensor(
@@ -231,13 +227,3 @@ def get_example_inputs_kvcache_sdpa(self):
231227
[0], dtype=torch.long
232228
), # start_pos, what token of output are we on.)
233229
)
234-
235-
def get_example_inputs_kvcache(self):
236-
return (
237-
torch.tensor(
238-
[[1, 2, 3]], dtype=torch.long
239-
), # tokens, with kv cache our input token length is always just 1 token.
240-
torch.tensor(
241-
[0, 1, 2], dtype=torch.long
242-
), # start_pos, what token of output are we on.
243-
)

0 commit comments

Comments
 (0)