Skip to content

Commit 5a4477a

Browse files
llama_save_model_to_file
1 parent 35f9b28 commit 5a4477a

File tree

8 files changed

+462
-151
lines changed

8 files changed

+462
-151
lines changed

common/common.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1959,3 +1959,19 @@ common_control_vector_data common_control_vector_load(const std::vector<common_c
19591959
return result;
19601960
}
19611961

1962+
ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector<llama_token> & tokens, int64_t stride) {
1963+
const int64_t ne_datapoint = llama_n_ctx(ctx);
1964+
const int64_t ndata = (tokens.size() - ne_datapoint - 1) / stride;
1965+
ggml_opt_dataset_t result = ggml_opt_dataset_init(
1966+
GGML_TYPE_I32, GGML_TYPE_I32, ne_datapoint, ne_datapoint, ndata, /*ndata_shard =*/ 1);
1967+
1968+
llama_token * data = (llama_token *) ggml_opt_dataset_data(result)->data;
1969+
llama_token * labels = (llama_token *) ggml_opt_dataset_labels(result)->data;
1970+
1971+
for (int64_t idata = 0; idata < ndata; ++idata) {
1972+
memcpy(data + idata*ne_datapoint, tokens.data() + idata*stride + 0, ne_datapoint*sizeof(llama_token));
1973+
memcpy(labels + idata*ne_datapoint, tokens.data() + idata*stride + 1, ne_datapoint*sizeof(llama_token));
1974+
}
1975+
1976+
return result;
1977+
}

common/common.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -610,3 +610,9 @@ common_control_vector_data common_control_vector_load(const std::vector<common_c
610610
static const char * const LLM_KV_SPLIT_NO = "split.no";
611611
static const char * const LLM_KV_SPLIT_COUNT = "split.count";
612612
static const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
613+
614+
//
615+
// training utils
616+
//
617+
618+
ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector<llama_token> & tokens, int64_t stride);

examples/training/finetune.cpp

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -79,24 +79,35 @@ int main(int argc, char ** argv) {
7979
constexpr float val_split = 0.05f;
8080

8181
std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, true);
82-
ggml_opt_dataset_t dataset = llama_opt_dataset_init(ctx, tokens.data(), tokens.size(), llama_n_ctx(ctx)/2);
83-
llama_opt_init(ctx);
82+
ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx, tokens, llama_n_ctx(ctx)/2);
83+
84+
struct ggml_opt_optimizer_params optimizer_params = ggml_opt_get_default_optimizer_params(nullptr);
85+
optimizer_params.adamw.alpha = 1e-6f; // learning rate
86+
87+
struct llama_opt_params lopt_params {
88+
/*n_ctx_train =*/ 0,
89+
/*get_opt_pars =*/ ggml_opt_get_constant_optimizer_params,
90+
/*get_opt_pars_ud =*/ &optimizer_params,
91+
};
92+
llama_opt_init(ctx, model, lopt_params);
93+
8494
const int64_t idata_split = ggml_opt_dataset_ndata(dataset) * (1.0f - val_split);
8595

86-
while (true) {
87-
ggml_opt_result_t result_train = ggml_opt_result_init();
88-
ggml_opt_result_t result_eval = ggml_opt_result_init();
96+
ggml_opt_result_t result_train = ggml_opt_result_init();
97+
ggml_opt_result_t result_eval = ggml_opt_result_init();
8998

99+
for (int epoch = 0; epoch < 1; ++epoch) {
90100
llama_opt_epoch(ctx, dataset, result_train, result_eval, idata_split,
91101
ggml_opt_epoch_callback_progress_bar, ggml_opt_epoch_callback_progress_bar);
92102
fprintf(stderr, "\n");
93103

94-
ggml_opt_result_free(result_train);
95-
ggml_opt_result_free(result_eval);
104+
ggml_opt_result_reset(result_train);
105+
ggml_opt_result_reset(result_eval);
96106
}
107+
ggml_opt_result_free(result_train);
108+
ggml_opt_result_free(result_eval);
97109

98-
LOG("\n");
99-
llama_perf_context_print(ctx);
110+
llama_save_model_to_file(model, "finetuned-model.gguf");
100111

101112
llama_free(ctx);
102113
llama_free_model(model);

ggml/include/ggml-opt.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,13 @@ extern "C" {
9090
// userdata can be used to pass arbitrary data
9191
typedef struct ggml_opt_optimizer_params (*ggml_opt_get_optimizer_params)(void * userdata);
9292

93-
// returns the default optimizer params (constant)
93+
// returns the default optimizer params (constant, hard-coded values)
9494
// userdata is not used
9595
GGML_API struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * userdata);
9696

97+
// casts userdata to ggml_opt_optimizer_params and returns it
98+
GGML_API struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params(void * userdata);
99+
97100
// parameters for initializing a new optimization context
98101
struct ggml_opt_params {
99102
ggml_backend_sched_t backend_sched; // defines which backends are used to construct the compute graphs

ggml/src/ggml-opt.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,10 @@ struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * us
229229
return result;
230230
}
231231

232+
struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params(void * userdata) {
233+
return *((struct ggml_opt_optimizer_params *) userdata);
234+
}
235+
232236
struct ggml_opt_params ggml_opt_default_params(
233237
ggml_backend_sched_t backend_sched,
234238
struct ggml_context * ctx_compute,

include/llama.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,10 @@ extern "C" {
413413
const char * path_model,
414414
struct llama_model_params params);
415415

416+
LLAMA_API void llama_save_model_to_file(
417+
const struct llama_model * model,
418+
const char * path_model);
419+
416420
LLAMA_API void llama_free_model(struct llama_model * model);
417421

418422
// TODO: rename to llama_init_from_model
@@ -1255,9 +1259,14 @@ extern "C" {
12551259
// training
12561260
//
12571261

1258-
LLAMA_API ggml_opt_dataset_t llama_opt_dataset_init(struct llama_context * ctx, const llama_token * tokens, int64_t n_tokens, int32_t stride);
1262+
struct llama_opt_params {
1263+
uint32_t n_ctx_train; // assumed context size post training, use context size specified in llama_context if 0
1264+
1265+
ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
1266+
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
1267+
};
12591268

1260-
LLAMA_API void llama_opt_init(struct llama_context * lctx);
1269+
LLAMA_API void llama_opt_init(struct llama_context * lctx, struct llama_model * model, struct llama_opt_params lopt_params);
12611270

12621271
LLAMA_API void llama_opt_epoch(
12631272
struct llama_context * lctx,

src/llama-vocab.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ struct llama_vocab {
6161
// set of all tokens that cause "end of generation"
6262
std::set<id> special_eog_ids;
6363

64+
std::string tokenizer_model;
65+
std::string tokenizer_pre;
66+
6467
// tokenizer flags
6568
bool tokenizer_add_space_prefix = false;
6669
bool tokenizer_add_bos = false;

0 commit comments

Comments
 (0)