Skip to content

Commit d2d5254

Browse files
perplexity binary support
1 parent 0f86ae9 commit d2d5254

File tree

7 files changed

+70
-26
lines changed

7 files changed

+70
-26
lines changed

examples/common.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,10 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
439439
break;
440440
}
441441
params.logdir = argv[i];
442+
443+
if (params.logdir.back() != '/') {
444+
params.logdir += "/";
445+
}
442446
} else if (arg == "--perplexity") {
443447
params.perplexity = true;
444448
} else if (arg == "--hellaswag") {

examples/main/main.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -830,20 +830,18 @@ int main(int argc, char ** argv) {
830830
if (!params.logdir.empty()) {
831831
const std::string timestamp = get_sortable_timestamp();
832832

833-
std::string logdir = params.logdir;
834-
const bool success = create_directory_with_parents(logdir);
833+
const bool success = create_directory_with_parents(params.logdir);
835834
if (success) {
836-
if (logdir.back() != '/') {
837-
logdir += "/";
838-
}
839835

840-
FILE * logfile = fopen((logdir + timestamp + ".yml").c_str(), "w");
836+
FILE * logfile = fopen((params.logdir + timestamp + ".yml").c_str(), "w");
841837
fprintf(logfile, "binary: main\n");
842838
dump_non_result_info_yaml(logfile, params, timestamp, input_tokens);
843-
llama_dump_result_info_yaml(logfile, ctx, output_ss.str().c_str(), output_tokens.data(), output_tokens.size());
839+
llama_dump_result_info_yaml(
840+
logfile, ctx, output_ss.str().c_str(), output_tokens.data(), output_tokens.size(), NULL, 0);
844841
fclose(logfile);
845842
} else {
846-
fprintf(stderr, "%s: warning: failed to create logdir %s, cannot write logfile\n", __func__, logdir.c_str());
843+
fprintf(stderr, "%s: warning: failed to create logdir %s, cannot write logfile\n",
844+
__func__, params.logdir.c_str());
847845
}
848846
}
849847
if (ctx_guidance) { llama_free(ctx_guidance); }

examples/perplexity/perplexity.cpp

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
#include "common.h"
22
#include "llama.h"
3+
#include "llama-util.h"
34
#include "build-info.h"
45

56
#include <cmath>
67
#include <ctime>
78
#include <sstream>
9+
#include <utility>
10+
#include <vector>
811

912
#if defined(_MSC_VER)
1013
#pragma warning(disable: 4244 4267) // possible loss of data
@@ -26,12 +29,13 @@ std::vector<float> softmax(const std::vector<float>& logits) {
2629
return probs;
2730
}
2831

29-
void perplexity(llama_context * ctx, const gpt_params & params) {
32+
std::pair<std::vector<llama_token>, std::vector<float>> perplexity(llama_context * ctx, const gpt_params & params) {
3033
// Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
3134
// Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
3235
// Output: `perplexity: 13.5106 [114/114]`
3336
// BOS tokens will be added for each chunk before eval
34-
auto tokens = ::llama_tokenize(ctx, params.prompt, true);
37+
std::vector<llama_token> tokens = ::llama_tokenize(ctx, params.prompt, true);
38+
std::vector<float> probs;
3539

3640
const int n_chunk_max = tokens.size() / params.n_ctx;
3741

@@ -68,7 +72,7 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
6872

6973
if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) {
7074
fprintf(stderr, "%s : failed to eval\n", __func__);
71-
return;
75+
return std::make_pair(tokens, probs);
7276
}
7377

7478
// restore the original token in case it was set to BOS
@@ -110,6 +114,7 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
110114
logits.begin() + (j + 1) * n_vocab);
111115

112116
const float prob = softmax(tok_logits)[tokens[start + j + 1]];
117+
probs.push_back(prob);
113118

114119
nll += -std::log(prob);
115120
++count;
@@ -119,6 +124,8 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
119124
fflush(stdout);
120125
}
121126
printf("\n");
127+
128+
return std::make_pair(tokens, probs);
122129
}
123130

124131
void hellaswag_score(llama_context * ctx, const gpt_params & params) {
@@ -341,13 +348,35 @@ int main(int argc, char ** argv) {
341348
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
342349
}
343350

351+
std::vector<llama_token> tokens;
352+
std::vector<float> probs;
344353
if (params.hellaswag) {
345354
hellaswag_score(ctx, params);
346355
} else {
347-
perplexity(ctx, params);
356+
auto ret = perplexity(ctx, params);
357+
tokens = ret.first;
358+
probs = ret.second;
348359
}
349360

350361
llama_print_timings(ctx);
362+
363+
if (!params.logdir.empty()) {
364+
const std::string timestamp = get_sortable_timestamp();
365+
366+
const bool success = create_directory_with_parents(params.logdir);
367+
if (success) {
368+
369+
FILE * logfile = fopen((params.logdir + timestamp + ".yml").c_str(), "w");
370+
fprintf(logfile, "binary: perplexity\n");
371+
dump_non_result_info_yaml(logfile, params, timestamp, tokens);
372+
llama_dump_result_info_yaml(logfile, ctx, NULL, NULL, 0, probs.data(), probs.size());
373+
fclose(logfile);
374+
} else {
375+
fprintf(stderr, "%s: warning: failed to create logdir %s, cannot write logfile\n",
376+
__func__, params.logdir.c_str());
377+
}
378+
}
379+
351380
llama_free(ctx);
352381
llama_free_model(model);
353382

llama-util.h

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -585,12 +585,6 @@ static bool create_directory_with_parents(const std::string & path) {
585585
pos_slash += 1;
586586
}
587587

588-
// finally, create the directory for the logs
589-
const int ret = mkdir(path.c_str(), 0755);
590-
if (ret != 0) {
591-
return false;
592-
}
593-
594588
return true;
595589
}
596590

@@ -612,7 +606,7 @@ static void dump_vector_int_yaml(FILE * stream, const char * prop_name, const st
612606

613607
static void dump_string_yaml_multiline(FILE * stream, const char * prop_name, const char * data,
614608
const bool remove_first) {
615-
std::string data_str(data);
609+
std::string data_str(data == NULL ? "" : data);
616610

617611
if (data_str.empty()) {
618612
fprintf(stream, "%s:\n", prop_name);

llama.cpp

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4399,8 +4399,10 @@ const char * llama_print_system_info(void) {
43994399
return s.c_str();
44004400
}
44014401

4402-
void llama_dump_result_info_yaml(FILE * stream, const llama_context * ctx, const char * output_str,
4403-
const int * output_tokens, const int n_output_tokens) {
4402+
void llama_dump_result_info_yaml(
4403+
FILE * stream, const llama_context * ctx, const char * output_str, const int * output_tokens,
4404+
const int n_output_tokens, const float * probs, const int n_probs) {
4405+
44044406
fprintf(stream, "\n");
44054407
fprintf(stream, "###########\n");
44064408
fprintf(stream, "# Results #\n");
@@ -4422,8 +4424,19 @@ void llama_dump_result_info_yaml(FILE * stream, const llama_context * ctx, const
44224424
fprintf(stream, "n_sample: %d # number of sampled tokens\n", ctx->n_sample);
44234425
dump_string_yaml_multiline(stream, "output", output_str, false);
44244426

4425-
const std::vector<int> output_token_vector(output_tokens, output_tokens + n_output_tokens);
4426-
dump_vector_int_yaml(stream, "output_tokens", output_token_vector);
4427+
if (output_tokens == NULL) {
4428+
fprintf(stream, "output_tokens:\n");
4429+
} else {
4430+
const std::vector<int> output_token_vector(output_tokens, output_tokens + n_output_tokens);
4431+
dump_vector_int_yaml(stream, "output_tokens", output_token_vector);
4432+
}
4433+
4434+
if (probs == NULL) {
4435+
fprintf(stream, "probs:\n");
4436+
} else {
4437+
const std::vector<float> prob_vector(probs, probs + n_probs);
4438+
dump_vector_float_yaml(stream, "probs", prob_vector);
4439+
}
44274440

44284441
fprintf(stream, "t_eval_us: %ld # total microseconds spent generating tokens\n", ctx->t_eval_us);
44294442
fprintf(stream, "t_load_us: %ld # total microseconds spent loading the model\n", ctx->t_load_us);

llama.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -471,8 +471,9 @@ extern "C" {
471471
// Print system information
472472
LLAMA_API const char * llama_print_system_info(void);
473473

474-
LLAMA_API void llama_dump_result_info_yaml(FILE * stream, const llama_context * ctx, const char * output_str,
475-
const int * output_tokens, int n_output_tokens);
474+
LLAMA_API void llama_dump_result_info_yaml(
475+
FILE * stream, const llama_context * ctx, const char * output_str, const int * output_tokens,
476+
int n_output_tokens, const float * probs, int n_probs);
476477

477478
#ifdef __cplusplus
478479
}

run_with_preset.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#!/usr/bin/env python3
22

3+
import os
34
import subprocess
45
import sys
56

@@ -23,7 +24,11 @@
2324

2425
props = {prop.replace("_", "-"): val for prop, val in props.items()}
2526

26-
command_list = ["./main"]
27+
binary = props.pop("binary", "main")
28+
if os.path.exists(f"./{binary}"):
29+
binary = f"./{binary}"
30+
31+
command_list = [binary]
2732

2833
for cli_arg in CLI_ARGS_MAIN:
2934
value = props.get(cli_arg, None)

0 commit comments

Comments
 (0)