File tree Expand file tree Collapse file tree 2 files changed +5
-4
lines changed Expand file tree Collapse file tree 2 files changed +5
-4
lines changed Original file line number Diff line number Diff line change @@ -131,7 +131,7 @@ Sampler::Sampler(
131
131
float topp,
132
132
unsigned long long rng_seed)
133
133
: vocab_size_(vocab_size),
134
- temperature_ ( temperature),
134
+ inv_temperature_ ( static_cast < bool >( temperature) ? 1.0f / temperature : 0 ),
135
135
topp_(topp),
136
136
rng_state_(rng_seed) {}
137
137
@@ -172,13 +172,13 @@ template <typename T>
172
172
int32_t Sampler::sample (T* logits) {
173
173
// sample the token given the logits and some hyperparameters
174
174
int next;
175
- if (temperature_ == 0 .0f ) {
175
+ if (inv_temperature_ == 0 .0f ) {
176
176
// greedy argmax sampling: take the token with the highest probability
177
177
next = sample_argmax (logits);
178
178
} else {
179
179
// apply the temperature to the logits
180
180
for (int q = 0 ; q < vocab_size_; q++) {
181
- logits[q] /= temperature_ ;
181
+ logits[q] *= inv_temperature_ ;
182
182
}
183
183
// apply softmax to the logits to get the probabilities for next token
184
184
softmax (logits, vocab_size_);
Original file line number Diff line number Diff line change @@ -51,7 +51,8 @@ class Sampler {
51
51
52
52
private:
53
53
int32_t vocab_size_;
54
- float temperature_;
54
+ // reciprocal of temperature, or 0 if temperature == 0.
55
+ float inv_temperature_;
55
56
float topp_;
56
57
unsigned long long rng_state_;
57
58
};
You can’t perform that action at this time.
0 commit comments