Skip to content

Commit a4164b5

Browse files
committed
Decouple custom ops in llama_transformer.py Part 1/N
1 parent d1bc794 commit a4164b5

File tree

4 files changed

+101
-56
lines changed

4 files changed

+101
-56
lines changed

examples/models/llama2/builder.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -202,11 +202,7 @@ def source_transform(
202202
def _get_dynamic_shape(self) -> Any:
203203
dim = torch.export.Dim("token_dim", max=self.model.params.max_seq_len - 1)
204204
if self.use_kv_cache:
205-
if self.use_sdpa_with_kv_cache:
206-
return None
207-
else:
208-
# return {1: dim}, {0: dim}} TODO update xnnpack to be able to handle dynamic shape kv cache
209-
return None
205+
return None
210206
else:
211207
return ({1: dim},)
212208

examples/models/llama2/export_llama_lib.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,10 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
482482
if args.expand_rope_table:
483483
transforms.append(materialze_broadcast_of_rope_freq_cis)
484484

485+
if args.use_sdpa_with_kv_cache:
486+
pass
487+
# TODO: Next diff transforms.append()
488+
485489
return (
486490
load_llama_model(
487491
checkpoint=checkpoint_path,

examples/models/llama2/llama_transformer.py

Lines changed: 95 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,93 @@ def update(
193193
return k_out, v_out
194194

195195

196+
class SDPA(nn.Module):
197+
def __init__(
198+
self,
199+
kv_cache: KVCache,
200+
mask,
201+
use_sdpa_with_kv_cache_op: bool,
202+
dim: int,
203+
):
204+
super().__init__()
205+
self.kv_cache = kv_cache
206+
self.mask = mask
207+
self.use_sdpa_with_kv_cache_op = use_sdpa_with_kv_cache_op
208+
self.dim = dim
209+
210+
def forward(
211+
self,
212+
input_pos: torch.Tensor,
213+
q: torch.Tensor,
214+
k: torch.Tensor,
215+
v: torch.Tensor,
216+
bsz,
217+
seqlen,
218+
) -> torch.Tensor:
219+
if not self.use_sdpa_with_kv_cache_op:
220+
return self._forward_default(
221+
input_pos,
222+
q,
223+
k,
224+
v,
225+
bsz,
226+
seqlen,
227+
)
228+
else:
229+
return self._forward_custom(
230+
input_pos,
231+
q,
232+
k,
233+
v,
234+
bsz,
235+
seqlen,
236+
)
237+
238+
def _forward_custom(
239+
self,
240+
input_pos: torch.Tensor,
241+
q: torch.Tensor,
242+
k: torch.Tensor,
243+
v: torch.Tensor,
244+
bsz,
245+
seqlen,
246+
):
247+
from .custom_ops import sdpa_with_kv_cache # noqa
248+
249+
output = torch.ops.llama.sdpa_with_kv_cache(
250+
q,
251+
k,
252+
v,
253+
self.kv_cache.k_cache,
254+
self.kv_cache.v_cache,
255+
input_pos[-1].item(),
256+
seqlen,
257+
)
258+
return output.view(bsz, seqlen, self.dim)
259+
260+
def _forward_default(
261+
self,
262+
input_pos: torch.Tensor,
263+
q: torch.Tensor,
264+
k: torch.Tensor,
265+
v: torch.Tensor,
266+
bsz,
267+
seqlen,
268+
) -> torch.Tensor:
269+
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
270+
k = k.transpose(1, 2)
271+
v = v.transpose(1, 2)
272+
273+
k, v = self.kv_cache.update(input_pos, k, v)
274+
mask = self.mask[None, None, input_pos]
275+
276+
k = k.repeat_interleave(self.n_rep, dim=1)
277+
v = v.repeat_interleave(self.n_rep, dim=1)
278+
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
279+
280+
return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
281+
282+
196283
class Attention(nn.Module):
197284
def __init__(self, args: ModelArgs, layer_id: int):
198285
super().__init__()
@@ -213,7 +300,6 @@ def __init__(self, args: ModelArgs, layer_id: int):
213300
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
214301
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
215302

216-
self.use_sdpa_with_kv_cache_op = args.use_sdpa_with_kv_cache_op
217303
self.layer_id = layer_id
218304

219305
causal_mask = torch.tril(
@@ -234,6 +320,12 @@ def __init__(self, args: ModelArgs, layer_id: int):
234320
self.head_dim,
235321
not args.use_sdpa_with_kv_cache_op, # if we are using the custom op dont transpose the cache. Expect untransposed q k v
236322
)
323+
self.SDPA = SDPA(
324+
self.kv_cache,
325+
self.mask,
326+
args.use_sdpa_with_kv_cache_op,
327+
self.dim,
328+
)
237329

238330
def forward(
239331
self,
@@ -256,41 +348,8 @@ def forward(
256348

257349
if self.use_kv_cache:
258350
assert input_pos is not None
259-
260-
if not self.use_sdpa_with_kv_cache_op:
261-
262-
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
263-
k = k.transpose(1, 2)
264-
v = v.transpose(1, 2)
265-
266-
k, v = self.kv_cache.update(input_pos, k, v)
267-
mask = self.mask[None, None, input_pos]
268-
269-
k = k.repeat_interleave(self.n_rep, dim=1)
270-
v = v.repeat_interleave(self.n_rep, dim=1)
271-
y = F.scaled_dot_product_attention(
272-
q, k, v, attn_mask=mask, dropout_p=0.0
273-
)
274-
275-
y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
276-
277-
y = self.wo(y)
278-
return y
279-
else:
280-
from .custom_ops import sdpa_with_kv_cache # noqa
281-
282-
output = torch.ops.llama.sdpa_with_kv_cache(
283-
q,
284-
k,
285-
v,
286-
self.kv_cache.k_cache,
287-
self.kv_cache.v_cache,
288-
input_pos[-1].item(),
289-
seqlen,
290-
)
291-
output = output.view(bsz, seqlen, -1)
292-
output = self.wo(output)
293-
return output
351+
output = self.SDPA(input_pos, q, k, v, bsz, seqlen)
352+
return self.wo(output)
294353

295354
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
296355
k = k.transpose(1, 2)

examples/models/llama2/model.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -173,11 +173,7 @@ def get_eager_model(self):
173173

174174
def get_example_inputs(self):
175175
if self.use_kv_cache:
176-
if self.use_sdpa_with_kv_cache_op:
177-
return self.get_example_inputs_kvcache_sdpa()
178-
else:
179-
# return self.get_example_inputs_kvcache() TODO xnnpack does not handle forwarding symints, update partitioner to not partition symints
180-
return self.get_example_inputs_kvcache_sdpa()
176+
return self.get_example_inputs_kvcache_sdpa()
181177
else:
182178
return (
183179
torch.tensor(
@@ -195,13 +191,3 @@ def get_example_inputs_kvcache_sdpa(self):
195191
[0], dtype=torch.long
196192
), # start_pos, what token of output are we on.)
197193
)
198-
199-
def get_example_inputs_kvcache(self):
200-
return (
201-
torch.tensor(
202-
[[1, 2, 3]], dtype=torch.long
203-
), # tokens, with kv cache our input token length is always just 1 token.
204-
torch.tensor(
205-
[0, 1, 2], dtype=torch.long
206-
), # start_pos, what token of output are we on.
207-
)

0 commit comments

Comments
 (0)