Skip to content

Commit 548ec46

Browse files
committed
train : allocate grads for gb_tmp
1 parent a4de804 commit 548ec46

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

examples/finetune/finetune.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1771,7 +1771,7 @@ int main(int argc, char ** argv) {
17711771
gf->order = (enum ggml_cgraph_eval_order) order;
17721772
gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
17731773
gb_tmp = params.common.use_checkpointing
1774-
? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false)
1774+
? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true)
17751775
: NULL;
17761776
loss = llama_build_lora_finetune_graphs(
17771777
&model, &lora, alloc, ctx_compute,
@@ -1804,7 +1804,7 @@ int main(int argc, char ** argv) {
18041804
gf->order = best_order;
18051805
gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
18061806
gb_tmp = params.common.use_checkpointing
1807-
? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false)
1807+
? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true)
18081808
: NULL;
18091809
loss = llama_build_lora_finetune_graphs(
18101810
&model, &lora, alloc, ctx_compute,

examples/train-text-from-scratch/train-text-from-scratch.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1138,7 +1138,7 @@ int main(int argc, char ** argv) {
11381138
gf->order = (enum ggml_cgraph_eval_order) order;
11391139
gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
11401140
gb_tmp = params.common.use_checkpointing
1141-
? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false)
1141+
? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true)
11421142
: NULL;
11431143
loss = llama_build_train_graphs(
11441144
&model, alloc, ctx_compute,
@@ -1171,7 +1171,7 @@ int main(int argc, char ** argv) {
11711171
gf->order = best_order;
11721172
gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
11731173
gb_tmp = params.common.use_checkpointing
1174-
? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false)
1174+
? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true)
11751175
: NULL;
11761176
loss = llama_build_train_graphs(
11771177
&model, alloc, ctx_compute,

0 commit comments

Comments
 (0)