Skip to content

Commit a4de804

Browse files
committed
train : allocate grads for backward graphs
1 parent aa1f36c commit a4de804

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

examples/finetune/finetune.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1769,7 +1769,7 @@ int main(int argc, char ** argv) {
17691769
alloc = ggml_allocr_new_measure(tensor_alignment);
17701770
gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
17711771
gf->order = (enum ggml_cgraph_eval_order) order;
1772-
gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false);
1772+
gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
17731773
gb_tmp = params.common.use_checkpointing
17741774
? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false)
17751775
: NULL;
@@ -1802,7 +1802,7 @@ int main(int argc, char ** argv) {
18021802
alloc = ggml_allocr_new(mem_compute_data.data(), mem_compute_data.size(), tensor_alignment);
18031803
gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
18041804
gf->order = best_order;
1805-
gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false);
1805+
gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
18061806
gb_tmp = params.common.use_checkpointing
18071807
? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false)
18081808
: NULL;

examples/train-text-from-scratch/train-text-from-scratch.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1136,7 +1136,7 @@ int main(int argc, char ** argv) {
11361136
alloc = ggml_allocr_new_measure(tensor_alignment);
11371137
gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
11381138
gf->order = (enum ggml_cgraph_eval_order) order;
1139-
gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false);
1139+
gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
11401140
gb_tmp = params.common.use_checkpointing
11411141
? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false)
11421142
: NULL;
@@ -1169,7 +1169,7 @@ int main(int argc, char ** argv) {
11691169
alloc = ggml_allocr_new(mem_compute_data.data(), mem_compute_data.size(), tensor_alignment);
11701170
gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
11711171
gf->order = best_order;
1172-
gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false);
1172+
gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
11731173
gb_tmp = params.common.use_checkpointing
11741174
? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false)
11751175
: NULL;

0 commit comments

Comments
 (0)