@@ -69,9 +69,9 @@ extern "C" {
69
69
// ====== Model / Context ======
70
70
71
71
enum ggml_opt_build_type {
72
- GGML_OPT_BUILD_TYPE_FORWARD ,
73
- GGML_OPT_BUILD_TYPE_GRAD ,
74
- GGML_OPT_BUILD_TYPE_OPT ,
72
+ GGML_OPT_BUILD_TYPE_FORWARD = 10 ,
73
+ GGML_OPT_BUILD_TYPE_GRAD = 20 ,
74
+ GGML_OPT_BUILD_TYPE_OPT = 30 ,
75
75
};
76
76
77
77
// parameters that control which optimizer is used and how said optimizer tries to find the minimal loss
@@ -101,13 +101,11 @@ extern "C" {
101
101
struct ggml_opt_params {
102
102
ggml_backend_sched_t backend_sched ; // defines which backends are used to construct the compute graphs
103
103
104
- struct ggml_context * ctx_compute ; // created in user code, holds non-static tensors
105
-
106
- // the forward graph is defined by inputs and outputs
107
- // the outputs and all tensors between inputs and outputs that have not been statically allocated
108
- // are not intended to be reusable between multiple optimization contexts
109
- struct ggml_tensor * inputs ;
110
- struct ggml_tensor * outputs ;
104
+ // by default the forward graph needs to be reconstructed for each eval
105
+ // if ctx_compute, inputs, and outputs are set the graphs are instead allocated statically
106
+ struct ggml_context * ctx_compute ;
107
+ struct ggml_tensor * inputs ;
108
+ struct ggml_tensor * outputs ;
111
109
112
110
enum ggml_opt_loss_type loss_type ;
113
111
enum ggml_opt_build_type build_type ;
@@ -121,11 +119,8 @@ extern "C" {
121
119
// get parameters for an optimization context with defaults set where possible
122
120
// parameters for which no sensible defaults exist are supplied as arguments to this function
123
121
GGML_API struct ggml_opt_params ggml_opt_default_params (
124
- ggml_backend_sched_t backend_sched ,
125
- struct ggml_context * ctx_compute ,
126
- struct ggml_tensor * inputs ,
127
- struct ggml_tensor * outputs ,
128
- enum ggml_opt_loss_type loss_type );
122
+ ggml_backend_sched_t backend_sched ,
123
+ enum ggml_opt_loss_type loss_type );
129
124
130
125
GGML_API ggml_opt_context_t ggml_opt_init (struct ggml_opt_params params );
131
126
GGML_API void ggml_opt_free (ggml_opt_context_t opt_ctx );
@@ -134,13 +129,15 @@ extern "C" {
134
129
GGML_API void ggml_opt_reset (ggml_opt_context_t opt_ctx , bool optimizer );
135
130
136
131
// get underlying tensors that store data
132
+ // if not using static graphs these pointers become invalid with the next call to ggml_opt_alloc
137
133
GGML_API struct ggml_tensor * ggml_opt_inputs ( ggml_opt_context_t opt_ctx ); // forward graph input tensor
138
134
GGML_API struct ggml_tensor * ggml_opt_outputs ( ggml_opt_context_t opt_ctx ); // forward graph output tensor
139
135
GGML_API struct ggml_tensor * ggml_opt_labels ( ggml_opt_context_t opt_ctx ); // labels to compare outputs against
140
136
GGML_API struct ggml_tensor * ggml_opt_loss ( ggml_opt_context_t opt_ctx ); // scalar tensor that contains the loss
141
137
GGML_API struct ggml_tensor * ggml_opt_pred ( ggml_opt_context_t opt_ctx ); // predictions made by outputs
142
138
GGML_API struct ggml_tensor * ggml_opt_ncorrect (ggml_opt_context_t opt_ctx ); // number of matching predictions between outputs and labels
143
139
140
+ // get the gradient accumulator for a node from the forward graph
144
141
GGML_API struct ggml_tensor * ggml_opt_grad_acc (ggml_opt_context_t opt_ctx , struct ggml_tensor * node );
145
142
146
143
// ====== Optimization Result ======
@@ -157,15 +154,20 @@ extern "C" {
157
154
158
155
// ====== Computation ======
159
156
160
- GGML_API void ggml_opt_set_forward_graph (
161
- ggml_opt_context_t opt_ctx , struct ggml_context * ctx_compute , struct ggml_cgraph * gf ,
162
- struct ggml_tensor * inputs , struct ggml_tensor * outputs , bool backward );
157
+ // if not using static graphs, this function must be called prior to ggml_opt_alloc
158
+ GGML_API void ggml_opt_prepare_alloc (
159
+ ggml_opt_context_t opt_ctx ,
160
+ struct ggml_context * ctx_compute ,
161
+ struct ggml_cgraph * gf ,
162
+ struct ggml_tensor * inputs ,
163
+ struct ggml_tensor * outputs );
163
164
164
- // do forward pass, increment result if not NULL
165
- GGML_API void ggml_opt_forward (ggml_opt_context_t opt_ctx , ggml_opt_result_t result );
165
+ // allocate the next graph for evaluation, either forward or forward + backward
166
+ // must be called exactly once prior to calling ggml_opt_eval
167
+ GGML_API void ggml_opt_alloc (ggml_opt_context_t opt_ctx , bool backward );
166
168
167
- // do forward pass, increment result if not NULL, do backward pass
168
- GGML_API void ggml_opt_forward_backward (ggml_opt_context_t opt_ctx , ggml_opt_result_t result );
169
+ // do forward pass, increment result if not NULL, do backward pass if allocated
170
+ GGML_API void ggml_opt_eval (ggml_opt_context_t opt_ctx , ggml_opt_result_t result );
169
171
170
172
// ############################################################################
171
173
// ## The high-level functions start here. They do not depend on any private ##
0 commit comments