Skip to content

Commit 2ae1870

Browse files
committed
Update on "[Excutorch][Llama] Decouple input sequence length from kv cache context length"
Decouple max sequence length, for shape dynamism in torch.export, from sequence length used for kv cache sizing. Differential Revision: [D68448334](https://our.internmc.facebook.com/intern/diff/D68448334/) cc mergennachin cccclai helunwencser dvorjackz [ghstack-poisoned]
2 parents 121238b + 685d256 commit 2ae1870

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ def build_args_parser() -> argparse.ArgumentParser:
338338
parser.add_argument(
339339
"--max_context_length",
340340
type=int,
341-
default=None,
341+
default=128,
342342
help="maximum length of context for model to remember",
343343
)
344344

@@ -645,8 +645,6 @@ def _validate_args(args):
645645
"""
646646
TODO: Combine all the backends under --backend args
647647
"""
648-
if args.max_context_length is None:
649-
args.max_context_length = args.max_seq_length
650648
if args.enable_dynamic_shape and (args.coreml or args.mps or args.qnn):
651649
raise ValueError(
652650
"Dynamic shape is not supported with coreml, MPS or qnn backends."
@@ -672,6 +670,7 @@ def _validate_args(args):
672670

673671
def _export_llama(args) -> LLMEdgeManager: # noqa: C901
674672
_validate_args(args)
673+
675674
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)
676675

677676
# export_to_edge

0 commit comments

Comments
 (0)