Skip to content

Commit 03f7e33

Browse files
committed
Cleanup STL headers + fix embedding examples + minor stuff
1 parent 55ad42a commit 03f7e33

File tree

4 files changed

+20
-26
lines changed

4 files changed

+20
-26
lines changed

examples/embedding/embedding.cpp

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,6 @@
11
#include "common.h"
22
#include "llama.h"
33

4-
#include <cassert>
5-
#include <cinttypes>
6-
#include <cmath>
7-
#include <cstdio>
8-
#include <cstring>
9-
#include <fstream>
10-
#include <string>
11-
#include <vector>
12-
134
int main(int argc, char ** argv) {
145
gpt_params params;
156
params.model = "models/llama-7B/ggml-model.bin";
@@ -94,9 +85,13 @@ int main(int argc, char ** argv) {
9485
}
9586
}
9687

88+
const int n_embd = llama_n_embd(ctx);
9789
const auto embeddings = llama_get_embeddings(ctx);
9890

99-
// TODO: print / use the embeddings
91+
for (int i = 0; i < n_embd; i++) {
92+
printf("%f ", embeddings[i]);
93+
}
94+
printf("\n");
10095
}
10196

10297
llama_print_timings(ctx);

examples/perplexity/perplexity.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,6 @@
11
#include "common.h"
22
#include "llama.h"
33

4-
#include <cassert>
5-
#include <cinttypes>
6-
#include <cmath>
7-
#include <cstdio>
8-
#include <cstring>
9-
#include <string>
10-
#include <vector>
11-
124
std::vector<double> softmax(const std::vector<float>& logits) {
135
std::vector<double> probs(logits.size());
146
float max_logit = logits[0];

llama.cpp

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1261,10 +1261,10 @@ static llama_vocab::id llama_sample_top_p_top_k(
12611261
double repeat_penalty) {
12621262
auto & rng = lctx.rng;
12631263

1264-
const auto & vocab = lctx.vocab;
1265-
const auto & logits = lctx.logits;
1264+
const int n_logits = lctx.model.hparams.n_vocab;
12661265

1267-
int n_logits = vocab.id_to_token.size();
1266+
const auto & logits = lctx.logits;
1267+
const auto * plogits = logits.data() + logits.size() - n_logits;
12681268

12691269
std::vector<std::pair<double, llama_vocab::id>> logits_id;
12701270
logits_id.reserve(n_logits);
@@ -1276,13 +1276,13 @@ static llama_vocab::id llama_sample_top_p_top_k(
12761276
// credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main
12771277
if (std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end()) {
12781278
// if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
1279-
if (logits[i] < 0.0) {
1280-
logits_id.push_back(std::make_pair(logits[i]*scale*repeat_penalty, i));
1279+
if (plogits[i] < 0.0) {
1280+
logits_id.push_back(std::make_pair(plogits[i]*scale*repeat_penalty, i));
12811281
} else {
1282-
logits_id.push_back(std::make_pair(logits[i]*scale/repeat_penalty, i));
1282+
logits_id.push_back(std::make_pair(plogits[i]*scale/repeat_penalty, i));
12831283
}
12841284
} else {
1285-
logits_id.push_back(std::make_pair(logits[i]*scale, i));
1285+
logits_id.push_back(std::make_pair(plogits[i]*scale, i));
12861286
}
12871287
}
12881288
}
@@ -1677,14 +1677,16 @@ struct llama_context * llama_init_from_file(
16771677
}
16781678

16791679
const auto & hparams = ctx->model.hparams;
1680+
1681+
// resized during inference
16801682
if (params.logits_all) {
16811683
ctx->logits.reserve(hparams.n_ctx*hparams.n_vocab);
16821684
} else {
16831685
ctx->logits.reserve(hparams.n_ctx);
16841686
}
16851687

16861688
if (params.embedding){
1687-
ctx->embedding.reserve(hparams.n_embd);
1689+
ctx->embedding.resize(hparams.n_embd);
16881690
}
16891691

16901692
ctx->buf_compute.resize(MEM_REQ_EVAL.at(ctx->model.type));
@@ -1761,6 +1763,10 @@ int llama_n_ctx(struct llama_context * ctx) {
17611763
return ctx->model.hparams.n_ctx;
17621764
}
17631765

1766+
int llama_n_embd(struct llama_context * ctx) {
1767+
return ctx->model.hparams.n_embd;
1768+
}
1769+
17641770
float * llama_get_logits(struct llama_context * ctx) {
17651771
return ctx->logits.data();
17661772
}

llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ extern "C" {
109109

110110
LLAMA_API int llama_n_vocab(struct llama_context * ctx);
111111
LLAMA_API int llama_n_ctx (struct llama_context * ctx);
112+
LLAMA_API int llama_n_embd (struct llama_context * ctx);
112113

113114
// Token logits obtained from the last call to llama_eval()
114115
// The logits for the last token are stored in the last row

0 commit comments

Comments
 (0)