Skip to content

Commit 2313c54

Browse files
committed
ggml : add ggml_graph_compute_with_ctx()
- backwards compatible API - deduplicates a lot of copy-paste
1 parent 8e1f0b6 commit 2313c54

File tree

2 files changed

+20
-18
lines changed

2 files changed

+20
-18
lines changed

ggml.c

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16493,21 +16493,17 @@ void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan)
1649316493
}
1649416494
}
1649516495

16496-
// TODO: avoid allocating memory frequently.
16497-
// TODO: make part of public API - use different name and put warning that it makes allocations
16498-
static void ggml_graph_compute_helper(struct ggml_cgraph * cgraph, int n_threads) {
16496+
// same as ggml_graph_compute() but the work data is allocated as a part of the context
16497+
// note: the drawback of this API is that you must have ensured that the context has enough memory for the work data
16498+
void ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads) {
1649916499
struct ggml_cplan cplan = ggml_graph_plan(cgraph, n_threads);
1650016500

16501-
if (cplan.work_size > 0) {
16502-
cplan.work_data = malloc(cplan.work_size);
16503-
GGML_ASSERT(cplan.work_data);
16504-
}
16501+
struct ggml_tensor * buf = ggml_new_tensor_1d(ctx, GGML_TYPE_I8, cplan.work_size);
16502+
GGML_ASSERT(buf);
1650516503

16506-
ggml_graph_compute(cgraph, &cplan);
16504+
cplan.work_data = buf->data;
1650716505

16508-
if (cplan.work_data) {
16509-
free(cplan.work_data);
16510-
}
16506+
ggml_graph_compute(cgraph, &cplan);
1651116507
}
1651216508

1651316509
void ggml_graph_reset(struct ggml_cgraph * cgraph) {
@@ -17292,6 +17288,7 @@ static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g
1729217288
//
1729317289

1729417290
static enum ggml_opt_result ggml_opt_adam(
17291+
struct ggml_context * ctx,
1729517292
struct ggml_opt_context * opt,
1729617293
struct ggml_opt_params params,
1729717294
struct ggml_tensor * f,
@@ -17346,7 +17343,7 @@ static enum ggml_opt_result ggml_opt_adam(
1734617343
ggml_graph_reset (gf);
1734717344
ggml_set_f32 (f->grad, 1.0f);
1734817345

17349-
ggml_graph_compute_helper(gb, params.n_threads);
17346+
ggml_graph_compute_with_ctx(ctx, gb, params.n_threads);
1735017347

1735117348
opt->adam.fx_prev = ggml_get_f32_1d(f, 0);
1735217349
opt->adam.fx_best = opt->adam.fx_prev;
@@ -17427,7 +17424,7 @@ static enum ggml_opt_result ggml_opt_adam(
1742717424
ggml_graph_reset (gf);
1742817425
ggml_set_f32 (f->grad, 1.0f);
1742917426

17430-
ggml_graph_compute_helper(gb, params.n_threads);
17427+
ggml_graph_compute_with_ctx(ctx, gb, params.n_threads);
1743117428

1743217429
const float fx = ggml_get_f32_1d(f, 0);
1743317430

@@ -17498,6 +17495,7 @@ struct ggml_lbfgs_iteration_data {
1749817495
};
1749917496

1750017497
static enum ggml_opt_result linesearch_backtracking(
17498+
struct ggml_context * ctx,
1750117499
const struct ggml_opt_params * params,
1750217500
int nx,
1750317501
float * x,
@@ -17549,7 +17547,7 @@ static enum ggml_opt_result linesearch_backtracking(
1754917547
ggml_graph_reset (gf);
1755017548
ggml_set_f32 (f->grad, 1.0f);
1755117549

17552-
ggml_graph_compute_helper(gb, params->n_threads);
17550+
ggml_graph_compute_with_ctx(ctx, gb, params->n_threads);
1755317551

1755417552
ggml_opt_get_grad(np, ps, g);
1755517553

@@ -17669,7 +17667,7 @@ static enum ggml_opt_result ggml_opt_lbfgs(
1766917667
ggml_graph_reset (gf);
1767017668
ggml_set_f32 (f->grad, 1.0f);
1767117669

17672-
ggml_graph_compute_helper(gb, params.n_threads);
17670+
ggml_graph_compute_with_ctx(ctx, gb, params.n_threads);
1767317671

1767417672
ggml_opt_get_grad(np, ps, g);
1767517673

@@ -17728,7 +17726,7 @@ static enum ggml_opt_result ggml_opt_lbfgs(
1772817726
ggml_vec_cpy_f32(nx, xp, x);
1772917727
ggml_vec_cpy_f32(nx, gp, g);
1773017728

17731-
ls = linesearch_backtracking(&params, nx, x, &fx, g, d, step, xp, f, gf, gb, np, ps);
17729+
ls = linesearch_backtracking(ctx, &params, nx, x, &fx, g, d, step, xp, f, gf, gb, np, ps);
1773217730

1773317731
if (ls < 0) {
1773417732
// linesearch failed - go back to the previous point and return
@@ -18030,7 +18028,7 @@ enum ggml_opt_result ggml_opt_resume_g(
1803018028
switch (opt->params.type) {
1803118029
case GGML_OPT_ADAM:
1803218030
{
18033-
result = ggml_opt_adam(opt, opt->params, f, gf, gb);
18031+
result = ggml_opt_adam(ctx, opt, opt->params, f, gf, gb);
1803418032
} break;
1803518033
case GGML_OPT_LBFGS:
1803618034
{

ggml.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1306,7 +1306,7 @@ extern "C" {
13061306

13071307
GGML_API void ggml_set_param(
13081308
struct ggml_context * ctx,
1309-
struct ggml_tensor * tensor);
1309+
struct ggml_tensor * tensor);
13101310

13111311
GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
13121312

@@ -1319,6 +1319,10 @@ extern "C" {
13191319
GGML_API void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
13201320
GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph);
13211321

1322+
// same as ggml_graph_compute() but the work data is allocated as a part of the context
1323+
// note: the drawback of this API is that you must have ensured that the context has enough memory for the work data
1324+
GGML_API void ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads);
1325+
13221326
GGML_API struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name);
13231327

13241328
GGML_API void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname);

0 commit comments

Comments
 (0)