Skip to content

Commit 16e819d

Browse files
committed
sync : pass custom graph sizes in training examples
1 parent 815f44e commit 16e819d

File tree

3 files changed

+15
-12
lines changed

3 files changed

+15
-12
lines changed

common/train.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ struct train_state * init_train_state() {
3232
state->opt = new struct ggml_opt_context;
3333
state->opt->ctx = NULL;
3434
state->opt->params = ggml_opt_default_params(GGML_OPT_ADAM);
35+
state->opt->params.graph_size = LLAMA_TRAIN_MAX_NODES;
3536
state->opt->loss_after = 0.0f;
3637

3738
return state;

examples/finetune/finetune.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1615,6 +1615,7 @@ int main(int argc, char ** argv) {
16151615
opt->params = ggml_opt_default_params(GGML_OPT_ADAM);
16161616
opt->params.print_forward_graph = false;
16171617
opt->params.print_backward_graph = false;
1618+
opt->params.graph_size = LLAMA_TRAIN_MAX_NODES;
16181619
opt->params.n_threads = params.common.n_threads;
16191620
opt->params.past = params.common.opt_past;
16201621
opt->params.delta = params.common.opt_delta;
@@ -1768,11 +1769,11 @@ int main(int argc, char ** argv) {
17681769
for (unsigned order = 0; order < (unsigned) GGML_CGRAPH_EVAL_ORDER_COUNT; ++order) {
17691770
ctx_compute = ggml_init(ctx_compute_params);
17701771
alloc = ggml_allocr_new_measure(tensor_alignment);
1771-
gf = ggml_new_graph(ctx_compute);
1772+
gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
17721773
gf->order = (enum ggml_cgraph_eval_order) order;
1773-
gb = ggml_new_graph(ctx_compute);
1774+
gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false);
17741775
gb_tmp = params.common.use_checkpointing
1775-
? ggml_new_graph(ctx_compute)
1776+
? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false)
17761777
: NULL;
17771778
loss = llama_build_lora_finetune_graphs(
17781779
&model, &lora, alloc, ctx_compute,
@@ -1801,11 +1802,11 @@ int main(int argc, char ** argv) {
18011802
mem_compute_data.resize(max_compute_size);
18021803
ctx_compute = ggml_init(ctx_compute_params);
18031804
alloc = ggml_allocr_new(mem_compute_data.data(), mem_compute_data.size(), tensor_alignment);
1804-
gf = ggml_new_graph(ctx_compute);
1805+
gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
18051806
gf->order = best_order;
1806-
gb = ggml_new_graph(ctx_compute);
1807+
gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false);
18071808
gb_tmp = params.common.use_checkpointing
1808-
? ggml_new_graph(ctx_compute)
1809+
? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false)
18091810
: NULL;
18101811
loss = llama_build_lora_finetune_graphs(
18111812
&model, &lora, alloc, ctx_compute,

examples/train-text-from-scratch/train-text-from-scratch.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1006,6 +1006,7 @@ int main(int argc, char ** argv) {
10061006
opt->params = ggml_opt_default_params(GGML_OPT_ADAM);
10071007
opt->params.print_forward_graph = false;
10081008
opt->params.print_backward_graph = false;
1009+
opt->params.graph_size = LLAMA_TRAIN_MAX_NODES;
10091010
opt->params.n_threads = params.common.n_threads;
10101011
opt->params.past = params.common.opt_past;
10111012
opt->params.delta = params.common.opt_delta;
@@ -1135,11 +1136,11 @@ int main(int argc, char ** argv) {
11351136
for (unsigned order = 0; order < (unsigned) GGML_CGRAPH_EVAL_ORDER_COUNT; ++order) {
11361137
ctx_compute = ggml_init(ctx_compute_params);
11371138
alloc = ggml_allocr_new_measure(tensor_alignment);
1138-
gf = ggml_new_graph(ctx_compute);
1139+
gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
11391140
gf->order = (enum ggml_cgraph_eval_order) order;
1140-
gb = ggml_new_graph(ctx_compute);
1141+
gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false);
11411142
gb_tmp = params.common.use_checkpointing
1142-
? ggml_new_graph(ctx_compute)
1143+
? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false)
11431144
: NULL;
11441145
loss = llama_build_train_graphs(
11451146
&model, alloc, ctx_compute,
@@ -1168,11 +1169,11 @@ int main(int argc, char ** argv) {
11681169
mem_compute_data.resize(max_compute_size);
11691170
ctx_compute = ggml_init(ctx_compute_params);
11701171
alloc = ggml_allocr_new(mem_compute_data.data(), mem_compute_data.size(), tensor_alignment);
1171-
gf = ggml_new_graph(ctx_compute);
1172+
gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
11721173
gf->order = best_order;
1173-
gb = ggml_new_graph(ctx_compute);
1174+
gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false);
11741175
gb_tmp = params.common.use_checkpointing
1175-
? ggml_new_graph(ctx_compute)
1176+
? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false)
11761177
: NULL;
11771178
loss = llama_build_train_graphs(
11781179
&model, alloc, ctx_compute,

0 commit comments

Comments
 (0)