Skip to content

Commit 7254cdf

Browse files
JohannesGaesslerggerganov
authored andcommitted
ggml: fix gradient allocation logic (ggml/966)
* ggml: fix gradient allocation logic * gradient allocation in ggml_build_backward_expand * fixup * fix test-backend-ops grad * suggestions by slaren * fix test1.c * fix legacy opt API * fix test-grad0 * remove keep arg
1 parent cad341d commit 7254cdf

File tree

4 files changed

+491
-1066
lines changed

4 files changed

+491
-1066
lines changed

ggml/include/ggml.h

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -577,10 +577,10 @@ extern "C" {
577577

578578
// this tensor...
579579
enum ggml_tensor_flag {
580-
GGML_TENSOR_FLAG_INPUT = 1, // ...is an input for the GGML compute graph
581-
GGML_TENSOR_FLAG_OUTPUT = 2, // ...is an output for the GGML compute graph
582-
GGML_TENSOR_FLAG_PARAM = 4, // ...contains trainable parameters
583-
GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up)
580+
GGML_TENSOR_FLAG_INPUT = 1, // ...is an input for the GGML compute graph
581+
GGML_TENSOR_FLAG_OUTPUT = 2, // ...is an output for the GGML compute graph
582+
GGML_TENSOR_FLAG_PARAM = 4, // ...contains trainable parameters
583+
GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up)
584584
};
585585

586586
// n-dimensional tensor
@@ -1410,14 +1410,14 @@ extern "C" {
14101410
// supports 3D: a->ne[2] == b->ne[1]
14111411
GGML_API struct ggml_tensor * ggml_get_rows(
14121412
struct ggml_context * ctx,
1413-
struct ggml_tensor * a,
1414-
struct ggml_tensor * b);
1413+
struct ggml_tensor * a, // data
1414+
struct ggml_tensor * b); // row indices
14151415

14161416
GGML_API struct ggml_tensor * ggml_get_rows_back(
14171417
struct ggml_context * ctx,
1418-
struct ggml_tensor * a,
1419-
struct ggml_tensor * b,
1420-
struct ggml_tensor * c);
1418+
struct ggml_tensor * a, // gradients of ggml_get_rows result
1419+
struct ggml_tensor * b, // row indices
1420+
struct ggml_tensor * c); // data for ggml_get_rows, only used for its shape
14211421

14221422
GGML_API struct ggml_tensor * ggml_diag(
14231423
struct ggml_context * ctx,
@@ -1568,9 +1568,9 @@ extern "C" {
15681568
// a - dy
15691569
GGML_API struct ggml_tensor * ggml_rope_back(
15701570
struct ggml_context * ctx,
1571-
struct ggml_tensor * a,
1572-
struct ggml_tensor * b,
1573-
struct ggml_tensor * c,
1571+
struct ggml_tensor * a, // gradients of ggml_rope result
1572+
struct ggml_tensor * b, // positions
1573+
struct ggml_tensor * c, // freq factors
15741574
int n_dims,
15751575
int mode,
15761576
int n_ctx_orig,
@@ -2036,15 +2036,15 @@ extern "C" {
20362036
// loss function
20372037

20382038
GGML_API struct ggml_tensor * ggml_cross_entropy_loss(
2039-
struct ggml_context * ctx,
2040-
struct ggml_tensor * a,
2041-
struct ggml_tensor * b);
2039+
struct ggml_context * ctx,
2040+
struct ggml_tensor * a, // logits
2041+
struct ggml_tensor * b); // labels
20422042

20432043
GGML_API struct ggml_tensor * ggml_cross_entropy_loss_back(
2044-
struct ggml_context * ctx,
2045-
struct ggml_tensor * a,
2046-
struct ggml_tensor * b,
2047-
struct ggml_tensor * c);
2044+
struct ggml_context * ctx,
2045+
struct ggml_tensor * a, // logits
2046+
struct ggml_tensor * b, // labels
2047+
struct ggml_tensor * c); // gradients of cross_entropy_loss result
20482048

20492049
// AdamW optimizer step
20502050
// Paper: https://arxiv.org/pdf/1711.05101v3.pdf
@@ -2066,7 +2066,7 @@ extern "C" {
20662066
GGML_API void ggml_set_loss(struct ggml_tensor * tensor);
20672067

20682068
GGML_API void ggml_build_forward_expand (struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
2069-
GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool accumulate, bool keep);
2069+
GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool accumulate);
20702070

20712071
GGML_API void ggml_build_opt_adamw(
20722072
struct ggml_context * ctx,

0 commit comments

Comments
 (0)