Skip to content

Commit 7b76f0f

Browse files
authored
Make TorchTune Llama model KV cache compatible in eager (#6643)
1 parent e384c1a commit 7b76f0f

File tree

8 files changed

+327
-41
lines changed

8 files changed

+327
-41
lines changed

examples/models/llama/runner/eager.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,13 @@
66

77
import argparse
88
import json
9-
from typing import Optional
9+
from typing import Optional, Type
1010

1111
import torch
1212

1313
from executorch.examples.models.llama.export_llama_lib import (
1414
_prepare_for_llama_export,
1515
build_args_parser as _build_args_parser,
16-
TORCHTUNE_DEFINED_MODELS,
1716
)
1817
from executorch.examples.models.llama.runner.generation import LlamaRunner
1918
from executorch.extension.llm.export.builder import LLMEdgeManager
@@ -33,7 +32,6 @@ def __init__(self, args):
3332
max_batch_size=1,
3433
use_kv_cache=args.use_kv_cache,
3534
vocab_size=params["vocab_size"],
36-
has_full_logits=args.model in TORCHTUNE_DEFINED_MODELS,
3735
device="cuda" if torch.cuda.is_available() else "cpu",
3836
)
3937
manager: LLMEdgeManager = _prepare_for_llama_export(args)
@@ -79,11 +77,10 @@ def build_args_parser() -> argparse.ArgumentParser:
7977
return parser
8078

8179

82-
def main() -> None:
80+
def execute_runner(runner_class: Type[LlamaRunner]) -> None:
8381
parser = build_args_parser()
8482
args = parser.parse_args()
85-
86-
runner = EagerLlamaRunner(args)
83+
runner = runner_class(args)
8784
generated_tokens = (
8885
runner.chat_completion(temperature=args.temperature)
8986
if args.chat
@@ -97,5 +94,9 @@ def main() -> None:
9794
print(f"Generated {len(generated_tokens)} tokens: {generated_tokens}")
9895

9996

97+
def main() -> None:
98+
execute_runner(EagerLlamaRunner)
99+
100+
100101
if __name__ == "__main__":
101102
main() # pragma: no cover

examples/models/llama/runner/generation.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ def __init__(
5353
max_batch_size: int,
5454
use_kv_cache: bool,
5555
vocab_size: int,
56-
has_full_logits: bool = False,
5756
device: str = "cpu",
5857
):
5958
"""
@@ -65,14 +64,12 @@ def __init__(
6564
max_batch_size: max batch size.
6665
use_kv_cache: whether to use a KV cache.
6766
vocab_size: number of items in the vocab.
68-
has_full_logits: whether the model returns the full logits or only returns the last logit.
6967
device: device to run the runner on.
7068
"""
7169
self.max_seq_len = max_seq_len
7270
self.max_batch_size = max_batch_size
7371
self.use_kv_cache = use_kv_cache
7472
self.tokenizer = get_tokenizer(tokenizer_path)
75-
self.has_full_logits = has_full_logits
7673
self.device = device
7774
assert vocab_size == self.tokenizer.n_words
7875

@@ -93,7 +90,7 @@ def generate( # noqa: C901
9390
echo: bool = False,
9491
pos_base: int = 0,
9592
) -> List[int]:
96-
# prefill
93+
# Prefill
9794
logits = self.forward(
9895
tokens=torch.tensor([prompt_tokens], dtype=torch.long, device=self.device),
9996
input_pos=(
@@ -103,10 +100,7 @@ def generate( # noqa: C901
103100
),
104101
)
105102

106-
if self.has_full_logits:
107-
current_token = next_token(logits[:, -1, :], temperature, top_p)
108-
else:
109-
current_token = next_token(logits, temperature, top_p)
103+
current_token = next_token(logits, temperature, top_p)
110104
print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True)
111105
tokens = prompt_tokens + [current_token]
112106

@@ -128,10 +122,7 @@ def generate( # noqa: C901
128122
)
129123

130124
# If the logits aren't already clipped to only contain the last logit, clip them.
131-
if self.has_full_logits:
132-
current_token = next_token(logits[:, -1, :], temperature, top_p)
133-
else:
134-
current_token = next_token(logits, temperature, top_p)
125+
current_token = next_token(logits, temperature, top_p)
135126
tokens.append(current_token)
136127

137128
if current_token == self.tokenizer.eos_id or (

examples/models/llama/runner/native.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ def __init__(self, args):
4141
max_batch_size=1,
4242
use_kv_cache=args.kv_cache,
4343
vocab_size=params["vocab_size"],
44-
has_full_logits=args.model in TORCHTUNE_DEFINED_MODELS,
4544
)
4645
self.model = _load_for_executorch(args.pte)
4746

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import json
8+
from typing import Optional
9+
10+
import torch
11+
12+
from executorch.examples.models.llama.export_llama_lib import _prepare_for_llama_export
13+
from executorch.examples.models.llama.runner.eager import execute_runner
14+
from executorch.examples.models.llama3_2_vision.runner.generation import (
15+
TorchTuneLlamaRunner,
16+
)
17+
from executorch.extension.llm.export import LLMEdgeManager
18+
19+
20+
class EagerLlamaRunner(TorchTuneLlamaRunner):
21+
"""
22+
Runs llama in eager mode with provided checkpoint file.
23+
"""
24+
25+
def __init__(self, args):
26+
with open(args.params, "r") as f:
27+
params = json.loads(f.read())
28+
super().__init__(
29+
tokenizer_path=args.tokenizer_path,
30+
max_seq_len=args.max_seq_length,
31+
max_batch_size=1,
32+
use_kv_cache=args.use_kv_cache,
33+
vocab_size=params["vocab_size"],
34+
device="cuda" if torch.cuda.is_available() else "cpu",
35+
)
36+
manager: LLMEdgeManager = _prepare_for_llama_export(args)
37+
self.model = manager.model.eval().to(device=self.device)
38+
39+
def forward(
40+
self,
41+
tokens: Optional[torch.LongTensor] = None,
42+
input_pos: Optional[torch.LongTensor] = None,
43+
mask: Optional[torch.LongTensor] = None,
44+
) -> torch.Tensor:
45+
return self.model.forward(tokens=tokens, input_pos=input_pos, mask=mask)
46+
47+
48+
def main() -> None:
49+
execute_runner(EagerLlamaRunner)
50+
51+
52+
if __name__ == "__main__":
53+
main() # pragma: no cover
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import List
8+
9+
import torch
10+
from executorch.examples.models.llama.runner.generation import LlamaRunner, next_token
11+
12+
13+
class TorchTuneLlamaRunner(LlamaRunner):
14+
def __init__(
15+
self,
16+
tokenizer_path: str,
17+
max_seq_len: int,
18+
max_batch_size: int,
19+
use_kv_cache: bool,
20+
vocab_size: int,
21+
device: str = "cpu",
22+
):
23+
super().__init__(
24+
tokenizer_path,
25+
max_seq_len,
26+
max_batch_size,
27+
use_kv_cache,
28+
vocab_size,
29+
device,
30+
)
31+
32+
self.causal_mask = torch.tril(
33+
torch.ones(
34+
size=(max_seq_len, max_seq_len),
35+
dtype=torch.bool,
36+
)
37+
)
38+
self.input_pos = torch.arange(max_seq_len)
39+
40+
def generate( # noqa: C901
41+
self,
42+
prompt_tokens: List[int],
43+
max_seq_len: int,
44+
temperature: float = 0.8,
45+
top_p: float = 0.9,
46+
echo: bool = False,
47+
) -> List[int]:
48+
# Prefill
49+
seq_len = len(prompt_tokens)
50+
input_pos = self.input_pos[None, :seq_len]
51+
mask = self.causal_mask[None, :seq_len]
52+
if self.use_kv_cache:
53+
logits = self.forward(
54+
tokens=torch.tensor(
55+
[prompt_tokens], dtype=torch.long, device=self.device
56+
),
57+
input_pos=input_pos,
58+
mask=mask,
59+
)
60+
else:
61+
logits = self.forward(
62+
tokens=torch.tensor(
63+
[prompt_tokens], dtype=torch.long, device=self.device
64+
),
65+
)
66+
67+
# Only need the last logit.
68+
current_token = next_token(logits[:, -1, :], temperature, top_p)
69+
print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True)
70+
tokens = prompt_tokens + [current_token]
71+
72+
while len(tokens) < max_seq_len:
73+
mask = self.causal_mask[None, seq_len, None, :]
74+
input_pos = self.input_pos[None, seq_len, None]
75+
if self.use_kv_cache:
76+
logits = self.forward(
77+
tokens=torch.tensor(
78+
[[current_token]], dtype=torch.long, device=self.device
79+
),
80+
input_pos=input_pos,
81+
mask=mask,
82+
)
83+
else:
84+
logits = self.forward(
85+
tokens=torch.tensor([tokens], dtype=torch.long, device=self.device),
86+
)
87+
88+
# Only need the last logit.
89+
current_token = next_token(logits[:, -1, :], temperature, top_p)
90+
tokens.append(current_token)
91+
92+
if current_token == self.tokenizer.eos_id or (
93+
hasattr(self.tokenizer, "stop_tokens")
94+
and current_token in self.tokenizer.stop_tokens
95+
):
96+
break
97+
98+
print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True)
99+
seq_len += 1
100+
101+
return tokens if echo else tokens[len(prompt_tokens) :]
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import argparse
8+
import json
9+
from typing import Optional
10+
11+
import torch
12+
13+
from executorch.examples.models.llama.export_llama_lib import (
14+
EXECUTORCH_DEFINED_MODELS,
15+
TORCHTUNE_DEFINED_MODELS,
16+
)
17+
from executorch.examples.models.llama3_2_vision.runner.generation import (
18+
TorchTuneLlamaRunner,
19+
)
20+
21+
from executorch.extension.pybindings.portable_lib import _load_for_executorch
22+
23+
# Load custom ops and quantized ops.
24+
from executorch.extension.pybindings import portable_lib # noqa # usort: skip
25+
26+
# Note: import this after portable_lib
27+
from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip
28+
from executorch.kernels import quantized # noqa
29+
30+
31+
class NativeLlamaRunner(TorchTuneLlamaRunner):
32+
"""
33+
Runs llama via ExecuTorch with provided pte file.
34+
"""
35+
36+
def __init__(self, args):
37+
with open(args.params, "r") as f:
38+
params = json.loads(f.read())
39+
super().__init__(
40+
tokenizer_path=args.tokenizer,
41+
max_seq_len=args.max_len,
42+
max_batch_size=1,
43+
use_kv_cache=args.kv_cache,
44+
vocab_size=params["vocab_size"],
45+
)
46+
self.model = _load_for_executorch(args.pte)
47+
self.use_kv_cache = args.kv_cache
48+
49+
def forward(
50+
self,
51+
tokens: torch.Tensor,
52+
input_pos: Optional[torch.Tensor] = None,
53+
mask: Optional[torch.LongTensor] = None,
54+
) -> torch.Tensor:
55+
return (
56+
self.model.forward((tokens, input_pos, mask))
57+
if self.use_kv_cache
58+
else self.model.forward((tokens,))
59+
)[0]
60+
61+
62+
def build_args_parser() -> argparse.ArgumentParser:
63+
# TODO: merge these with build_args_parser from export_llama_lib.
64+
parser = argparse.ArgumentParser()
65+
66+
parser.add_argument(
67+
"--model",
68+
default="llama3",
69+
choices=EXECUTORCH_DEFINED_MODELS + TORCHTUNE_DEFINED_MODELS,
70+
)
71+
72+
parser.add_argument(
73+
"-f",
74+
"--pte",
75+
type=str,
76+
default=None,
77+
help="path to exported executorch .pte file",
78+
)
79+
80+
parser.add_argument(
81+
"-p", "--params", type=str, default=None, help="model params file"
82+
)
83+
84+
parser.add_argument(
85+
"-t",
86+
"--tokenizer",
87+
type=str,
88+
default=None,
89+
)
90+
91+
parser.add_argument(
92+
"--prompt",
93+
type=str,
94+
default="Hello",
95+
)
96+
97+
parser.add_argument(
98+
"--temperature",
99+
type=float,
100+
default=0.6,
101+
)
102+
103+
parser.add_argument(
104+
"-kv",
105+
"--kv_cache",
106+
action="store_true",
107+
)
108+
109+
parser.add_argument(
110+
"--max_len",
111+
type=int,
112+
default=128,
113+
help="Maximum length of the generated response sequence.",
114+
)
115+
116+
return parser
117+
118+
119+
def main() -> None:
120+
parser = build_args_parser()
121+
args = parser.parse_args()
122+
runner = NativeLlamaRunner(args)
123+
generated_tokens = runner.text_completion(
124+
prompt=args.prompt,
125+
temperature=args.temperature,
126+
)
127+
print(f"Response: {generated_tokens}")
128+
129+
130+
if __name__ == "__main__":
131+
main() # pragma: no cover

0 commit comments

Comments
 (0)