Skip to content

Commit 1dd12f0

Browse files
committed
Merge branch 'jz/native-runner-tt' into jz/tt-llama-3
2 parents 310b3a3 + 37011d3 commit 1dd12f0

File tree

5 files changed

+89
-30
lines changed

5 files changed

+89
-30
lines changed

examples/models/llama/runner/eager.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010

1111
import torch
1212

13-
from examples.models.llama.llama_transformer import ModelArgs
1413
from executorch.examples.models.llama.export_llama_lib import (
1514
_prepare_for_llama_export,
1615
build_args_parser as _build_args_parser,
16+
TORCHTUNE_DEFINED_MODELS,
1717
)
1818
from executorch.examples.models.llama.runner.generation import LlamaRunner
1919
from executorch.extension.llm.export import LLMEdgeManager
@@ -27,15 +27,13 @@ class EagerLlamaRunner(LlamaRunner):
2727
def __init__(self, args):
2828
with open(args.params, "r") as f:
2929
params = json.loads(f.read())
30-
model_args: ModelArgs = ModelArgs(
30+
super().__init__(
31+
tokenizer_path=args.tokenizer_path,
3132
max_seq_len=args.max_seq_length,
3233
max_batch_size=1,
3334
use_kv_cache=args.use_kv_cache,
34-
**params,
35-
)
36-
super().__init__(
37-
tokenizer_path=args.tokenizer_path,
38-
model_args=model_args,
35+
vocab_size=params["vocab_size"],
36+
has_full_logits=args.model in TORCHTUNE_DEFINED_MODELS,
3937
device="cuda" if torch.cuda.is_available() else "cpu",
4038
)
4139
manager: LLMEdgeManager = _prepare_for_llama_export(args)

examples/models/llama/runner/generation.py

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import torch
1111

12-
from executorch.examples.models.llama.llama_transformer import ModelArgs
1312
from executorch.extension.llm.tokenizer.utils import get_tokenizer
1413

1514

@@ -51,11 +50,35 @@ def next_token(logits: torch.Tensor, temperature: float, top_p: float) -> int:
5150

5251

5352
class LlamaRunner(ABC):
54-
def __init__(self, tokenizer_path: str, model_args: ModelArgs, device: str = "cpu"):
55-
self.params = model_args
53+
def __init__(
54+
self,
55+
tokenizer_path: str,
56+
max_seq_len: int,
57+
max_batch_size: int,
58+
use_kv_cache: bool,
59+
vocab_size: int,
60+
has_full_logits: bool = False,
61+
device: str = "cpu",
62+
):
63+
"""
64+
Constructor.
65+
66+
Args:
67+
tokenizer_path: path to tokenizer.model file.
68+
max_seq_len: max length of the output sequence, after which the output will be clipped.
69+
max_batch_size: max batch size.
70+
use_kv_cache: whether to use a KV cache.
71+
vocab_size: number of items in the vocab.
72+
has_full_logits: whether the model returns the full logits or only returns the last logit.
73+
device: device to run the runner on.
74+
"""
75+
self.max_seq_len = max_seq_len
76+
self.max_batch_size = max_batch_size
77+
self.use_kv_cache = use_kv_cache
5678
self.tokenizer = get_tokenizer(tokenizer_path)
57-
assert model_args.vocab_size == self.tokenizer.n_words
79+
self.has_full_logits = has_full_logits
5880
self.device = device
81+
assert vocab_size == self.tokenizer.n_words
5982

6083
@abstractmethod
6184
def forward(
@@ -77,16 +100,22 @@ def generate( # noqa: C901
77100
tokens=torch.tensor([prompt_tokens], dtype=torch.long, device=self.device),
78101
input_pos=(
79102
torch.tensor([0], dtype=torch.long, device=self.device)
80-
if self.params.use_kv_cache
103+
if self.use_kv_cache
81104
else None
82105
),
83106
)
84107

85-
current_token = next_token(logits, temperature, top_p)
108+
current_token = next_token(logits[:, -1, :], temperature, top_p)
109+
if self.has_full_logits:
110+
current_token = next_token(logits[:, -1, :], temperature, top_p)
111+
else:
112+
current_token = next_token(logits, temperature, top_p)
86113
tokens = prompt_tokens + [current_token]
87114

88-
while len(tokens) < self.params.max_seq_len:
89-
if self.params.use_kv_cache:
115+
i = 0
116+
while len(tokens) < self.max_seq_len:
117+
print(f"{i} out of {self.max_seq_len} max tokens generated")
118+
if self.use_kv_cache:
90119
logits = self.forward(
91120
tokens=torch.tensor(
92121
[[current_token]], dtype=torch.long, device=self.device
@@ -99,13 +128,21 @@ def generate( # noqa: C901
99128
logits = self.forward(
100129
tokens=torch.tensor([tokens], dtype=torch.long, device=self.device),
101130
)
102-
current_token = next_token(logits, temperature, top_p)
131+
132+
# If the logits aren't already clipped to only contain the last logit, clip them.
133+
if self.has_full_logits:
134+
current_token = next_token(logits[:, -1, :], temperature, top_p)
135+
else:
136+
current_token = next_token(logits, temperature, top_p)
137+
103138
if current_token == self.tokenizer.eos_id or (
104139
hasattr(self.tokenizer, "stop_tokens")
105140
and current_token in self.tokenizer.stop_tokens
106141
):
107142
break
143+
108144
tokens.append(current_token)
145+
i += 1
109146

110147
return tokens if echo else tokens[len(prompt_tokens) :]
111148

examples/models/llama/runner/native.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,22 @@
1010

1111
import torch
1212

13-
from examples.models.llama.llama_transformer import ModelArgs
13+
from executorch.examples.models.llama.export_llama_lib import (
14+
EXECUTORCH_DEFINED_MODELS,
15+
TORCHTUNE_DEFINED_MODELS,
16+
)
17+
1418
from executorch.extension.pybindings.portable_lib import _load_for_executorch
1519

1620
# Load custom ops and quantized ops.
1721
from executorch.extension.pybindings import portable_lib # noqa # usort: skip
1822

23+
from executorch.examples.models.llama.runner.generation import LlamaRunner
24+
1925
# Note: import this after portable_lib
20-
from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip
26+
# from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip
2127
from executorch.kernels import quantized # noqa
2228

23-
from .generation import LlamaRunner
24-
2529

2630
class NativeLlamaRunner(LlamaRunner):
2731
"""
@@ -31,30 +35,44 @@ class NativeLlamaRunner(LlamaRunner):
3135
def __init__(self, args):
3236
with open(args.params, "r") as f:
3337
params = json.loads(f.read())
34-
model_args: ModelArgs = ModelArgs(
38+
super().__init__(
39+
tokenizer_path=args.tokenizer,
3540
max_seq_len=args.max_len,
3641
max_batch_size=1,
3742
use_kv_cache=args.kv_cache,
38-
**params,
43+
vocab_size=params["vocab_size"],
44+
has_full_logits=args.model in TORCHTUNE_DEFINED_MODELS,
3945
)
40-
super().__init__(tokenizer_path=args.tokenizer, model_args=model_args)
4146
self.model = _load_for_executorch(args.pte)
4247

4348
def forward(
4449
self,
4550
tokens: Optional[torch.LongTensor] = None,
4651
input_pos: Optional[torch.LongTensor] = None,
4752
) -> torch.Tensor:
48-
return (
49-
self.model.forward((tokens, input_pos))
50-
if input_pos is not None
51-
else self.model.forward((tokens,))
52-
)[0]
53+
# TODO: in LlamaRunner there is a generate function that automatically generates
54+
# input_pos tensor and inputs it into the model. Atm TorchTune models use
55+
# kwargs for the input_pos, so we will need to make some changes. At least
56+
# for the time being, we can run the non-kv cache version of the Torchtune
57+
# model with just the tokens like below.
58+
return (self.model.forward((tokens,)))[0]
59+
# return (
60+
# self.model.forward((tokens, input_pos))
61+
# if input_pos is not None
62+
# else self.model.forward((tokens,))
63+
# )[0]
5364

5465

5566
def build_args_parser() -> argparse.ArgumentParser:
67+
# TODO: merge these with build_args_parser from export_llama_lib.
5668
parser = argparse.ArgumentParser()
5769

70+
parser.add_argument(
71+
"--model",
72+
default="llama",
73+
choices=EXECUTORCH_DEFINED_MODELS + TORCHTUNE_DEFINED_MODELS,
74+
)
75+
5876
parser.add_argument(
5977
"-f",
6078
"--pte",
@@ -89,7 +107,6 @@ def build_args_parser() -> argparse.ArgumentParser:
89107
parser.add_argument(
90108
"-kv",
91109
"--kv_cache",
92-
default=True,
93110
action="store_true",
94111
)
95112

examples/models/llama3_2_vision/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@ class Llama3_2Decoder(EagerModelBase):
4040

4141
def __init__(self, **kwargs):
4242
# Set member vars from kwargs.
43-
self.max_seq_len = kwargs.get("max_seq_len", 8192) # Trained to be a lot larger, but this value is kept small because of static kv cache at the moment.
43+
self.max_seq_len = kwargs.get(
44+
"max_seq_len", 8192
45+
) # Trained to be a lot larger, but this value is kept small because of static kv cache at the moment.
4446
self.encoder_max_seq_len = kwargs.get(
4547
"encoder_max_seq_len", int(4 * (448 / 14) ** 2 + 1)
4648
) # Same as above.

extension/llm/export/builder.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,11 @@ def export(self) -> "LLMEdgeManager":
194194
strict=True,
195195
).module()
196196
else:
197+
print("Exporting with:")
198+
print(f"inputs: {self.example_inputs}")
199+
print(f"kwargs: {self.example_kwarg_inputs}")
200+
print(f"dynamic shapes: {dynamic_shape}")
201+
197202
# pyre-fixme[8]: Attribute has type `Optional[GraphModule]`; used as
198203
# `Module`.
199204
self.pre_autograd_graph_module = export_for_training(

0 commit comments

Comments
 (0)