Skip to content

Commit 6e38763

Browse files
committed
Clip logits if torchtune
1 parent 25ec7ce commit 6e38763

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

examples/models/llama/runner/generation.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,15 +127,18 @@ def generate( # noqa: C901
127127
tokens=torch.tensor([tokens], dtype=torch.long, device=self.device),
128128
)
129129

130-
if self.has_full_logits:
131-
current_token = next_token(logits[:, -1, :], temperature, top_p)
132-
else:
133-
current_token = next_token(logits, temperature, top_p)
130+
# If the logits aren't already clipped to only contain the last logit, clip them.
131+
if self.has_full_logits:
132+
current_token = next_token(logits[:, -1, :], temperature, top_p)
133+
else:
134+
current_token = next_token(logits, temperature, top_p)
135+
134136
if current_token == self.tokenizer.eos_id or (
135137
hasattr(self.tokenizer, "stop_tokens")
136138
and current_token in self.tokenizer.stop_tokens
137139
):
138140
break
141+
139142
tokens.append(current_token)
140143
i += 1
141144

extension/llm/export/builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,13 +194,13 @@ def export(self) -> "LLMEdgeManager":
194194
strict=True,
195195
).module()
196196
else:
197-
# pyre-fixme[8]: Attribute has type `Optional[GraphModule]`; used as
198-
# `Module`.
199197
print("Exporting with:")
200198
print(f"inputs: {self.example_inputs}")
201199
print(f"kwargs: {self.example_kwarg_inputs}")
202200
print(f"dynamic shapes: {dynamic_shape}")
203201

202+
# pyre-fixme[8]: Attribute has type `Optional[GraphModule]`; used as
203+
# `Module`.
204204
self.pre_autograd_graph_module = export_for_training(
205205
self.model,
206206
self.example_inputs,

0 commit comments

Comments
 (0)