Skip to content

Commit 4953e90

Browse files
llama : always sort logits before nucleus sampling (#812)
* Always sort logits before nucleus sampling * remove second normalization - fix windows build - remove normalization since std::discrete_distribution does not require it
1 parent cc9cee8 commit 4953e90

File tree

1 file changed

+3
-14
lines changed

1 file changed

+3
-14
lines changed

llama.cpp

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1236,19 +1236,13 @@ static llama_vocab::id llama_sample_top_p_top_k(
12361236
}
12371237
}
12381238

1239-
if (top_k > 0 && top_k < n_logits) {
1240-
sample_top_k(logits_id, top_k);
1241-
}
1242-
1243-
float maxl = -std::numeric_limits<float>::infinity();
1244-
for (const auto & kv : logits_id) {
1245-
maxl = Max(maxl, kv.first);
1246-
}
1239+
sample_top_k(logits_id, top_k > 0 ? Min(top_k, n_logits) : n_logits);
12471240

12481241
// compute probs for the top k tokens
12491242
std::vector<float> probs;
12501243
probs.reserve(logits_id.size());
12511244

1245+
float maxl = logits_id[0].first;
12521246
double sum = 0.0;
12531247
for (const auto & kv : logits_id) {
12541248
const float p = expf(kv.first - maxl);
@@ -1271,16 +1265,11 @@ static llama_vocab::id llama_sample_top_p_top_k(
12711265
break;
12721266
}
12731267
}
1274-
1275-
cumsum = 1.0/cumsum;
1276-
for (int i = 0; i < (int) probs.size(); i++) {
1277-
probs[i] *= cumsum;
1278-
}
12791268
}
12801269

12811270
//printf("\n");
12821271
//for (int i = 0; i < (int) 10; i++) {
1283-
// printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]);
1272+
// printf("%d: '%s' %f\n", i, lctx.vocab.id_to_token.at(logits_id[i].second).tok.c_str(), probs[i]);
12841273
//}
12851274
//printf("\n\n");
12861275
//exit(0);

0 commit comments

Comments
 (0)