Skip to content

[Excutorch][Llama] Decouple input sequence length from kv cache context length #8047

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .ci/scripts/test_eval_llama_mmlu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ run_and_verify() {
--tasks mmlu \
-f 5 \
--max_seq_length 2048 \
--max_context_length 2048 \
--limit 5 > result.txt

# Verify result.txt
Expand Down
1 change: 1 addition & 0 deletions .ci/scripts/test_eval_llama_wikitext.sh
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ run_and_verify() {
-kv \
-d fp32 \
--max_seq_length 2048 \
--max_context_length 2048 \
--limit 5 > result.txt

# Verify result.txt
Expand Down
29 changes: 25 additions & 4 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,13 @@ def build_args_parser() -> argparse.ArgumentParser:
help="maximum length sequence to evaluate",
)

parser.add_argument(
"--max_context_length",
type=int,
default=128,
help="maximum length of context for model to remember",
)

parser.add_argument("-2", "--fairseq2", action="store_true")
parser.add_argument("-v", "--verbose", action="store_true")
parser.add_argument(
Expand Down Expand Up @@ -579,6 +586,7 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
tokenizer_path=args.tokenizer_path,
verbose=args.verbose,
max_seq_len=args.max_seq_length,
max_context_len=args.max_context_length,
input_prune_map_path=args.input_prune_map,
output_prune_map_path=args.output_prune_map,
metadata_str=args.metadata,
Expand Down Expand Up @@ -637,6 +645,11 @@ def _validate_args(args):
"""
TODO: Combine all the backends under --backend args
"""

if args.max_context_length < args.max_seq_length:
raise ValueError(
f"max_context_length {args.max_context_length} must be >= max_seq_len {args.max_seq_length}. max_context_length impacts kv cache size that is used to remember history, while max_seq_length refers to user prompt length. Please use --max_context_length to specify context length."
)
if args.enable_dynamic_shape and (args.coreml or args.mps or args.qnn):
raise ValueError(
"Dynamic shape is not supported with coreml, MPS or qnn backends."
Expand All @@ -662,6 +675,7 @@ def _validate_args(args):

def _export_llama(args) -> LLMEdgeManager: # noqa: C901
_validate_args(args)

pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)

# export_to_edge
Expand Down Expand Up @@ -760,13 +774,13 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
atten = builder_exported_to_edge.model.layers[0].attention
if args.use_qnn_sha:
cache_shape = torch.Size(
(atten.max_batch_size, atten.max_seq_len, atten.head_dim)
(atten.max_batch_size, atten.max_context_len, atten.head_dim)
)
else:
cache_shape = torch.Size(
(
atten.max_batch_size,
atten.max_seq_len,
atten.max_context_len,
atten.n_kv_heads,
atten.head_dim,
)
Expand Down Expand Up @@ -861,6 +875,7 @@ def _load_llama_model_metadata(
use_sdpa_with_kv_cache: bool,
enable_dynamic_shape: bool,
max_seq_len: int,
max_context_len: int,
n_layers: int,
vocab_size: int,
metadata_str: Optional[str] = None,
Expand All @@ -870,6 +885,7 @@ def _load_llama_model_metadata(
"get_bos_id": 3 if is_fairseq2 else 1,
"get_eos_ids": [3] if is_fairseq2 else [2],
"get_max_seq_len": max_seq_len,
"get_max_context_len": max_context_len,
"get_n_layers": n_layers,
"get_vocab_size": vocab_size,
"use_kv_cache": use_kv_cache,
Expand Down Expand Up @@ -904,6 +920,7 @@ def _load_llama_model(
tokenizer_path: Optional[str] = None,
verbose: bool = False,
max_seq_len: int = 128,
max_context_len: int = 128,
input_prune_map_path: Optional[str] = None,
output_prune_map_path: Optional[str] = None,
metadata_str: Optional[str] = None,
Expand Down Expand Up @@ -948,6 +965,7 @@ def _load_llama_model(
generate_full_logits=generate_full_logits,
fairseq2=weight_type == WeightType.FAIRSEQ2,
max_seq_len=max_seq_len,
max_context_len=max_context_len,
enable_dynamic_shape=enable_dynamic_shape,
input_prune_map_path=input_prune_map_path,
output_prune_map_path=output_prune_map_path,
Expand Down Expand Up @@ -1006,10 +1024,13 @@ def _load_llama_model(
# pyre-fixme[6]: For 5th argument expected `ModelArgs` but got
# `Union[Tensor, Module]`.
model.max_seq_len,
# pyre-fixme[6]: For 6th argument expected `int` but got `Union[Tensor,
# pyre-fixme[6]: For 6th argument expected `ModelArgs` but got
# `Union[Tensor, Module]`.
model.max_context_len,
# pyre-fixme[6]: For 7th argument expected `int` but got `Union[Tensor,
# Module]`.
model.n_layers,
# pyre-fixme[6]: For 7th argument expected `int` but got `Union[Tensor,
# pyre-fixme[6]: For 8th argument expected `int` but got `Union[Tensor,
# Module]`.
model.vocab_size,
metadata_str,
Expand Down
32 changes: 17 additions & 15 deletions examples/models/llama/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ class ModelArgs:
norm_eps: float = 1e-5
max_batch_size: int = 32
max_seq_len: int = 2048
max_context_len: int = 2048
moe: bool = False # True to enable the MoE (Mixture of Experts)
num_experts: int = 8 # Number of experts
num_activated_experts: int = 2 # Number of experts to activate
Expand Down Expand Up @@ -163,9 +164,9 @@ def __init__(self, params: ModelArgs):
freqs_cos, freqs_sin = self.precompute_freqs_cis(
self.params.head_dim,
(
self.params.max_seq_len # Normal llama2.
self.params.max_context_len # Normal llama2.
if self.params.ffn_dim_multiplier is None
else self.params.max_seq_len * 2 # Sharded checkpoint.
else self.params.max_context_len * 2 # Sharded checkpoint.
),
self.params.rope_freq_base,
)
Expand Down Expand Up @@ -205,7 +206,7 @@ def get_freqs(self, input_pos: Optional[torch.Tensor], seq_len: int):
# when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos.
input_pos_item = input_pos[-1].item()
torch._check_is_size(input_pos_item)
torch._check(input_pos_item < self.params.max_seq_len)
torch._check(input_pos_item < self.params.max_context_len)
# pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor
freqs_cos = self.freqs_cos.narrow(0, input_pos_item, seq_len)
# pyre-ignore: Incompatible parameter type [6]
Expand All @@ -229,15 +230,15 @@ class KVCache(nn.Module):
def __init__(
self,
max_batch_size: int,
max_seq_length: int,
max_context_length: int,
n_heads: int,
head_dim: int,
enable_dynamic_shape: bool,
dtype=torch.float32,
):
super().__init__()
self.max_seq_length = max_seq_length
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
self.max_context_length = max_context_length
cache_shape = (max_batch_size, n_heads, max_context_length, head_dim)

self.max_batch_size = max_batch_size
self.n_heads = n_heads
Expand All @@ -257,7 +258,7 @@ def update(
if self.enable_dynamic_shape:
start_pos = input_pos[0].item()
torch._check_is_size(start_pos)
torch._check(start_pos < self.max_seq_length)
torch._check(start_pos < self.max_context_length)
dim_to_slice = 2
seq_length = k_val.size(dim_to_slice)
# Replace the entry in the cache for this token
Expand Down Expand Up @@ -289,14 +290,14 @@ def __init__(
dim: int,
head_dim: int,
n_rep: int,
max_seq_len: int,
max_context_len: int,
enable_dynamic_shape: bool,
):
super().__init__()
self.dim = dim
self.head_dim = head_dim
self.n_rep = n_rep
self.max_seq_len = max_seq_len
self.max_context_len = max_context_len
self.enable_dynamic_shape = enable_dynamic_shape

def forward(
Expand All @@ -312,7 +313,7 @@ def forward(
if self.enable_dynamic_shape:
start_pos = input_pos[-1].item()
torch._check_is_size(start_pos)
torch._check(start_pos < self.max_seq_len)
torch._check(start_pos < self.max_context_len)
seq_length = q.size(2)
# pyre-ignore: Incompatible parameter type [6]
attn_mask = mask.narrow(0, start_pos, seq_length)
Expand Down Expand Up @@ -341,7 +342,7 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.head_dim
self.max_batch_size = args.max_batch_size
self.max_seq_len = args.max_seq_len
self.max_context_len = args.max_context_len
self.dim = args.dim
self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
Expand All @@ -354,8 +355,8 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):

causal_mask = torch.tril(
torch.ones(
self.max_seq_len,
self.max_seq_len,
self.max_context_len,
self.max_context_len,
dtype=torch.bool,
device="cpu",
)
Expand All @@ -365,7 +366,7 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
if self.use_kv_cache:
self.kv_cache = KVCache(
args.max_batch_size,
args.max_seq_len,
args.max_context_len,
self.n_kv_heads,
self.head_dim,
args.enable_dynamic_shape,
Expand All @@ -374,7 +375,7 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
dim=self.n_local_heads * self.head_dim,
head_dim=self.head_dim,
n_rep=self.n_rep,
max_seq_len=self.max_seq_len,
max_context_len=self.max_context_len,
enable_dynamic_shape=args.enable_dynamic_shape,
)

Expand Down Expand Up @@ -528,6 +529,7 @@ def __init__(self, params: ModelArgs):
self.use_kv_cache = params.use_kv_cache
self.generate_full_logits = params.generate_full_logits
self.max_seq_len = params.max_seq_len
self.max_context_len = params.max_context_len
self.input_prune_map = params.input_prune_map
self.output_prune_map = params.output_prune_map

Expand Down
8 changes: 7 additions & 1 deletion examples/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,13 @@ def __init__(self, **kwargs):
self.input_prune_map_path = kwargs.get("input_prune_map_path", None)
self.output_prune_map_path = kwargs.get("output_prune_map_path", None)
self.max_seq_len = kwargs.get("max_seq_len", 128)
self.max_context_len = kwargs.get("max_context_len", 128)
self.args = kwargs.get("args", None)

assert (
self.max_context_len >= self.max_seq_len
), f"max_context_len({self.max_context_len}) must be >= max_seq_len({self.max_seq_len})"

# The example is using a dummy small model with random weights for demo purpose only.
# Follow the instruction in https://github.com/facebookresearch/llama to download the model.
device = "cpu"
Expand Down Expand Up @@ -136,6 +141,7 @@ def __init__(self, **kwargs):

model_args: ModelArgs = ModelArgs(
max_seq_len=self.max_seq_len,
max_context_len=self.max_context_len,
max_batch_size=1,
use_kv_cache=self.use_kv_cache,
use_sdpa_with_kv_cache_op=self.use_sdpa_with_kv_cache_op,
Expand Down Expand Up @@ -219,7 +225,7 @@ def __init__(self, **kwargs):
window_size = int(attention_sink_params[1])
eviction_batch_size = int(attention_sink_params[2])

assert self.args.max_seq_length == sink_size + window_size
assert self.args.max_context_length == sink_size + window_size

self.model_ = enable_attention_sink(
module=self.model_,
Expand Down
16 changes: 8 additions & 8 deletions examples/models/llama/source_transformation/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@ class KVCacheSHA(torch.nn.Module):
def __init__(
self,
max_batch_size: int,
max_seq_length: int,
max_context_length: int,
n_heads: int,
head_dim: int,
dtype=torch.float32,
):
super().__init__()

# a buffer per head
cache_shape = (max_batch_size, max_seq_length, head_dim)
cache_shape = (max_batch_size, max_context_length, head_dim)
for i in range(n_heads):
self.register_buffer(
f"past_k_caches_{i}",
Expand Down Expand Up @@ -79,7 +79,7 @@ class SDPASHA(torch.nn.Module):
def __init__(
self,
max_batch_size: int,
max_seq_length: int,
max_context_length: int,
n_heads: int,
n_rep: int,
head_dim: int,
Expand All @@ -90,7 +90,7 @@ def __init__(
self.n_rep = n_rep
self.dim = dim
self.kv_cache = KVCacheSHA(
max_batch_size, max_seq_length, n_heads // n_rep, head_dim
max_batch_size, max_context_length, n_heads // n_rep, head_dim
)
self.scale_factor = math.sqrt(head_dim)

Expand Down Expand Up @@ -134,11 +134,11 @@ def __init__(self, attention_mha: nn.Module):
self.n_rep = self.n_heads // self.n_kv_heads
self.dim = attention_mha.dim
self.max_batch_size = attention_mha.max_batch_size
self.max_seq_len = attention_mha.max_seq_len
self.max_context_len = attention_mha.max_context_len
self.head_dim = attention_mha.dim // self.n_heads
self.SDPA = SDPASHA(
self.max_batch_size,
self.max_seq_len,
self.max_context_len,
self.n_heads,
self.n_rep,
self.head_dim,
Expand Down Expand Up @@ -184,8 +184,8 @@ def __init__(self, attention_mha: nn.Module):

causal_mask = torch.tril(
torch.ones(
self.max_seq_len,
self.max_seq_len,
self.max_context_len,
self.max_context_len,
dtype=torch.bool,
device="cpu",
)
Expand Down
20 changes: 13 additions & 7 deletions examples/models/llama/source_transformation/attention_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def __init__(
self.apply_rotary_emb_to_k = hf_apply_rotary_emb_to_k
else:
self.apply_rotary_emb_to_k = apply_rotary_emb_to_k
self.max_seq_length = window_size + sink_size
assert self.max_seq_length == self.params.max_seq_len
self.max_context_length = window_size + sink_size
assert self.max_context_length == self.params.max_context_len
self.eviction_batch_size = eviction_batch_size
self.position_shift = 0

Expand All @@ -54,11 +54,14 @@ def get_freqs(self, input_pos: Optional[torch.Tensor], seq_len: int):

input_pos_item = input_pos.item()
torch._check_is_size(input_pos_item)
if input_pos_item + self.position_shift + seq_len > self.max_seq_length:
if input_pos_item + self.position_shift + seq_len > self.max_context_length:
# There are not enough spaces in the cache to store the new tokens.
# We need to evict some old tokens and shift some recent tokens.
num_to_evict = max(
input_pos_item + self.position_shift - self.max_seq_length + seq_len,
input_pos_item
+ self.position_shift
- self.max_context_length
+ seq_len,
self.eviction_batch_size,
)
self.position_shift -= num_to_evict # pyre-ignore [8]
Expand Down Expand Up @@ -121,7 +124,7 @@ def __init__(
):
super().__init__(
max_batch_size=max_batch_size,
max_seq_length=window_size + sink_size,
max_context_length=window_size + sink_size,
n_heads=n_heads,
head_dim=head_dim,
enable_dynamic_shape=enable_dynamic_shape,
Expand All @@ -148,11 +151,14 @@ def evict_tokens(self, input_pos: torch.Tensor, seq_len: int) -> int:
"""
input_pos_item = input_pos.item()
torch._check_is_size(input_pos_item)
if input_pos_item + self.position_shift + seq_len > self.max_seq_length:
if input_pos_item + self.position_shift + seq_len > self.max_context_length:
# There are not enough spaces in the cache to store the new tokens.
# We need to evict some old tokens and shift some recent tokens.
num_to_evict = max(
input_pos_item + self.position_shift - self.max_seq_length + seq_len,
input_pos_item
+ self.position_shift
- self.max_context_length
+ seq_len,
self.eviction_batch_size,
)
num_to_keep = (
Expand Down
Loading
Loading