Skip to content

Commit 1e9e5d0

Browse files
authored
update generation.py to run in eager mode as well
Differential Revision: D61226855 Pull Request resolved: #4702
1 parent 6982c03 commit 1e9e5d0

File tree

3 files changed

+286
-216
lines changed

3 files changed

+286
-216
lines changed
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import argparse
8+
import json
9+
from typing import Optional
10+
11+
import torch
12+
13+
from examples.models.llama2.llama_transformer import ModelArgs
14+
from executorch.examples.models.model_factory import EagerModelFactory
15+
16+
from .generation import LlamaRunner
17+
18+
19+
class EagerLlamaRunner(LlamaRunner):
20+
"""
21+
Runs llama in eager mode with provided checkpoint file.
22+
"""
23+
24+
def __init__(self, args):
25+
with open(args.params, "r") as f:
26+
params = json.loads(f.read())
27+
model_args: ModelArgs = ModelArgs(
28+
max_seq_len=args.max_len,
29+
max_batch_size=1,
30+
use_kv_cache=True,
31+
**params,
32+
)
33+
super().__init__(tokenizer_path=args.tokenizer, model_args=model_args)
34+
self.model, _, _ = EagerModelFactory.create_model(
35+
"llama2",
36+
"Llama2Model",
37+
checkpoint=args.checkpoint,
38+
params=args.params,
39+
use_kv_cache=True,
40+
fairseq2=False,
41+
max_seq_len=args.max_len,
42+
enable_dynamic_shape=True,
43+
)
44+
45+
def forward(
46+
self,
47+
tokens: Optional[torch.LongTensor] = None,
48+
input_pos: Optional[torch.LongTensor] = None,
49+
) -> torch.Tensor:
50+
return self.model.forward(tokens=tokens, input_pos=input_pos)
51+
52+
53+
def build_args_parser() -> argparse.ArgumentParser:
54+
parser = argparse.ArgumentParser()
55+
56+
parser.add_argument(
57+
"--checkpoint",
58+
type=str,
59+
default=None,
60+
help="path to model checkpoint file",
61+
)
62+
63+
parser.add_argument(
64+
"--params",
65+
type=str,
66+
default=None,
67+
help="model params file",
68+
)
69+
70+
parser.add_argument(
71+
"--max_len",
72+
type=int,
73+
default=128,
74+
help="Maximum length of the generated response sequence.",
75+
)
76+
77+
parser.add_argument(
78+
"--tokenizer",
79+
type=str,
80+
default=None,
81+
)
82+
83+
parser.add_argument(
84+
"--prompt",
85+
type=str,
86+
default="Hello",
87+
)
88+
89+
parser.add_argument(
90+
"--temperature",
91+
type=float,
92+
default=0,
93+
)
94+
95+
return parser
96+
97+
98+
def main() -> None:
99+
parser = build_args_parser()
100+
args = parser.parse_args()
101+
102+
runner = EagerLlamaRunner(args)
103+
result = runner.text_completion(
104+
prompt=args.prompt,
105+
temperature=args.temperature,
106+
)
107+
print(
108+
"Response: \n{response}\n Tokens:\n {tokens}".format(
109+
response=result["generation"], tokens=result["tokens"]
110+
)
111+
)
112+
113+
114+
if __name__ == "__main__":
115+
main() # pragma: no cover

0 commit comments

Comments
 (0)