Skip to content

Commit e2ca877

Browse files
[executorch] Avoid division in Sampler::sample (#4656)
Differential Revision: D61041442 Pull Request resolved: #4646 Co-authored-by: Scott Wolchok <[email protected]>
1 parent f7684ad commit e2ca877

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

extension/llm/sampler/sampler.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ Sampler::Sampler(
131131
float topp,
132132
unsigned long long rng_seed)
133133
: vocab_size_(vocab_size),
134-
temperature_(temperature),
134+
inv_temperature_(static_cast<bool>(temperature) ? 1.0f / temperature : 0),
135135
topp_(topp),
136136
rng_state_(rng_seed) {}
137137

@@ -172,13 +172,13 @@ template <typename T>
172172
int32_t Sampler::sample(T* logits) {
173173
// sample the token given the logits and some hyperparameters
174174
int next;
175-
if (temperature_ == 0.0f) {
175+
if (inv_temperature_ == 0.0f) {
176176
// greedy argmax sampling: take the token with the highest probability
177177
next = sample_argmax(logits);
178178
} else {
179179
// apply the temperature to the logits
180180
for (int q = 0; q < vocab_size_; q++) {
181-
logits[q] /= temperature_;
181+
logits[q] *= inv_temperature_;
182182
}
183183
// apply softmax to the logits to get the probabilities for next token
184184
softmax(logits, vocab_size_);

extension/llm/sampler/sampler.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ class Sampler {
5151

5252
private:
5353
int32_t vocab_size_;
54-
float temperature_;
54+
// reciprocal of temperature, or 0 if temperature == 0.
55+
float inv_temperature_;
5556
float topp_;
5657
unsigned long long rng_state_;
5758
};

0 commit comments

Comments
 (0)