Skip to content

Commit f79420d

Browse files
swolchokmalfet
authored andcommitted
Unbreak zero-temperature sampling (#599)
Fixes #581.
1 parent 46a6b9c commit f79420d

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

generate.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -155,10 +155,9 @@ def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = Non
155155
def sample(
156156
logits, need_probs: bool, temperature: float = 1.0, top_k: Optional[int] = None
157157
):
158-
# if temperature == 0 and not need_probs:
159-
# _, idx_next = torch.topk(logits, k=1, dim=-1)
160-
# idx_next = idx_next.squeeze(dim=(0, 1))
161-
# return (idx_next, None)
158+
if temperature == 0 and not need_probs:
159+
_, idx_next = torch.topk(logits[0,-1], k=1, dim=-1)
160+
return (idx_next, None)
162161
probs = logits_to_probs(logits[0, -1], temperature, top_k)
163162
idx_next = multinomial_sample_one_no_sync(probs)
164163
return idx_next, probs

0 commit comments

Comments
 (0)