@@ -17272,7 +17272,6 @@ void ggml_graph_compute(struct ggml_graph_compute_plan * plan, struct ggml_cgrap
17272
17272
GGML_ASSERT(plan->n_tasks[i] > 0);
17273
17273
}
17274
17274
}
17275
-
17276
17275
}
17277
17276
17278
17277
const int n_threads = plan->n_threads;
@@ -17339,6 +17338,20 @@ void ggml_graph_compute(struct ggml_graph_compute_plan * plan, struct ggml_cgrap
17339
17338
}
17340
17339
}
17341
17340
17341
+ static void ggml_graph_compute_sugar(struct ggml_cgraph * cgraph, int n_threads) {
17342
+ struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(cgraph, n_threads);
17343
+ if (plan.work_size > 0) {
17344
+ plan.work_data = malloc(plan.work_size);
17345
+ GGML_ASSERT(plan.work_data);
17346
+ }
17347
+
17348
+ ggml_graph_compute(&plan, cgraph);
17349
+
17350
+ if (plan.work_data) {
17351
+ free(plan.work_data);
17352
+ }
17353
+ }
17354
+
17342
17355
void ggml_graph_reset(struct ggml_cgraph * cgraph) {
17343
17356
for (int i = 0; i < cgraph->n_nodes; i++) {
17344
17357
struct ggml_tensor * grad = cgraph->grads[i];
@@ -18193,17 +18206,7 @@ static enum ggml_opt_result ggml_opt_adam(
18193
18206
ggml_graph_reset (gf);
18194
18207
ggml_set_f32 (f->grad, 1.0f);
18195
18208
18196
- {
18197
- struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(gb, params.n_threads);
18198
- if (plan.work_size > 0) {
18199
- plan.work_data = malloc(plan.work_size);
18200
- GGML_ASSERT(plan.work_data);
18201
- }
18202
- ggml_graph_compute(&plan, gb);
18203
- if (plan.work_data) {
18204
- free(plan.work_data);
18205
- }
18206
- }
18209
+ ggml_graph_compute_sugar(gb, params.n_threads);
18207
18210
18208
18211
opt->adam.fx_prev = ggml_get_f32_1d(f, 0);
18209
18212
opt->adam.fx_best = opt->adam.fx_prev;
@@ -18284,17 +18287,7 @@ static enum ggml_opt_result ggml_opt_adam(
18284
18287
ggml_graph_reset (gf);
18285
18288
ggml_set_f32 (f->grad, 1.0f);
18286
18289
18287
- {
18288
- struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(gb, params.n_threads);
18289
- if (plan.work_size > 0) {
18290
- plan.work_data = malloc(plan.work_size);
18291
- GGML_ASSERT(plan.work_data);
18292
- }
18293
- ggml_graph_compute(&plan, gb);
18294
- if (plan.work_data) {
18295
- free(plan.work_data);
18296
- }
18297
- }
18290
+ ggml_graph_compute_sugar(gb, params.n_threads);
18298
18291
18299
18292
const float fx = ggml_get_f32_1d(f, 0);
18300
18293
@@ -18416,17 +18409,7 @@ static enum ggml_opt_result linesearch_backtracking(
18416
18409
ggml_graph_reset (gf);
18417
18410
ggml_set_f32 (f->grad, 1.0f);
18418
18411
18419
- {
18420
- struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(gb, params->n_threads);
18421
- if (plan.work_size > 0) {
18422
- plan.work_data = malloc(plan.work_size);
18423
- GGML_ASSERT(plan.work_data);
18424
- }
18425
- ggml_graph_compute(&plan, gb);
18426
- if (plan.work_data) {
18427
- free(plan.work_data);
18428
- }
18429
- }
18412
+ ggml_graph_compute_sugar(gb, params->n_threads);
18430
18413
18431
18414
ggml_opt_get_grad(np, ps, g);
18432
18415
@@ -18545,17 +18528,8 @@ static enum ggml_opt_result ggml_opt_lbfgs(
18545
18528
18546
18529
ggml_graph_reset (gf);
18547
18530
ggml_set_f32 (f->grad, 1.0f);
18548
- {
18549
- struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(gb, params.n_threads);
18550
- if (plan.work_size > 0) {
18551
- plan.work_data = malloc(plan.work_size);
18552
- GGML_ASSERT(plan.work_data);
18553
- }
18554
- ggml_graph_compute(&plan, gb);
18555
- if (plan.work_data) {
18556
- free(plan.work_data);
18557
- }
18558
- }
18531
+
18532
+ ggml_graph_compute_sugar(gb, params.n_threads);
18559
18533
18560
18534
ggml_opt_get_grad(np, ps, g);
18561
18535
0 commit comments