Skip to content

Commit b93967a

Browse files
committed
[Excutorch][Llama] Decouple input sequence length from kv cache context length
Pull Request resolved: #7927 Decouple max sequence length, for shape dynamism in torch.export, from sequence length used for kv cache sizing. ghstack-source-id: 263653316 Differential Revision: [D68448334](https://our.internmc.facebook.com/intern/diff/D68448334/)
1 parent bdd3d9c commit b93967a

13 files changed

+105
-66
lines changed

.ci/scripts/test_eval_llama_mmlu.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ run_and_verify() {
4343
--tasks mmlu \
4444
-f 5 \
4545
--max_seq_length 2048 \
46+
--max_context_length 2048 \
4647
--limit 5 > result.txt
4748

4849
# Verify result.txt

.ci/scripts/test_eval_llama_wikitext.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ run_and_verify() {
4141
-kv \
4242
-d fp32 \
4343
--max_seq_length 2048 \
44+
--max_context_length 2048 \
4445
--limit 5 > result.txt
4546

4647
# Verify result.txt

examples/models/llama/export_llama_lib.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,13 @@ def build_args_parser() -> argparse.ArgumentParser:
335335
help="maximum length sequence to evaluate",
336336
)
337337

338+
parser.add_argument(
339+
"--max_context_length",
340+
type=int,
341+
default=128,
342+
help="maximum length of context for model to remember",
343+
)
344+
338345
parser.add_argument("-2", "--fairseq2", action="store_true")
339346
parser.add_argument("-v", "--verbose", action="store_true")
340347
parser.add_argument(
@@ -579,6 +586,7 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
579586
tokenizer_path=args.tokenizer_path,
580587
verbose=args.verbose,
581588
max_seq_len=args.max_seq_length,
589+
max_context_len=args.max_context_length,
582590
input_prune_map_path=args.input_prune_map,
583591
output_prune_map_path=args.output_prune_map,
584592
metadata_str=args.metadata,
@@ -637,6 +645,11 @@ def _validate_args(args):
637645
"""
638646
TODO: Combine all the backends under --backend args
639647
"""
648+
649+
if args.max_context_length < args.max_seq_length:
650+
raise ValueError(
651+
f"max_context_length {args.max_context_length} must be >= max_seq_len {args.max_seq_length}. max_context_length impacts kv cache size that is used to remember history, while max_seq_length refers to user prompt length. Please use --max_context_length to specify context length."
652+
)
640653
if args.enable_dynamic_shape and (args.coreml or args.mps or args.qnn):
641654
raise ValueError(
642655
"Dynamic shape is not supported with coreml, MPS or qnn backends."
@@ -662,6 +675,7 @@ def _validate_args(args):
662675

663676
def _export_llama(args) -> LLMEdgeManager: # noqa: C901
664677
_validate_args(args)
678+
665679
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)
666680

667681
# export_to_edge
@@ -760,13 +774,13 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
760774
atten = builder_exported_to_edge.model.layers[0].attention
761775
if args.use_qnn_sha:
762776
cache_shape = torch.Size(
763-
(atten.max_batch_size, atten.max_seq_len, atten.head_dim)
777+
(atten.max_batch_size, atten.max_context_len, atten.head_dim)
764778
)
765779
else:
766780
cache_shape = torch.Size(
767781
(
768782
atten.max_batch_size,
769-
atten.max_seq_len,
783+
atten.max_context_len,
770784
atten.n_kv_heads,
771785
atten.head_dim,
772786
)
@@ -861,6 +875,7 @@ def _load_llama_model_metadata(
861875
use_sdpa_with_kv_cache: bool,
862876
enable_dynamic_shape: bool,
863877
max_seq_len: int,
878+
max_context_len: int,
864879
n_layers: int,
865880
vocab_size: int,
866881
metadata_str: Optional[str] = None,
@@ -870,6 +885,7 @@ def _load_llama_model_metadata(
870885
"get_bos_id": 3 if is_fairseq2 else 1,
871886
"get_eos_ids": [3] if is_fairseq2 else [2],
872887
"get_max_seq_len": max_seq_len,
888+
"get_max_context_len": max_context_len,
873889
"get_n_layers": n_layers,
874890
"get_vocab_size": vocab_size,
875891
"use_kv_cache": use_kv_cache,
@@ -904,6 +920,7 @@ def _load_llama_model(
904920
tokenizer_path: Optional[str] = None,
905921
verbose: bool = False,
906922
max_seq_len: int = 128,
923+
max_context_len: int = 128,
907924
input_prune_map_path: Optional[str] = None,
908925
output_prune_map_path: Optional[str] = None,
909926
metadata_str: Optional[str] = None,
@@ -948,6 +965,7 @@ def _load_llama_model(
948965
generate_full_logits=generate_full_logits,
949966
fairseq2=weight_type == WeightType.FAIRSEQ2,
950967
max_seq_len=max_seq_len,
968+
max_context_len=max_context_len,
951969
enable_dynamic_shape=enable_dynamic_shape,
952970
input_prune_map_path=input_prune_map_path,
953971
output_prune_map_path=output_prune_map_path,
@@ -1006,10 +1024,13 @@ def _load_llama_model(
10061024
# pyre-fixme[6]: For 5th argument expected `ModelArgs` but got
10071025
# `Union[Tensor, Module]`.
10081026
model.max_seq_len,
1009-
# pyre-fixme[6]: For 6th argument expected `int` but got `Union[Tensor,
1027+
# pyre-fixme[6]: For 6th argument expected `ModelArgs` but got
1028+
# `Union[Tensor, Module]`.
1029+
model.max_context_len,
1030+
# pyre-fixme[6]: For 7th argument expected `int` but got `Union[Tensor,
10101031
# Module]`.
10111032
model.n_layers,
1012-
# pyre-fixme[6]: For 7th argument expected `int` but got `Union[Tensor,
1033+
# pyre-fixme[6]: For 8th argument expected `int` but got `Union[Tensor,
10131034
# Module]`.
10141035
model.vocab_size,
10151036
metadata_str,

examples/models/llama/llama_transformer.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ class ModelArgs:
9191
norm_eps: float = 1e-5
9292
max_batch_size: int = 32
9393
max_seq_len: int = 2048
94+
max_context_len: int = 2048
9495
moe: bool = False # True to enable the MoE (Mixture of Experts)
9596
num_experts: int = 8 # Number of experts
9697
num_activated_experts: int = 2 # Number of experts to activate
@@ -163,9 +164,9 @@ def __init__(self, params: ModelArgs):
163164
freqs_cos, freqs_sin = self.precompute_freqs_cis(
164165
self.params.head_dim,
165166
(
166-
self.params.max_seq_len # Normal llama2.
167+
self.params.max_context_len # Normal llama2.
167168
if self.params.ffn_dim_multiplier is None
168-
else self.params.max_seq_len * 2 # Sharded checkpoint.
169+
else self.params.max_context_len * 2 # Sharded checkpoint.
169170
),
170171
self.params.rope_freq_base,
171172
)
@@ -205,7 +206,7 @@ def get_freqs(self, input_pos: Optional[torch.Tensor], seq_len: int):
205206
# when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos.
206207
input_pos_item = input_pos[-1].item()
207208
torch._check_is_size(input_pos_item)
208-
torch._check(input_pos_item < self.params.max_seq_len)
209+
torch._check(input_pos_item < self.params.max_context_len)
209210
# pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor
210211
freqs_cos = self.freqs_cos.narrow(0, input_pos_item, seq_len)
211212
# pyre-ignore: Incompatible parameter type [6]
@@ -229,15 +230,15 @@ class KVCache(nn.Module):
229230
def __init__(
230231
self,
231232
max_batch_size: int,
232-
max_seq_length: int,
233+
max_context_length: int,
233234
n_heads: int,
234235
head_dim: int,
235236
enable_dynamic_shape: bool,
236237
dtype=torch.float32,
237238
):
238239
super().__init__()
239-
self.max_seq_length = max_seq_length
240-
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
240+
self.max_context_length = max_context_length
241+
cache_shape = (max_batch_size, n_heads, max_context_length, head_dim)
241242

242243
self.max_batch_size = max_batch_size
243244
self.n_heads = n_heads
@@ -257,7 +258,7 @@ def update(
257258
if self.enable_dynamic_shape:
258259
start_pos = input_pos[0].item()
259260
torch._check_is_size(start_pos)
260-
torch._check(start_pos < self.max_seq_length)
261+
torch._check(start_pos < self.max_context_length)
261262
dim_to_slice = 2
262263
seq_length = k_val.size(dim_to_slice)
263264
# Replace the entry in the cache for this token
@@ -289,14 +290,14 @@ def __init__(
289290
dim: int,
290291
head_dim: int,
291292
n_rep: int,
292-
max_seq_len: int,
293+
max_context_len: int,
293294
enable_dynamic_shape: bool,
294295
):
295296
super().__init__()
296297
self.dim = dim
297298
self.head_dim = head_dim
298299
self.n_rep = n_rep
299-
self.max_seq_len = max_seq_len
300+
self.max_context_len = max_context_len
300301
self.enable_dynamic_shape = enable_dynamic_shape
301302

302303
def forward(
@@ -312,7 +313,7 @@ def forward(
312313
if self.enable_dynamic_shape:
313314
start_pos = input_pos[-1].item()
314315
torch._check_is_size(start_pos)
315-
torch._check(start_pos < self.max_seq_len)
316+
torch._check(start_pos < self.max_context_len)
316317
seq_length = q.size(2)
317318
# pyre-ignore: Incompatible parameter type [6]
318319
attn_mask = mask.narrow(0, start_pos, seq_length)
@@ -341,7 +342,7 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
341342
self.n_rep = self.n_local_heads // self.n_local_kv_heads
342343
self.head_dim = args.head_dim
343344
self.max_batch_size = args.max_batch_size
344-
self.max_seq_len = args.max_seq_len
345+
self.max_context_len = args.max_context_len
345346
self.dim = args.dim
346347
self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False)
347348
self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
@@ -354,8 +355,8 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
354355

355356
causal_mask = torch.tril(
356357
torch.ones(
357-
self.max_seq_len,
358-
self.max_seq_len,
358+
self.max_context_len,
359+
self.max_context_len,
359360
dtype=torch.bool,
360361
device="cpu",
361362
)
@@ -365,7 +366,7 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
365366
if self.use_kv_cache:
366367
self.kv_cache = KVCache(
367368
args.max_batch_size,
368-
args.max_seq_len,
369+
args.max_context_len,
369370
self.n_kv_heads,
370371
self.head_dim,
371372
args.enable_dynamic_shape,
@@ -374,7 +375,7 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
374375
dim=self.n_local_heads * self.head_dim,
375376
head_dim=self.head_dim,
376377
n_rep=self.n_rep,
377-
max_seq_len=self.max_seq_len,
378+
max_context_len=self.max_context_len,
378379
enable_dynamic_shape=args.enable_dynamic_shape,
379380
)
380381

@@ -528,6 +529,7 @@ def __init__(self, params: ModelArgs):
528529
self.use_kv_cache = params.use_kv_cache
529530
self.generate_full_logits = params.generate_full_logits
530531
self.max_seq_len = params.max_seq_len
532+
self.max_context_len = params.max_context_len
531533
self.input_prune_map = params.input_prune_map
532534
self.output_prune_map = params.output_prune_map
533535

examples/models/llama/model.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,13 @@ def __init__(self, **kwargs):
5252
self.input_prune_map_path = kwargs.get("input_prune_map_path", None)
5353
self.output_prune_map_path = kwargs.get("output_prune_map_path", None)
5454
self.max_seq_len = kwargs.get("max_seq_len", 128)
55+
self.max_context_len = kwargs.get("max_context_len", 128)
5556
self.args = kwargs.get("args", None)
5657

58+
assert (
59+
self.max_context_len >= self.max_seq_len
60+
), f"max_context_len({self.max_context_len}) must be >= max_seq_len({self.max_seq_len})"
61+
5762
# The example is using a dummy small model with random weights for demo purpose only.
5863
# Follow the instruction in https://github.com/facebookresearch/llama to download the model.
5964
device = "cpu"
@@ -136,6 +141,7 @@ def __init__(self, **kwargs):
136141

137142
model_args: ModelArgs = ModelArgs(
138143
max_seq_len=self.max_seq_len,
144+
max_context_len=self.max_context_len,
139145
max_batch_size=1,
140146
use_kv_cache=self.use_kv_cache,
141147
use_sdpa_with_kv_cache_op=self.use_sdpa_with_kv_cache_op,
@@ -219,7 +225,7 @@ def __init__(self, **kwargs):
219225
window_size = int(attention_sink_params[1])
220226
eviction_batch_size = int(attention_sink_params[2])
221227

222-
assert self.args.max_seq_length == sink_size + window_size
228+
assert self.args.max_context_length == sink_size + window_size
223229

224230
self.model_ = enable_attention_sink(
225231
module=self.model_,

examples/models/llama/source_transformation/attention.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,15 @@ class KVCacheSHA(torch.nn.Module):
3232
def __init__(
3333
self,
3434
max_batch_size: int,
35-
max_seq_length: int,
35+
max_context_length: int,
3636
n_heads: int,
3737
head_dim: int,
3838
dtype=torch.float32,
3939
):
4040
super().__init__()
4141

4242
# a buffer per head
43-
cache_shape = (max_batch_size, max_seq_length, head_dim)
43+
cache_shape = (max_batch_size, max_context_length, head_dim)
4444
for i in range(n_heads):
4545
self.register_buffer(
4646
f"past_k_caches_{i}",
@@ -79,7 +79,7 @@ class SDPASHA(torch.nn.Module):
7979
def __init__(
8080
self,
8181
max_batch_size: int,
82-
max_seq_length: int,
82+
max_context_length: int,
8383
n_heads: int,
8484
n_rep: int,
8585
head_dim: int,
@@ -90,7 +90,7 @@ def __init__(
9090
self.n_rep = n_rep
9191
self.dim = dim
9292
self.kv_cache = KVCacheSHA(
93-
max_batch_size, max_seq_length, n_heads // n_rep, head_dim
93+
max_batch_size, max_context_length, n_heads // n_rep, head_dim
9494
)
9595
self.scale_factor = math.sqrt(head_dim)
9696

@@ -134,11 +134,11 @@ def __init__(self, attention_mha: nn.Module):
134134
self.n_rep = self.n_heads // self.n_kv_heads
135135
self.dim = attention_mha.dim
136136
self.max_batch_size = attention_mha.max_batch_size
137-
self.max_seq_len = attention_mha.max_seq_len
137+
self.max_context_len = attention_mha.max_context_len
138138
self.head_dim = attention_mha.dim // self.n_heads
139139
self.SDPA = SDPASHA(
140140
self.max_batch_size,
141-
self.max_seq_len,
141+
self.max_context_len,
142142
self.n_heads,
143143
self.n_rep,
144144
self.head_dim,
@@ -184,8 +184,8 @@ def __init__(self, attention_mha: nn.Module):
184184

185185
causal_mask = torch.tril(
186186
torch.ones(
187-
self.max_seq_len,
188-
self.max_seq_len,
187+
self.max_context_len,
188+
self.max_context_len,
189189
dtype=torch.bool,
190190
device="cpu",
191191
)

examples/models/llama/source_transformation/attention_sink.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ def __init__(
4444
self.apply_rotary_emb_to_k = hf_apply_rotary_emb_to_k
4545
else:
4646
self.apply_rotary_emb_to_k = apply_rotary_emb_to_k
47-
self.max_seq_length = window_size + sink_size
48-
assert self.max_seq_length == self.params.max_seq_len
47+
self.max_context_length = window_size + sink_size
48+
assert self.max_context_length == self.params.max_context_len
4949
self.eviction_batch_size = eviction_batch_size
5050
self.position_shift = 0
5151

@@ -54,11 +54,14 @@ def get_freqs(self, input_pos: Optional[torch.Tensor], seq_len: int):
5454

5555
input_pos_item = input_pos.item()
5656
torch._check_is_size(input_pos_item)
57-
if input_pos_item + self.position_shift + seq_len > self.max_seq_length:
57+
if input_pos_item + self.position_shift + seq_len > self.max_context_length:
5858
# There are not enough spaces in the cache to store the new tokens.
5959
# We need to evict some old tokens and shift some recent tokens.
6060
num_to_evict = max(
61-
input_pos_item + self.position_shift - self.max_seq_length + seq_len,
61+
input_pos_item
62+
+ self.position_shift
63+
- self.max_context_length
64+
+ seq_len,
6265
self.eviction_batch_size,
6366
)
6467
self.position_shift -= num_to_evict # pyre-ignore [8]
@@ -121,7 +124,7 @@ def __init__(
121124
):
122125
super().__init__(
123126
max_batch_size=max_batch_size,
124-
max_seq_length=window_size + sink_size,
127+
max_context_length=window_size + sink_size,
125128
n_heads=n_heads,
126129
head_dim=head_dim,
127130
enable_dynamic_shape=enable_dynamic_shape,
@@ -148,11 +151,14 @@ def evict_tokens(self, input_pos: torch.Tensor, seq_len: int) -> int:
148151
"""
149152
input_pos_item = input_pos.item()
150153
torch._check_is_size(input_pos_item)
151-
if input_pos_item + self.position_shift + seq_len > self.max_seq_length:
154+
if input_pos_item + self.position_shift + seq_len > self.max_context_length:
152155
# There are not enough spaces in the cache to store the new tokens.
153156
# We need to evict some old tokens and shift some recent tokens.
154157
num_to_evict = max(
155-
input_pos_item + self.position_shift - self.max_seq_length + seq_len,
158+
input_pos_item
159+
+ self.position_shift
160+
- self.max_context_length
161+
+ seq_len,
156162
self.eviction_batch_size,
157163
)
158164
num_to_keep = (

0 commit comments

Comments
 (0)