Skip to content

Commit 219385c

Browse files
refactor ggml_opt, fix test-opt
1 parent d707261 commit 219385c

File tree

5 files changed

+221
-189
lines changed

5 files changed

+221
-189
lines changed

ggml/include/ggml-opt.h

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,9 @@ extern "C" {
6969
// ====== Model / Context ======
7070

7171
enum ggml_opt_build_type {
72-
GGML_OPT_BUILD_TYPE_FORWARD,
73-
GGML_OPT_BUILD_TYPE_GRAD,
74-
GGML_OPT_BUILD_TYPE_OPT,
72+
GGML_OPT_BUILD_TYPE_FORWARD = 10,
73+
GGML_OPT_BUILD_TYPE_GRAD = 20,
74+
GGML_OPT_BUILD_TYPE_OPT = 30,
7575
};
7676

7777
// parameters that control which optimizer is used and how said optimizer tries to find the minimal loss
@@ -101,13 +101,11 @@ extern "C" {
101101
struct ggml_opt_params {
102102
ggml_backend_sched_t backend_sched; // defines which backends are used to construct the compute graphs
103103

104-
struct ggml_context * ctx_compute; // created in user code, holds non-static tensors
105-
106-
// the forward graph is defined by inputs and outputs
107-
// the outputs and all tensors between inputs and outputs that have not been statically allocated
108-
// are not intended to be reusable between multiple optimization contexts
109-
struct ggml_tensor * inputs;
110-
struct ggml_tensor * outputs;
104+
// by default the forward graph needs to be reconstructed for each eval
105+
// if ctx_compute, inputs, and outputs are set the graphs are instead allocated statically
106+
struct ggml_context * ctx_compute;
107+
struct ggml_tensor * inputs;
108+
struct ggml_tensor * outputs;
111109

112110
enum ggml_opt_loss_type loss_type;
113111
enum ggml_opt_build_type build_type;
@@ -121,11 +119,8 @@ extern "C" {
121119
// get parameters for an optimization context with defaults set where possible
122120
// parameters for which no sensible defaults exist are supplied as arguments to this function
123121
GGML_API struct ggml_opt_params ggml_opt_default_params(
124-
ggml_backend_sched_t backend_sched,
125-
struct ggml_context * ctx_compute,
126-
struct ggml_tensor * inputs,
127-
struct ggml_tensor * outputs,
128-
enum ggml_opt_loss_type loss_type);
122+
ggml_backend_sched_t backend_sched,
123+
enum ggml_opt_loss_type loss_type);
129124

130125
GGML_API ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params);
131126
GGML_API void ggml_opt_free(ggml_opt_context_t opt_ctx);
@@ -134,13 +129,15 @@ extern "C" {
134129
GGML_API void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer);
135130

136131
// get underlying tensors that store data
132+
// if not using static graphs these pointers become invalid with the next call to ggml_opt_alloc
137133
GGML_API struct ggml_tensor * ggml_opt_inputs( ggml_opt_context_t opt_ctx); // forward graph input tensor
138134
GGML_API struct ggml_tensor * ggml_opt_outputs( ggml_opt_context_t opt_ctx); // forward graph output tensor
139135
GGML_API struct ggml_tensor * ggml_opt_labels( ggml_opt_context_t opt_ctx); // labels to compare outputs against
140136
GGML_API struct ggml_tensor * ggml_opt_loss( ggml_opt_context_t opt_ctx); // scalar tensor that contains the loss
141137
GGML_API struct ggml_tensor * ggml_opt_pred( ggml_opt_context_t opt_ctx); // predictions made by outputs
142138
GGML_API struct ggml_tensor * ggml_opt_ncorrect(ggml_opt_context_t opt_ctx); // number of matching predictions between outputs and labels
143139

140+
// get the gradient accumulator for a node from the forward graph
144141
GGML_API struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node);
145142

146143
// ====== Optimization Result ======
@@ -157,15 +154,20 @@ extern "C" {
157154

158155
// ====== Computation ======
159156

160-
GGML_API void ggml_opt_set_forward_graph(
161-
ggml_opt_context_t opt_ctx, struct ggml_context * ctx_compute, struct ggml_cgraph * gf,
162-
struct ggml_tensor * inputs, struct ggml_tensor * outputs, bool backward);
157+
// if not using static graphs, this function must be called prior to ggml_opt_alloc
158+
GGML_API void ggml_opt_prepare_alloc(
159+
ggml_opt_context_t opt_ctx,
160+
struct ggml_context * ctx_compute,
161+
struct ggml_cgraph * gf,
162+
struct ggml_tensor * inputs,
163+
struct ggml_tensor * outputs);
163164

164-
// do forward pass, increment result if not NULL
165-
GGML_API void ggml_opt_forward(ggml_opt_context_t opt_ctx, ggml_opt_result_t result);
165+
// allocate the next graph for evaluation, either forward or forward + backward
166+
// must be called exactly once prior to calling ggml_opt_eval
167+
GGML_API void ggml_opt_alloc(ggml_opt_context_t opt_ctx, bool backward);
166168

167-
// do forward pass, increment result if not NULL, do backward pass
168-
GGML_API void ggml_opt_forward_backward(ggml_opt_context_t opt_ctx, ggml_opt_result_t result);
169+
// do forward pass, increment result if not NULL, do backward pass if allocated
170+
GGML_API void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result);
169171

170172
// ############################################################################
171173
// ## The high-level functions start here. They do not depend on any private ##

0 commit comments

Comments
 (0)