Skip to content

Commit 1070600

Browse files
committed
Update on "[Excutorch][Llama] Decouple input sequence length from kv cache context length"
Decouple max sequence length, for shape dynamism in torch.export, from sequence length used for kv cache sizing. Differential Revision: [D68448334](https://our.internmc.facebook.com/intern/diff/D68448334/) cc mergennachin cccclai helunwencser dvorjackz [ghstack-poisoned]
2 parents 80c1d03 + 8a24dfb commit 1070600

File tree

3 files changed

+7
-0
lines changed

3 files changed

+7
-0
lines changed

.ci/scripts/test_eval_llama_mmlu.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ run_and_verify() {
4343
--tasks mmlu \
4444
-f 5 \
4545
--max_seq_length 2048 \
46+
--max_context_length 2048 \
4647
--limit 5 > result.txt
4748

4849
# Verify result.txt

.ci/scripts/test_eval_llama_wikitext.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ run_and_verify() {
4141
-kv \
4242
-d fp32 \
4343
--max_seq_length 2048 \
44+
--max_context_length 2048 \
4445
--limit 5 > result.txt
4546

4647
# Verify result.txt

examples/models/llama/export_llama_lib.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -645,6 +645,11 @@ def _validate_args(args):
645645
"""
646646
TODO: Combine all the backends under --backend args
647647
"""
648+
649+
if args.max_context_length < args.max_seq_length:
650+
raise ValueError(
651+
f"max_context_length {args.max_context_length} must be >= max_seq_len {args.max_seq_length}. max_context_length impacts kv cache size that is used to remember history, while max_seq_length refers to user prompt length. Please use --max_context_length to specify context length."
652+
)
648653
if args.enable_dynamic_shape and (args.coreml or args.mps or args.qnn):
649654
raise ValueError(
650655
"Dynamic shape is not supported with coreml, MPS or qnn backends."

0 commit comments

Comments
 (0)