Skip to content

Commit 96d5798

Browse files
committed
Fix
1 parent 7a7041d commit 96d5798

File tree

3 files changed

+12
-3
lines changed

3 files changed

+12
-3
lines changed

examples/models/llama/runner/eager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def __init__(self, args):
3434
max_batch_size=1,
3535
use_kv_cache=args.use_kv_cache,
3636
vocab_size=params["vocab_size"],
37-
has_full_logits=args.model in TORCHTUNE_DEFINED_MODELS
37+
has_full_logits=args.model in TORCHTUNE_DEFINED_MODELS,
3838
device="cuda" if torch.cuda.is_available() else "cpu",
3939
)
4040
manager: LLMEdgeManager = _prepare_for_llama_export(args)

examples/models/llama/runner/generation.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@ def __init__(
7373
has_full_logits: whether the model returns the full logits or only returns the last logit.
7474
device: device to run the runner on.
7575
"""
76-
self.model_name = model
7776
self.max_seq_len = max_seq_len
7877
self.max_batch_size = max_batch_size
7978
self.use_kv_cache = use_kv_cache

examples/models/llama/runner/native.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@
1010

1111
import torch
1212

13+
from executorch.examples.models.llama.export_llama_lib import EXECUTORCH_DEFINED_MODELS, TORCHTUNE_DEFINED_MODELS
14+
1315
from executorch.extension.pybindings.portable_lib import _load_for_executorch
1416

1517
# Load custom ops and quantized ops.
1618
from executorch.extension.pybindings import portable_lib # noqa # usort: skip
1719

18-
from executorch.examples.models.llama2.runner.generation import LlamaRunner
20+
from executorch.examples.models.llama.runner.generation import LlamaRunner
1921

2022
# Note: import this after portable_lib
2123
# from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip
@@ -36,6 +38,7 @@ def __init__(self, args):
3638
max_batch_size=1,
3739
use_kv_cache=args.kv_cache,
3840
vocab_size=params["vocab_size"],
41+
has_full_logits=args.model in TORCHTUNE_DEFINED_MODELS,
3942
)
4043
self.model = _load_for_executorch(args.pte)
4144

@@ -58,8 +61,15 @@ def forward(
5861

5962

6063
def build_args_parser() -> argparse.ArgumentParser:
64+
# TODO: merge these with build_args_parser from export_llama_lib.
6165
parser = argparse.ArgumentParser()
6266

67+
parser.add_argument(
68+
"--model",
69+
default="llama",
70+
choices=EXECUTORCH_DEFINED_MODELS + TORCHTUNE_DEFINED_MODELS,
71+
)
72+
6373
parser.add_argument(
6474
"-f",
6575
"--pte",

0 commit comments

Comments
 (0)