File tree Expand file tree Collapse file tree 3 files changed +12
-3
lines changed
examples/models/llama/runner Expand file tree Collapse file tree 3 files changed +12
-3
lines changed Original file line number Diff line number Diff line change @@ -34,7 +34,7 @@ def __init__(self, args):
34
34
max_batch_size = 1 ,
35
35
use_kv_cache = args .use_kv_cache ,
36
36
vocab_size = params ["vocab_size" ],
37
- has_full_logits = args .model in TORCHTUNE_DEFINED_MODELS
37
+ has_full_logits = args .model in TORCHTUNE_DEFINED_MODELS ,
38
38
device = "cuda" if torch .cuda .is_available () else "cpu" ,
39
39
)
40
40
manager : LLMEdgeManager = _prepare_for_llama_export (args )
Original file line number Diff line number Diff line change @@ -73,7 +73,6 @@ def __init__(
73
73
has_full_logits: whether the model returns the full logits or only returns the last logit.
74
74
device: device to run the runner on.
75
75
"""
76
- self .model_name = model
77
76
self .max_seq_len = max_seq_len
78
77
self .max_batch_size = max_batch_size
79
78
self .use_kv_cache = use_kv_cache
Original file line number Diff line number Diff line change 10
10
11
11
import torch
12
12
13
+ from executorch .examples .models .llama .export_llama_lib import EXECUTORCH_DEFINED_MODELS , TORCHTUNE_DEFINED_MODELS
14
+
13
15
from executorch .extension .pybindings .portable_lib import _load_for_executorch
14
16
15
17
# Load custom ops and quantized ops.
16
18
from executorch .extension .pybindings import portable_lib # noqa # usort: skip
17
19
18
- from executorch .examples .models .llama2 .runner .generation import LlamaRunner
20
+ from executorch .examples .models .llama .runner .generation import LlamaRunner
19
21
20
22
# Note: import this after portable_lib
21
23
# from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip
@@ -36,6 +38,7 @@ def __init__(self, args):
36
38
max_batch_size = 1 ,
37
39
use_kv_cache = args .kv_cache ,
38
40
vocab_size = params ["vocab_size" ],
41
+ has_full_logits = args .model in TORCHTUNE_DEFINED_MODELS ,
39
42
)
40
43
self .model = _load_for_executorch (args .pte )
41
44
@@ -58,8 +61,15 @@ def forward(
58
61
59
62
60
63
def build_args_parser () -> argparse .ArgumentParser :
64
+ # TODO: merge these with build_args_parser from export_llama_lib.
61
65
parser = argparse .ArgumentParser ()
62
66
67
+ parser .add_argument (
68
+ "--model" ,
69
+ default = "llama" ,
70
+ choices = EXECUTORCH_DEFINED_MODELS + TORCHTUNE_DEFINED_MODELS ,
71
+ )
72
+
63
73
parser .add_argument (
64
74
"-f" ,
65
75
"--pte" ,
You can’t perform that action at this time.
0 commit comments