Skip to content

Commit 386bb05

Browse files
committed
Fix test_llama
ghstack-source-id: 714a076 Pull Request resolved: #11165
1 parent 3a09118 commit 386bb05

File tree

2 files changed

+6
-8
lines changed

2 files changed

+6
-8
lines changed

backends/arm/test/models/test_llama.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
TosaPipelineMI,
2323
)
2424

25+
from executorch.examples.models.llama.config.llm_config_utils import convert_args_to_llm_config
2526
from executorch.examples.models.llama.export_llama_lib import (
2627
build_args_parser,
2728
get_llama_model,
@@ -89,8 +90,9 @@ def prepare_model(self):
8990
]
9091
parser = build_args_parser()
9192
args = parser.parse_args(args)
93+
llm_config = convert_args_to_llm_config(args)
9294

93-
llama_model, llama_inputs, llama_meta = get_llama_model(args)
95+
llama_model, llama_inputs, llama_meta = get_llama_model(llm_config)
9496

9597
return llama_model, llama_inputs, llama_meta
9698

examples/models/llama/export_llama_lib.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -805,10 +805,6 @@ def _qmode_type(value):
805805

806806

807807
def _validate_args(llm_config):
808-
"""
809-
TODO: Combine all the backends under --backend args
810-
"""
811-
812808
if llm_config.export.max_context_length < llm_config.export.max_seq_length:
813809
raise ValueError(
814810
f"max_context_length {llm_config.export.max_context_length} must be >= max_seq_len {llm_config.export.max_seq_length}. max_context_length impacts kv cache size that is used to remember history, while max_seq_length refers to user prompt length. Please use --max_context_length to specify context length."
@@ -1498,9 +1494,9 @@ def _get_source_transforms( # noqa
14981494
return transforms
14991495

15001496

1501-
def get_llama_model(args):
1502-
_validate_args(args)
1503-
e_mgr = _prepare_for_llama_export(args)
1497+
def get_llama_model(llm_config: LlmConfig):
1498+
_validate_args(llm_config)
1499+
e_mgr = _prepare_for_llama_export(llm_config)
15041500
model = (
15051501
e_mgr.model.eval().to(device="cuda")
15061502
if torch.cuda.is_available()

0 commit comments

Comments
 (0)