@@ -16493,21 +16493,17 @@ void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan)
16493
16493
}
16494
16494
}
16495
16495
16496
- // TODO: avoid allocating memory frequently.
16497
- // TODO: make part of public API - use different name and put warning that it makes allocations
16498
- static void ggml_graph_compute_helper( struct ggml_cgraph * cgraph, int n_threads) {
16496
+ // same as ggml_graph_compute() but the work data is allocated as a part of the context
16497
+ // note: the drawback of this API is that you must have ensured that the context has enough memory for the work data
16498
+ void ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads) {
16499
16499
struct ggml_cplan cplan = ggml_graph_plan(cgraph, n_threads);
16500
16500
16501
- if (cplan.work_size > 0) {
16502
- cplan.work_data = malloc(cplan.work_size);
16503
- GGML_ASSERT(cplan.work_data);
16504
- }
16501
+ struct ggml_tensor * buf = ggml_new_tensor_1d(ctx, GGML_TYPE_I8, cplan.work_size);
16502
+ GGML_ASSERT(buf);
16505
16503
16506
- ggml_graph_compute(cgraph, & cplan) ;
16504
+ cplan.work_data = buf->data ;
16507
16505
16508
- if (cplan.work_data) {
16509
- free(cplan.work_data);
16510
- }
16506
+ ggml_graph_compute(cgraph, &cplan);
16511
16507
}
16512
16508
16513
16509
void ggml_graph_reset(struct ggml_cgraph * cgraph) {
@@ -17292,6 +17288,7 @@ static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g
17292
17288
//
17293
17289
17294
17290
static enum ggml_opt_result ggml_opt_adam(
17291
+ struct ggml_context * ctx,
17295
17292
struct ggml_opt_context * opt,
17296
17293
struct ggml_opt_params params,
17297
17294
struct ggml_tensor * f,
@@ -17346,7 +17343,7 @@ static enum ggml_opt_result ggml_opt_adam(
17346
17343
ggml_graph_reset (gf);
17347
17344
ggml_set_f32 (f->grad, 1.0f);
17348
17345
17349
- ggml_graph_compute_helper( gb, params.n_threads);
17346
+ ggml_graph_compute_with_ctx(ctx, gb, params.n_threads);
17350
17347
17351
17348
opt->adam.fx_prev = ggml_get_f32_1d(f, 0);
17352
17349
opt->adam.fx_best = opt->adam.fx_prev;
@@ -17427,7 +17424,7 @@ static enum ggml_opt_result ggml_opt_adam(
17427
17424
ggml_graph_reset (gf);
17428
17425
ggml_set_f32 (f->grad, 1.0f);
17429
17426
17430
- ggml_graph_compute_helper( gb, params.n_threads);
17427
+ ggml_graph_compute_with_ctx(ctx, gb, params.n_threads);
17431
17428
17432
17429
const float fx = ggml_get_f32_1d(f, 0);
17433
17430
@@ -17498,6 +17495,7 @@ struct ggml_lbfgs_iteration_data {
17498
17495
};
17499
17496
17500
17497
static enum ggml_opt_result linesearch_backtracking(
17498
+ struct ggml_context * ctx,
17501
17499
const struct ggml_opt_params * params,
17502
17500
int nx,
17503
17501
float * x,
@@ -17549,7 +17547,7 @@ static enum ggml_opt_result linesearch_backtracking(
17549
17547
ggml_graph_reset (gf);
17550
17548
ggml_set_f32 (f->grad, 1.0f);
17551
17549
17552
- ggml_graph_compute_helper( gb, params->n_threads);
17550
+ ggml_graph_compute_with_ctx(ctx, gb, params->n_threads);
17553
17551
17554
17552
ggml_opt_get_grad(np, ps, g);
17555
17553
@@ -17669,7 +17667,7 @@ static enum ggml_opt_result ggml_opt_lbfgs(
17669
17667
ggml_graph_reset (gf);
17670
17668
ggml_set_f32 (f->grad, 1.0f);
17671
17669
17672
- ggml_graph_compute_helper( gb, params.n_threads);
17670
+ ggml_graph_compute_with_ctx(ctx, gb, params.n_threads);
17673
17671
17674
17672
ggml_opt_get_grad(np, ps, g);
17675
17673
@@ -17728,7 +17726,7 @@ static enum ggml_opt_result ggml_opt_lbfgs(
17728
17726
ggml_vec_cpy_f32(nx, xp, x);
17729
17727
ggml_vec_cpy_f32(nx, gp, g);
17730
17728
17731
- ls = linesearch_backtracking(¶ms, nx, x, &fx, g, d, step, xp, f, gf, gb, np, ps);
17729
+ ls = linesearch_backtracking(ctx, ¶ms, nx, x, &fx, g, d, step, xp, f, gf, gb, np, ps);
17732
17730
17733
17731
if (ls < 0) {
17734
17732
// linesearch failed - go back to the previous point and return
@@ -18030,7 +18028,7 @@ enum ggml_opt_result ggml_opt_resume_g(
18030
18028
switch (opt->params.type) {
18031
18029
case GGML_OPT_ADAM:
18032
18030
{
18033
- result = ggml_opt_adam(opt, opt->params, f, gf, gb);
18031
+ result = ggml_opt_adam(ctx, opt, opt->params, f, gf, gb);
18034
18032
} break;
18035
18033
case GGML_OPT_LBFGS:
18036
18034
{
0 commit comments