Skip to content

Commit 993c69d

Browse files
Jack-Khuufacebook-github-bot
authored andcommitted
Plumb max_seq_len into llama model (#2566)
Summary: Currently the field is hard coded, we want to control this field Differential Revision: D55204688
1 parent 12b5324 commit 993c69d

File tree

4 files changed

+13
-7
lines changed

4 files changed

+13
-7
lines changed

examples/models/llama2/builder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def load_llama_model(
6868
use_sdpa_with_kv_cache: bool = False,
6969
weight_type: WeightType = WeightType.LLAMA,
7070
verbose: bool = False,
71+
max_seq_len: int = 128,
7172
) -> "LlamaEdgeManager":
7273
"""
7374
A helper util that builds a Llama2 model. It returns a LlamaEdgeManager that
@@ -87,6 +88,7 @@ def load_llama_model(
8788
use_kv_cache=use_kv_cache,
8889
use_sdpa_with_kv_cache=use_sdpa_with_kv_cache,
8990
fairseq2=weight_type == WeightType.FAIRSEQ2,
91+
max_seq_len=max_seq_len,
9092
)
9193
state_dict = model.state_dict()
9294
dtype = state_dict[next(iter(state_dict))].dtype

examples/models/llama2/eval_llama_lib.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -140,12 +140,6 @@ def build_args_parser() -> argparse.ArgumentParser:
140140
parser.add_argument(
141141
"--limit", type=int, default=5, help="number of samples to evalulate"
142142
)
143-
parser.add_argument(
144-
"--max_seq_length",
145-
type=int,
146-
default=100,
147-
help="maximum length sequence to evaluate",
148-
)
149143

150144
return parser
151145

examples/models/llama2/export_llama_lib.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,13 @@ def build_args_parser() -> argparse.ArgumentParser:
391391
help="Override the output filename of the saved pte model file.",
392392
)
393393

394+
parser.add_argument(
395+
"--max_seq_length",
396+
type=int,
397+
default=128,
398+
help="maximum length sequence to evaluate",
399+
)
400+
394401
parser.add_argument("-2", "--fairseq2", action="store_true")
395402
parser.add_argument("-v", "--verbose", action="store_true")
396403
parser.add_argument("-X", "--xnnpack", action="store_true")
@@ -511,6 +518,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
511518
use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache,
512519
weight_type=weight_type,
513520
verbose=args.verbose,
521+
max_seq_len=args.max_seq_length,
514522
)
515523
.set_output_dir(output_dir_path)
516524
.set_metadata(args.metadata)

examples/models/llama2/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ def __init__(self, **kwargs):
6666
if "use_sdpa_with_kv_cache" in kwargs
6767
else False
6868
)
69+
70+
self.max_seq_len = kwargs["max_seq_len"] if "max_seq_len" in kwargs else 128
6971
# The example is using a dummy small model with random weights for demo purpose only.
7072
# Follow the instruction in https://github.com/facebookresearch/llama to download the model
7173
device = "cpu"
@@ -112,7 +114,7 @@ def __init__(self, **kwargs):
112114
)
113115
with open(params_path, "r") as f:
114116
params = json.loads(f.read())
115-
max_seq_len = 128
117+
max_seq_len = self.max_seq_len
116118
max_batch_size = 1
117119
model_args: ModelArgs = ModelArgs(
118120
max_seq_len=max_seq_len,

0 commit comments

Comments
 (0)