Skip to content

[ExecuTorch][Llama] Split custom sdpa op and kv cache #7412

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jan 16, 2025
2 changes: 2 additions & 0 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
get_quant_weight_transform,
)
from .source_transformation.quantized_kv_cache import (
replace_kv_cache_with_custom_kv_cache,
replace_kv_cache_with_quantized_kv_cache,
)
from .source_transformation.rms_norm import replace_rms_norm_with_native_rms_norm
Expand Down Expand Up @@ -1058,6 +1059,7 @@ def _get_source_transforms( # noqa
transforms.append(materialze_broadcast_of_rope_freq_cis)

if args.use_sdpa_with_kv_cache:
transforms.append(replace_kv_cache_with_custom_kv_cache)
transforms.append(replace_sdpa_with_custom_op)

if args.quantize_kv_cache:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import logging
from enum import Enum
from typing import Tuple

import torch
import torch.nn as nn
Expand Down Expand Up @@ -44,7 +45,6 @@ def __init__(
QuantizedCacheType.AffineSymmetric,
QuantizedCacheType.AffineAsymmetric,
):

raise ValueError(
f"Only affine symmetric and asymmetric cache types are supported: got {cache_type}"
)
Expand Down Expand Up @@ -81,10 +81,11 @@ def __init__(
)

def _quantize(self, value):
scales, zero_points = (
torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(
value, self.quantized_cache_dtype
)
(
scales,
zero_points,
) = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(
value, self.quantized_cache_dtype
)
quantized_value = torch.ops.quantized_decomposed.quantize_per_token(
value,
Expand Down Expand Up @@ -262,3 +263,71 @@ def replace_kv_cache_with_quantized_kv_cache(module):
else:
replace_kv_cache_with_quantized_kv_cache(child)
return module


class CustomKVCache(nn.Module):
def __init__(
self,
max_batch_size: int,
max_seq_length: int,
n_heads: int,
head_dim: int,
dtype=torch.float32,
):
super().__init__()
self.max_seq_length = max_seq_length
cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim)

self.max_batch_size = max_batch_size
self.n_heads = n_heads
self.head_dim = head_dim
self.register_buffer(
"k_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu")
)
self.register_buffer(
"v_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu")
)

def update(
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# input_pos: [S], k_val: [B, S, H, D]
start_pos = input_pos[0].item()
_ = torch.ops.llama.update_cache(k_val, self.k_cache, start_pos)
_ = torch.ops.llama.update_cache(v_val, self.v_cache, start_pos)
return self.k_cache, self.v_cache


def replace_kv_cache_with_custom_kv_cache(module):
r"""
Replace KVCache with CustomKVCache. This modifies the model in place.
At the moment custom kv cache only supports cache with shape
[B, S, H, D] as opposed to [B, H, S, D]
This is because the custom op treats second dim as sequence dim.
Future work: support [B, H, S, D]
"""
logging.warning(
"Replacing KVCache with CustomKVCache. This modifies the model in place."
)
for name, child in module.named_children():
if isinstance(child, KVCache):
cache_shape = child.k_cache.shape
cache_dtype = child.k_cache.dtype
assert (
child.is_transposed is False
), "CustomKVCache does not support transposed cache"
max_batch_size, max_seq_length, n_heads, head_dim = cache_shape
setattr(
module,
name,
CustomKVCache(
max_batch_size,
max_seq_length,
n_heads,
head_dim,
dtype=cache_dtype,
),
)
else:
replace_kv_cache_with_custom_kv_cache(child)
return module
39 changes: 10 additions & 29 deletions examples/models/llama/source_transformation/sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,33 +56,16 @@ def forward(

k_cache = self.kv_cache.k_cache
v_cache = self.kv_cache.v_cache
if hasattr(self.kv_cache, "quantized_cache_dtype"):
# updated quantize cache, scale and zero points
# returns dequantized kv cache
# Not most optimal. Optimizations to follow next
k_cache, v_cache = self.kv_cache.update(input_pos, k, v)
output = torch.ops.llama.custom_sdpa(
q,
k_cache,
v_cache,
input_pos[0].item(),
None, # Attention mask
0, # dropout probability. Ignored by the code
True, # is_causal
)
else:
output = torch.ops.llama.sdpa_with_kv_cache(
q,
k,
v,
k_cache,
v_cache,
input_pos[0].item(),
seqlen,
None, # Attention mask
0, # dropout probability. Ignored by the code
True, # is_causal
)
k_cache, v_cache = self.kv_cache.update(input_pos, k, v)
output = torch.ops.llama.custom_sdpa(
q,
k_cache,
v_cache,
input_pos[0].item(),
None, # Attention mask
0, # dropout probability. Ignored by the code
True, # is_causal
)
return output.view(bsz, seqlen, self.dim).to(dtype=input_dtype)


Expand All @@ -106,7 +89,6 @@ def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:


class SDPASimple(torch.nn.Module):

def __init__(
self,
kv_cache: KVCache,
Expand Down Expand Up @@ -166,7 +148,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:


class SDPAFlex(torch.nn.Module):

def __init__(
self,
kv_cache: KVCache,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from executorch.examples.models.llama.llama_transformer import KVCache

from executorch.examples.models.llama.source_transformation.quantized_kv_cache import (
CustomKVCache,
QuantizedCacheType,
QuantizedKVCache,
)
Expand All @@ -19,7 +20,6 @@


class SDPAWithQuantizedKVCacheTest(unittest.TestCase):

def _init_cache(self):
self.kv_cache = KVCache(
self.max_batch_size,
Expand All @@ -33,6 +33,19 @@ def _init_cache(self):
self.quantized_kv_cache = QuantizedKVCache.from_float(
self.kv_cache, QuantizedCacheType.AffineAsymmetric
)
# Need this because first test actually has seq_len > 1
# and vanilla kvcache cannot handle seq_len > 1, due to
# how input_pos encoding works in the current stack.
# This needs fixing by making sure rest of the stack including
# custom ops or other backends can work with input_pos
# as a sequence of token positions
self.custom_kv_cache = CustomKVCache(
self.max_batch_size,
self.max_seq_len,
self.n_kv_heads,
self.head_dim,
dtype=self.dtype,
)

def _init_kv(self):
kv_shape = (1, self.seq_len, self.n_kv_heads, self.head_dim)
Expand All @@ -59,7 +72,7 @@ def test_simple(self, is_dynamic_shape=False):
self.seq_len = 3
self._init_cache()
q, k, v = self._init_kv()
self.float_sdpa = SDPACustom(self.kv_cache, self.dim)
self.float_sdpa = SDPACustom(self.custom_kv_cache, self.dim)
self.quantized_sdpa = SDPACustom(self.quantized_kv_cache, self.dim)
float_out = self.float_sdpa(input_pos, q, k, v, 1, self.seq_len, None)
quantized_out = self.quantized_sdpa(input_pos, q, k, v, 1, self.seq_len, None)
Expand Down