Skip to content

Commit 2b502c3

Browse files
committed
add static ggml_graph_compute_sugar()
1 parent db81f33 commit 2b502c3

File tree

1 file changed

+19
-45
lines changed

1 file changed

+19
-45
lines changed

ggml.c

Lines changed: 19 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -16424,7 +16424,6 @@ void ggml_graph_compute(struct ggml_graph_compute_plan * plan, struct ggml_cgrap
1642416424
GGML_ASSERT(plan->n_tasks[i] > 0);
1642516425
}
1642616426
}
16427-
1642816427
}
1642916428

1643016429
const int n_threads = plan->n_threads;
@@ -16491,6 +16490,20 @@ void ggml_graph_compute(struct ggml_graph_compute_plan * plan, struct ggml_cgrap
1649116490
}
1649216491
}
1649316492

16493+
static void ggml_graph_compute_sugar(struct ggml_cgraph * cgraph, int n_threads) {
16494+
struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(cgraph, n_threads);
16495+
if (plan.work_size > 0) {
16496+
plan.work_data = malloc(plan.work_size);
16497+
GGML_ASSERT(plan.work_data);
16498+
}
16499+
16500+
ggml_graph_compute(&plan, cgraph);
16501+
16502+
if (plan.work_data) {
16503+
free(plan.work_data);
16504+
}
16505+
}
16506+
1649416507
void ggml_graph_reset(struct ggml_cgraph * cgraph) {
1649516508
for (int i = 0; i < cgraph->n_nodes; i++) {
1649616509
struct ggml_tensor * grad = cgraph->grads[i];
@@ -17327,17 +17340,7 @@ static enum ggml_opt_result ggml_opt_adam(
1732717340
ggml_graph_reset (gf);
1732817341
ggml_set_f32 (f->grad, 1.0f);
1732917342

17330-
{
17331-
struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(gb, params.n_threads);
17332-
if (plan.work_size > 0) {
17333-
plan.work_data = malloc(plan.work_size);
17334-
GGML_ASSERT(plan.work_data);
17335-
}
17336-
ggml_graph_compute(&plan, gb);
17337-
if (plan.work_data) {
17338-
free(plan.work_data);
17339-
}
17340-
}
17343+
ggml_graph_compute_sugar(gb, params.n_threads);
1734117344

1734217345
opt->adam.fx_prev = ggml_get_f32_1d(f, 0);
1734317346
opt->adam.fx_best = opt->adam.fx_prev;
@@ -17418,17 +17421,7 @@ static enum ggml_opt_result ggml_opt_adam(
1741817421
ggml_graph_reset (gf);
1741917422
ggml_set_f32 (f->grad, 1.0f);
1742017423

17421-
{
17422-
struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(gb, params.n_threads);
17423-
if (plan.work_size > 0) {
17424-
plan.work_data = malloc(plan.work_size);
17425-
GGML_ASSERT(plan.work_data);
17426-
}
17427-
ggml_graph_compute(&plan, gb);
17428-
if (plan.work_data) {
17429-
free(plan.work_data);
17430-
}
17431-
}
17424+
ggml_graph_compute_sugar(gb, params.n_threads);
1743217425

1743317426
const float fx = ggml_get_f32_1d(f, 0);
1743417427

@@ -17550,17 +17543,7 @@ static enum ggml_opt_result linesearch_backtracking(
1755017543
ggml_graph_reset (gf);
1755117544
ggml_set_f32 (f->grad, 1.0f);
1755217545

17553-
{
17554-
struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(gb, params->n_threads);
17555-
if (plan.work_size > 0) {
17556-
plan.work_data = malloc(plan.work_size);
17557-
GGML_ASSERT(plan.work_data);
17558-
}
17559-
ggml_graph_compute(&plan, gb);
17560-
if (plan.work_data) {
17561-
free(plan.work_data);
17562-
}
17563-
}
17546+
ggml_graph_compute_sugar(gb, params->n_threads);
1756417547

1756517548
ggml_opt_get_grad(np, ps, g);
1756617549

@@ -17679,17 +17662,8 @@ static enum ggml_opt_result ggml_opt_lbfgs(
1767917662

1768017663
ggml_graph_reset (gf);
1768117664
ggml_set_f32 (f->grad, 1.0f);
17682-
{
17683-
struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(gb, params.n_threads);
17684-
if (plan.work_size > 0) {
17685-
plan.work_data = malloc(plan.work_size);
17686-
GGML_ASSERT(plan.work_data);
17687-
}
17688-
ggml_graph_compute(&plan, gb);
17689-
if (plan.work_data) {
17690-
free(plan.work_data);
17691-
}
17692-
}
17665+
17666+
ggml_graph_compute_sugar(gb, params.n_threads);
1769317667

1769417668
ggml_opt_get_grad(np, ps, g);
1769517669

0 commit comments

Comments
 (0)