Skip to content

Commit 260b4a5

Browse files
committed
speculative : add initial poc
1 parent b7f2aa9 commit 260b4a5

File tree

16 files changed

+410
-114
lines changed

16 files changed

+410
-114
lines changed

common/common.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,7 @@
2424

2525
#if defined(_WIN32)
2626
#define WIN32_LEAN_AND_MEAN
27-
#ifndef NOMINMAX
28-
# define NOMINMAX
29-
#endif
27+
#define NOMINMAX
3028
#include <codecvt>
3129
#include <locale>
3230
#include <windows.h>
@@ -317,6 +315,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
317315
break;
318316
}
319317
params.model = argv[i];
318+
} else if (arg == "-md" || arg == "--model-draft") {
319+
if (++i >= argc) {
320+
invalid_param = true;
321+
break;
322+
}
323+
params.model_draft = argv[i];
320324
} else if (arg == "-a" || arg == "--alias") {
321325
if (++i >= argc) {
322326
invalid_param = true;
@@ -669,6 +673,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
669673
fprintf(stdout, " --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n");
670674
fprintf(stdout, " -m FNAME, --model FNAME\n");
671675
fprintf(stdout, " model path (default: %s)\n", params.model.c_str());
676+
fprintf(stdout, " -md FNAME, --model-draft FNAME\n");
677+
fprintf(stdout, " draft model for speculative sampling (default: %s)\n", params.model.c_str());
672678
fprintf(stdout, " -ld LOGDIR, --logdir LOGDIR\n");
673679
fprintf(stdout, " path under which to save YAML logs (no logging if unset)\n");
674680
fprintf(stdout, "\n");
@@ -1029,7 +1035,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
10291035
dump_string_yaml_multiline(stream, "grammar", params.grammar.c_str());
10301036
fprintf(stream, "grammar-file: # never logged, see grammar instead. Can still be specified for input.\n");
10311037
fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false");
1032-
fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks);
1038+
fprintf(stream, "hellaswag_tasks: %ld # default: 400\n", params.hellaswag_tasks);
10331039

10341040
const auto logit_bias_eos = params.logit_bias.find(llama_token_eos(lctx));
10351041
const bool ignore_eos = logit_bias_eos != params.logit_bias.end() && logit_bias_eos->second == -INFINITY;
@@ -1062,6 +1068,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
10621068
fprintf(stream, "mirostat_lr: %f # default: 0.1\n", params.mirostat_eta);
10631069
fprintf(stream, "mlock: %s # default: false\n", params.use_mlock ? "true" : "false");
10641070
fprintf(stream, "model: %s # default: models/7B/ggml-model.bin\n", params.model.c_str());
1071+
fprintf(stream, "model_draft: %s # default:\n", params.model_draft.c_str());
10651072
fprintf(stream, "mtest: %s # default: false\n", params.mem_test ? "true" : "false");
10661073
fprintf(stream, "multiline_input: %s # default: false\n", params.multiline_input ? "true" : "false");
10671074
fprintf(stream, "n_gpu_layers: %d # default: 0\n", params.n_gpu_layers);

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ struct gpt_params {
6363
float cfg_scale = 1.f; // How strong is guidance
6464

6565
std::string model = "models/7B/ggml-model-f16.gguf"; // model path
66+
std::string model_draft = ""; // draft model for speculative sampling
6667
std::string model_alias = "unknown"; // model alias
6768
std::string prompt = "";
6869
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state

common/console.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,6 @@ namespace console {
235235

236236
int estimateWidth(char32_t codepoint) {
237237
#if defined(_WIN32)
238-
(void)codepoint;
239238
return 1;
240239
#else
241240
return wcwidth(codepoint);

common/log.h

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ inline std::string log_filename_generator_impl(const std::string & log_file_base
154154
// #include "log.h"
155155
//
156156
#ifndef LOG_NO_TIMESTAMPS
157-
#ifndef _MSC_VER
157+
#ifndef _WIN32
158158
#define LOG_TIMESTAMP_FMT "[%" PRIu64 "] "
159159
#define LOG_TIMESTAMP_VAL , (std::chrono::duration_cast<std::chrono::duration<std::uint64_t>>(std::chrono::system_clock::now().time_since_epoch())).count()
160160
#else
@@ -167,7 +167,7 @@ inline std::string log_filename_generator_impl(const std::string & log_file_base
167167
#endif
168168

169169
#ifdef LOG_TEE_TIMESTAMPS
170-
#ifndef _MSC_VER
170+
#ifndef _WIN32
171171
#define LOG_TEE_TIMESTAMP_FMT "[%" PRIu64 "] "
172172
#define LOG_TEE_TIMESTAMP_VAL , (std::chrono::duration_cast<std::chrono::duration<std::uint64_t>>(std::chrono::system_clock::now().time_since_epoch())).count()
173173
#else
@@ -187,7 +187,7 @@ inline std::string log_filename_generator_impl(const std::string & log_file_base
187187
// #include "log.h"
188188
//
189189
#ifndef LOG_NO_FILE_LINE_FUNCTION
190-
#ifndef _MSC_VER
190+
#ifndef _WIN32
191191
#define LOG_FLF_FMT "[%24s:%5d][%24s] "
192192
#define LOG_FLF_VAL , __FILE__, __LINE__, __FUNCTION__
193193
#else
@@ -200,7 +200,7 @@ inline std::string log_filename_generator_impl(const std::string & log_file_base
200200
#endif
201201

202202
#ifdef LOG_TEE_FILE_LINE_FUNCTION
203-
#ifndef _MSC_VER
203+
#ifndef _WIN32
204204
#define LOG_TEE_FLF_FMT "[%24s:%5d][%24s] "
205205
#define LOG_TEE_FLF_VAL , __FILE__, __LINE__, __FUNCTION__
206206
#else
@@ -224,7 +224,7 @@ enum LogTriState
224224
// INTERNAL, DO NOT USE
225225
// USE LOG() INSTEAD
226226
//
227-
#ifndef _MSC_VER
227+
#ifndef _WIN32
228228
#define LOG_IMPL(str, ...) \
229229
{ \
230230
if (LOG_TARGET != nullptr) \
@@ -247,7 +247,7 @@ enum LogTriState
247247
// INTERNAL, DO NOT USE
248248
// USE LOG_TEE() INSTEAD
249249
//
250-
#ifndef _MSC_VER
250+
#ifndef _WIN32
251251
#define LOG_TEE_IMPL(str, ...) \
252252
{ \
253253
if (LOG_TARGET != nullptr) \
@@ -284,7 +284,7 @@ enum LogTriState
284284
// Main LOG macro.
285285
// behaves like printf, and supports arguments the exact same way.
286286
//
287-
#ifndef _MSC_VER
287+
#ifndef _WIN32
288288
#define LOG(...) LOG_IMPL(__VA_ARGS__, "")
289289
#else
290290
#define LOG(str, ...) LOG_IMPL("%s" str, "", __VA_ARGS__, "")
@@ -298,14 +298,14 @@ enum LogTriState
298298
// Secondary target can be changed just like LOG_TARGET
299299
// by defining LOG_TEE_TARGET
300300
//
301-
#ifndef _MSC_VER
301+
#ifndef _WIN32
302302
#define LOG_TEE(...) LOG_TEE_IMPL(__VA_ARGS__, "")
303303
#else
304304
#define LOG_TEE(str, ...) LOG_TEE_IMPL("%s" str, "", __VA_ARGS__, "")
305305
#endif
306306

307307
// LOG macro variants with auto endline.
308-
#ifndef _MSC_VER
308+
#ifndef _WIN32
309309
#define LOGLN(...) LOG_IMPL(__VA_ARGS__, "\n")
310310
#define LOG_TEELN(...) LOG_TEE_IMPL(__VA_ARGS__, "\n")
311311
#else
@@ -341,14 +341,14 @@ inline FILE *log_handler1_impl(bool change = false, LogTriState disable = LogTri
341341
}
342342
}
343343

344-
if (_disabled)
345-
{
346-
// Log is disabled
347-
return nullptr;
348-
}
349-
350344
if (_initialized)
351345
{
346+
if (_disabled)
347+
{
348+
// Log is disabled
349+
return nullptr;
350+
}
351+
352352
// with fallback in case something went wrong
353353
return logfile ? logfile : stderr;
354354
}
@@ -461,7 +461,7 @@ inline void log_test()
461461
LOG("13 Hello World this time in yet new file?\n")
462462
log_set_target(log_filename_generator("llama_autonamed", "log"));
463463
LOG("14 Hello World in log with generated filename!\n")
464-
#ifdef _MSC_VER
464+
#ifdef _WIN32
465465
LOG_TEE("15 Hello msvc TEE without arguments\n")
466466
LOG_TEE("16 Hello msvc TEE with (%d)(%s) arguments\n", 1, "test")
467467
LOG_TEELN("17 Hello msvc TEELN without arguments\n")

examples/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ else()
2323
add_subdirectory(train-text-from-scratch)
2424
add_subdirectory(convert-llama2c-to-ggml)
2525
add_subdirectory(simple)
26+
add_subdirectory(speculative)
2627
add_subdirectory(embd-input)
2728
add_subdirectory(llama-bench)
2829
add_subdirectory(beam-search)

examples/baby-llama/baby-llama.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1617,10 +1617,15 @@ int main(int argc, char ** argv) {
16171617

16181618
float error_before_opt = ggml_get_f32_1d(e, 0);
16191619

1620+
struct ggml_opt_params opt_params_adam = ggml_opt_default_params(GGML_OPT_ADAM);
16201621
struct ggml_opt_params opt_params_lbfgs = ggml_opt_default_params(GGML_OPT_LBFGS);
1622+
opt_params_adam.print_forward_graph = false;
1623+
opt_params_adam.print_backward_graph = false;
16211624
opt_params_lbfgs.print_forward_graph = false;
16221625
opt_params_lbfgs.print_backward_graph = false;
1626+
opt_params_adam.adam.n_iter = 16;
16231627
opt_params_lbfgs.lbfgs.n_iter = 16;
1628+
// ggml_opt(ctx0, opt_params_adam, e);
16241629
ggml_opt(ctx0, opt_params_lbfgs, e);
16251630
//
16261631
ggml_build_forward_expand(&gf, e);

examples/beam-search/beam-search.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,7 @@
2222
#include <unistd.h>
2323
#elif defined (_WIN32)
2424
#define WIN32_LEAN_AND_MEAN
25-
#ifndef NOMINMAX
26-
# define NOMINMAX
27-
#endif
25+
#define NOMINMAX
2826
#include <windows.h>
2927
#include <signal.h>
3028
#endif
@@ -75,7 +73,7 @@ void beam_search_callback(void * callback_data_ptr, llama_beams_state beams_stat
7573
assert(0u < beams_state.n_beams);
7674
const llama_token * tokens = beams_state.beam_views[0].tokens;
7775
std::copy(tokens, tokens + n, callback_data.response.end() - n);
78-
printf("%zu", n);
76+
printf("%lu", n);
7977
}
8078
fflush(stdout);
8179
#if 1 // DEBUG: print current beams for this iteration
@@ -147,7 +145,7 @@ int main(int argc, char ** argv)
147145

148146
if (tokens_list.size() > max_tokens_list_size)
149147
{
150-
fprintf( stderr , "%s: error: prompt too long (%zu tokens, max %zu)\n" ,
148+
fprintf( stderr , "%s: error: prompt too long (%lu tokens, max %lu)\n" ,
151149
__func__ , tokens_list.size() , max_tokens_list_size );
152150
return 1;
153151
}

examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp

Lines changed: 35 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ typedef struct {
7575
int seq_len; // max sequence length
7676
} Config;
7777

78-
struct TransformerWeights {
78+
typedef struct {
7979
// token embedding table
8080
float* token_embedding_table; // (vocab_size, dim)
8181
// weights for rmsnorms
@@ -97,22 +97,7 @@ struct TransformerWeights {
9797
// float* freq_cis_imag; // (seq_len, dim/2)
9898
// (optional) classifier weights for the logits, on the last layer
9999
float* wcls;
100-
101-
~TransformerWeights() {
102-
delete[] token_embedding_table;
103-
delete[] rms_att_weight;
104-
delete[] rms_ffn_weight;
105-
delete[] wq;
106-
delete[] wk;
107-
delete[] wv;
108-
delete[] wo;
109-
delete[] w1;
110-
delete[] w2;
111-
delete[] w3;
112-
delete[] rms_final_weight;
113-
delete[] wcls;
114-
}
115-
};
100+
} TransformerWeights;
116101

117102
void malloc_weights(TransformerWeights* w, Config* p, bool shared_weights) {
118103
// we calloc instead of malloc to keep valgrind happy
@@ -188,6 +173,21 @@ int checkpoint_init_weights(TransformerWeights *w, Config* p, FILE* f, bool shar
188173
return 0;
189174
}
190175

176+
void free_weights(TransformerWeights* w) {
177+
delete w->token_embedding_table;
178+
delete w->rms_att_weight;
179+
delete w->rms_ffn_weight;
180+
delete w->wq;
181+
delete w->wk;
182+
delete w->wv;
183+
delete w->wo;
184+
delete w->w1;
185+
delete w->w2;
186+
delete w->w3;
187+
delete w->rms_final_weight;
188+
if (w->wcls) delete w->wcls;
189+
}
190+
191191
void print_sample_weights(TransformerWeights *w){
192192
printf("----- Quick print of first of the weight vales of all the variables\n");
193193
printf("%f\n", w->token_embedding_table[0]);
@@ -596,10 +596,6 @@ void load_vocab(const char *filename, Config *config, struct llama_vocab *vocab)
596596
// assume llama2.c vocabulary
597597
printf("Assuming llama2.c vocabulary since %s is not a gguf file\n", filename);
598598
llama_file file(filename, "rb");
599-
if (!file.fp) {
600-
fprintf(stderr, "error: %s: %s\n", strerror(errno), filename);
601-
exit(1);
602-
}
603599
const int n_vocab = config->vocab_size;
604600
/* uint32_t max_token_length = */ file.read_u32(); // unused
605601
vocab->id_to_token.resize(n_vocab);
@@ -637,7 +633,7 @@ void load_vocab(const char *filename, Config *config, struct llama_vocab *vocab)
637633
}
638634
}
639635

640-
void convert_weights_ak_to_gg(struct ggml_tensor * gg_weights, const float * karpathy_weights) {
636+
void stuff_karpathy_weights_into_gg(struct ggml_tensor * gg_weights, float * karpathy_weights){
641637
int ct;
642638
switch (gg_weights->n_dims){
643639
case 1:
@@ -674,13 +670,13 @@ void convert_weights_ak_to_gg(struct ggml_tensor * gg_weights, const float * kar
674670
}
675671

676672
void save_as_llama_model(struct llama_vocab * vocab, struct my_llama_model * model, TransformerWeights* w, const char * filename) {
677-
// convert AK weights into GG weights one by one.
673+
// stuff AK weights into GG weights one by one.
678674
// w->token_embedding_table -> model->tok_embeddings
679675
// float* -> struct ggml_tensor
680-
convert_weights_ak_to_gg(model->tok_embeddings, w->token_embedding_table);
681-
convert_weights_ak_to_gg(model->output, w->wcls ? w->wcls : w->token_embedding_table);
676+
stuff_karpathy_weights_into_gg(model->tok_embeddings, w->token_embedding_table);
677+
stuff_karpathy_weights_into_gg(model->output, w->wcls ? w->wcls : w->token_embedding_table);
682678

683-
convert_weights_ak_to_gg(model->norm, w->rms_final_weight);
679+
stuff_karpathy_weights_into_gg(model->norm, w->rms_final_weight);
684680
//print_row(model->norm, 0);
685681

686682
// for rms-att-weight
@@ -690,18 +686,18 @@ void save_as_llama_model(struct llama_vocab * vocab, struct my_llama_model * mod
690686
for (uint32_t i = 0; i < model->hparams.n_layer; ++i){
691687
auto & layer = model->layers[i];
692688
// 1d
693-
convert_weights_ak_to_gg(layer.attention_norm, &w->rms_att_weight[i*row_length]);
694-
convert_weights_ak_to_gg(layer.ffn_norm , &w->rms_ffn_weight[i*row_length]);
689+
stuff_karpathy_weights_into_gg(layer.attention_norm, &w->rms_att_weight[i*row_length]);
690+
stuff_karpathy_weights_into_gg(layer.ffn_norm , &w->rms_ffn_weight[i*row_length]);
695691

696692
// from 3d matrix layer x dim x dim to 2d matrix dim x dim
697-
convert_weights_ak_to_gg(layer.wq , &w->wq[i*row_length*row_length]);
698-
convert_weights_ak_to_gg(layer.wk , &w->wk[i*row_length*row_length]);
699-
convert_weights_ak_to_gg(layer.wv , &w->wv[i*row_length*row_length]);
700-
convert_weights_ak_to_gg(layer.wo , &w->wo[i*row_length*row_length]);
701-
702-
convert_weights_ak_to_gg(layer.w1 , &w->w1[i*row_length*n_ff]);
703-
convert_weights_ak_to_gg(layer.w2 , &w->w2[i*n_ff*row_length]);
704-
convert_weights_ak_to_gg(layer.w3 , &w->w3[i*row_length*n_ff]);
693+
stuff_karpathy_weights_into_gg(layer.wq , &w->wq[i*row_length*row_length]);
694+
stuff_karpathy_weights_into_gg(layer.wk , &w->wk[i*row_length*row_length]);
695+
stuff_karpathy_weights_into_gg(layer.wv , &w->wv[i*row_length*row_length]);
696+
stuff_karpathy_weights_into_gg(layer.wo , &w->wo[i*row_length*row_length]);
697+
698+
stuff_karpathy_weights_into_gg(layer.w1 , &w->w1[i*row_length*n_ff]);
699+
stuff_karpathy_weights_into_gg(layer.w2 , &w->w2[i*n_ff*row_length]);
700+
stuff_karpathy_weights_into_gg(layer.w3 , &w->w3[i*row_length*n_ff]);
705701
}
706702

707703
struct gguf_context * ctx = gguf_init_empty();
@@ -902,7 +898,7 @@ bool params_parse(int argc, char ** argv, struct train_params * params) {
902898
}
903899

904900
std::string basename(const std::string &path) {
905-
size_t pos = path.find_last_of("/\\");
901+
size_t pos = path.find_last_of("/");
906902
if (pos == std::string::npos) {
907903
return path;
908904
}
@@ -915,7 +911,7 @@ int main(int argc, char ** argv) {
915911
return 1;
916912
}
917913
Config config;
918-
TransformerWeights weights = {};
914+
TransformerWeights weights;
919915
{
920916
FILE *file = fopen(params.fn_llama2c_model, "rb");
921917
if (!file) { printf("Unable to open the checkpoint file %s!\n", params.fn_llama2c_model); return 1; }
@@ -957,5 +953,6 @@ int main(int argc, char ** argv) {
957953
printf("Saving llama.c model file %s in ggml format at %s\n", params.fn_llama2c_model, params.fn_llama2c_output_model);
958954

959955
ggml_free(model.ctx);
956+
free_weights(&weights);
960957
return 0;
961958
}

0 commit comments

Comments
 (0)