Skip to content

Commit b8ff8e2

Browse files
committed
Things work
1 parent 7f81e00 commit b8ff8e2

File tree

4 files changed

+46
-24
lines changed

4 files changed

+46
-24
lines changed

examples/models/llama2/runner/eager.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
import torch
1212

13-
from examples.models.llama2.llama_transformer import ModelArgs
1413
from executorch.examples.models.model_factory import EagerModelFactory
1514

1615
from .generation import LlamaRunner
@@ -24,13 +23,13 @@ class EagerLlamaRunner(LlamaRunner):
2423
def __init__(self, args):
2524
with open(args.params, "r") as f:
2625
params = json.loads(f.read())
27-
model_args: ModelArgs = ModelArgs(
28-
max_seq_len=args.max_len,
26+
super().__init__(
27+
tokenizer_path=args.tokenizer,
28+
max_seq_len=args.max_seq_len,
2929
max_batch_size=1,
3030
use_kv_cache=True,
31-
**params,
31+
vocab_size=params["vocab_size"],
3232
)
33-
super().__init__(tokenizer_path=args.tokenizer, model_args=model_args)
3433
self.model, _, _, _ = EagerModelFactory.create_model(
3534
"llama2",
3635
"Llama2Model",

examples/models/llama2/runner/generation.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,19 @@ def next_token(logits: torch.Tensor, temperature: float, top_p: float) -> int:
5151

5252

5353
class LlamaRunner(ABC):
54-
def __init__(self, tokenizer_path: str, model_args: ModelArgs):
55-
self.params = model_args
54+
def __init__(
55+
self,
56+
tokenizer_path: str,
57+
max_seq_len: int,
58+
max_batch_size: int,
59+
use_kv_cache: bool,
60+
vocab_size: int,
61+
):
62+
self.max_seq_len = max_seq_len
63+
self.max_batch_size = max_batch_size
64+
self.use_kv_cache = use_kv_cache
5665
self.tokenizer = Tokenizer(tokenizer_path)
57-
assert model_args.vocab_size == self.tokenizer.n_words
66+
assert vocab_size == self.tokenizer.n_words
5867

5968
@abstractmethod
6069
def forward(
@@ -75,27 +84,35 @@ def generate( # noqa: C901
7584
logits = self.forward(
7685
tokens=torch.tensor([prompt_tokens], dtype=torch.long),
7786
input_pos=(
78-
torch.tensor([0], dtype=torch.long)
79-
if self.params.use_kv_cache
80-
else None
87+
torch.tensor([0], dtype=torch.long) if self.use_kv_cache else None
8188
),
8289
)
8390

84-
current_token = next_token(logits, temperature, top_p)
91+
# TODO: accomodate TorchTune model, which doesn't
92+
# make an optimization of dropping all logits but the last.
93+
current_token = next_token(logits[:, -1, :], temperature, top_p)
8594
tokens = prompt_tokens + [current_token]
8695

87-
while len(tokens) < self.params.max_seq_len:
88-
if self.params.use_kv_cache:
96+
i = 0
97+
while len(tokens) < self.max_seq_len:
98+
print(f"{i} out of {self.max_seq_len} max tokens generated")
99+
if self.use_kv_cache:
89100
logits = self.forward(
90101
tokens=torch.tensor([[current_token]], dtype=torch.long),
91102
input_pos=torch.tensor([len(tokens) - 1], dtype=torch.long),
92103
)
93104
else:
94-
logits = self.forward(tokens=torch.tensor([tokens], dtype=torch.long))
95-
current_token = next_token(logits, temperature, top_p)
96-
if current_token in self.tokenizer.stop_tokens:
105+
logits = self.forward(
106+
tokens=torch.tensor([tokens], dtype=torch.long, device=self.device),
107+
)
108+
current_token = next_token(logits[:, -1, :], temperature, top_p)
109+
if current_token == self.tokenizer.eos_id or (
110+
hasattr(self.tokenizer, "stop_tokens")
111+
and current_token in self.tokenizer.stop_tokens
112+
):
97113
break
98114
tokens.append(current_token)
115+
i += 1
99116

100117
return tokens if echo else tokens[len(prompt_tokens) :]
101118

examples/models/llama2/runner/native.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,17 @@
1010

1111
import torch
1212

13-
from executorch.examples.models.llama2.llama_transformer import ModelArgs
1413
from executorch.extension.pybindings.portable_lib import _load_for_executorch
1514

1615
# Load custom ops and quantized ops.
1716
from executorch.extension.pybindings import portable_lib # noqa # usort: skip
1817

18+
from executorch.examples.models.llama2.runner.generation import LlamaRunner
19+
1920
# Note: import this after portable_lib
2021
# from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip
2122
from executorch.kernels import quantized # noqa
2223

23-
from executorch.examples.models.llama2.runner.generation import LlamaRunner
24-
2524

2625
class NativeLlamaRunner(LlamaRunner):
2726
"""
@@ -31,13 +30,13 @@ class NativeLlamaRunner(LlamaRunner):
3130
def __init__(self, args):
3231
with open(args.params, "r") as f:
3332
params = json.loads(f.read())
34-
model_args: ModelArgs = ModelArgs(
33+
super().__init__(
34+
tokenizer_path=args.tokenizer,
3535
max_seq_len=args.max_len,
3636
max_batch_size=1,
3737
use_kv_cache=args.kv_cache,
3838
vocab_size=params["vocab_size"],
3939
)
40-
super().__init__(tokenizer_path=args.tokenizer, model_args=model_args)
4140
self.model = _load_for_executorch(args.pte)
4241

4342
def forward(

extension/llm/export/builder.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,12 +193,19 @@ def capture_pre_autograd_graph(self) -> "LLMEdgeManager":
193193
strict=True,
194194
).module()
195195
else:
196-
self.pre_autograd_graph_module = capture_pre_autograd_graph(
196+
# pyre-fixme[8]: Attribute has type `Optional[GraphModule]`; used as
197+
# `Module`.
198+
print("Exporting with:")
199+
print(f"inputs: {self.example_inputs}")
200+
print(f"kwargs: {self.example_kwarg_inputs}")
201+
print(f"dynamic shapes: {dynamic_shape}")
202+
203+
self.pre_autograd_graph_module = export_for_training(
197204
self.model,
198205
self.example_inputs,
199206
kwargs=self.example_kwarg_inputs,
200207
dynamic_shapes=dynamic_shape,
201-
)
208+
).module()
202209

203210
return self
204211

0 commit comments

Comments
 (0)