Skip to content

Commit 37011d3

Browse files
committed
Clean up
1 parent f275e2e commit 37011d3

File tree

4 files changed

+15
-8
lines changed

4 files changed

+15
-8
lines changed

examples/models/llama/runner/eager.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
import torch
1212

13-
from examples.models.llama.llama_transformer import ModelArgs
1413
from executorch.examples.models.llama.export_llama_lib import (
1514
_prepare_for_llama_export,
1615
build_args_parser as _build_args_parser,

examples/models/llama/runner/generation.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import torch
1111

12-
from executorch.examples.models.llama.llama_transformer import ModelArgs
1312
from executorch.extension.llm.tokenizer.utils import get_tokenizer
1413

1514

@@ -63,7 +62,7 @@ def __init__(
6362
):
6463
"""
6564
Constructor.
66-
65+
6766
Args:
6867
tokenizer_path: path to tokenizer.model file.
6968
max_seq_len: max length of the output sequence, after which the output will be clipped.
@@ -100,13 +99,17 @@ def generate( # noqa: C901
10099
logits = self.forward(
101100
tokens=torch.tensor([prompt_tokens], dtype=torch.long, device=self.device),
102101
input_pos=(
103-
torch.tensor([0], dtype=torch.long, device=self.device) if self.use_kv_cache else None
102+
torch.tensor([0], dtype=torch.long, device=self.device)
103+
if self.use_kv_cache
104+
else None
104105
),
105106
)
106107

107-
# TODO: accomodate TorchTune model, which doesn't
108-
# make an optimization of dropping all logits but the last.
109108
current_token = next_token(logits[:, -1, :], temperature, top_p)
109+
if self.has_full_logits:
110+
current_token = next_token(logits[:, -1, :], temperature, top_p)
111+
else:
112+
current_token = next_token(logits, temperature, top_p)
110113
tokens = prompt_tokens + [current_token]
111114

112115
i = 0

examples/models/llama/runner/native.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010

1111
import torch
1212

13-
from executorch.examples.models.llama.export_llama_lib import EXECUTORCH_DEFINED_MODELS, TORCHTUNE_DEFINED_MODELS
13+
from executorch.examples.models.llama.export_llama_lib import (
14+
EXECUTORCH_DEFINED_MODELS,
15+
TORCHTUNE_DEFINED_MODELS,
16+
)
1417

1518
from executorch.extension.pybindings.portable_lib import _load_for_executorch
1619

examples/models/llama3_2_vision/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@ class Llama3_2Decoder(EagerModelBase):
4040

4141
def __init__(self, **kwargs):
4242
# Set member vars from kwargs.
43-
self.max_seq_len = kwargs.get("max_seq_len", 8192) # Trained to be a lot larger, but this value is kept small because of static kv cache at the moment.
43+
self.max_seq_len = kwargs.get(
44+
"max_seq_len", 8192
45+
) # Trained to be a lot larger, but this value is kept small because of static kv cache at the moment.
4446
self.encoder_max_seq_len = kwargs.get(
4547
"encoder_max_seq_len", int(4 * (448 / 14) ** 2 + 1)
4648
) # Same as above.

0 commit comments

Comments
 (0)