Skip to content

Commit 899f7a2

Browse files
gqa_mode arg for repeat_back
1 parent 5a4477a commit 899f7a2

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

ggml/include/ggml.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -943,7 +943,8 @@ extern "C" {
943943
GGML_API struct ggml_tensor * ggml_repeat_back(
944944
struct ggml_context * ctx,
945945
struct ggml_tensor * a,
946-
struct ggml_tensor * b);
946+
struct ggml_tensor * b,
947+
bool gqa_mode); // use memory pattern for backward pass of mat. mul. with group-query attention
947948

948949
// concat a and b along dim
949950
// used in stable-diffusion

ggml/src/ggml.c

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2303,13 +2303,15 @@ struct ggml_tensor * ggml_repeat(
23032303
struct ggml_tensor * ggml_repeat_back(
23042304
struct ggml_context * ctx,
23052305
struct ggml_tensor * a,
2306-
struct ggml_tensor * b) {
2306+
struct ggml_tensor * b,
2307+
bool gqa_mode) {
23072308
GGML_ASSERT(ggml_can_repeat(b, a));
23082309

23092310
struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, b->ne);
23102311

23112312
result->op = GGML_OP_REPEAT_BACK;
23122313
result->src[0] = a;
2314+
result->op_params[1] = gqa_mode ? 1 : 0;
23132315

23142316
return result;
23152317
}
@@ -5129,7 +5131,7 @@ static void ggml_compute_backward(
51295131
if (src1_needs_grads) {
51305132
struct ggml_tensor * tmp = grad;
51315133
if (!ggml_are_same_shape(src0, src1)) {
5132-
tmp = ggml_repeat_back(ctx, tmp, src1);
5134+
tmp = ggml_repeat_back(ctx, tmp, src1, false);
51335135
}
51345136
ggml_add_or_set(ctx, cgraph, isrc1, tmp);
51355137
}
@@ -5174,7 +5176,7 @@ static void ggml_compute_backward(
51745176
if (src1_needs_grads) {
51755177
struct ggml_tensor * tmp = ggml_mul(ctx, src0, grad);
51765178
if (!ggml_are_same_shape(src0, src1)) {
5177-
tmp = ggml_repeat_back(ctx, tmp, src1);
5179+
tmp = ggml_repeat_back(ctx, tmp, src1, false);
51785180
}
51795181
ggml_add_or_set(ctx, cgraph, isrc1, tmp);
51805182
}
@@ -5229,7 +5231,7 @@ static void ggml_compute_backward(
52295231
} break;
52305232
case GGML_OP_REPEAT: {
52315233
if (src0_needs_grads) {
5232-
ggml_add_or_set(ctx, cgraph, isrc0, ggml_repeat_back(ctx, grad, src0));
5234+
ggml_add_or_set(ctx, cgraph, isrc0, ggml_repeat_back(ctx, grad, src0, false));
52335235
}
52345236
} break;
52355237
case GGML_OP_REPEAT_BACK: {
@@ -5268,8 +5270,7 @@ static void ggml_compute_backward(
52685270
if (!ggml_are_same_shape(tmp, src0)) {
52695271
GGML_ASSERT(tmp->ne[0] == src0->ne[0]);
52705272
GGML_ASSERT(tmp->ne[1] == src0->ne[1]);
5271-
tmp = ggml_repeat_back(ctx, tmp, src0);
5272-
tmp->op_params[0] = 1; // FIXME
5273+
tmp = ggml_repeat_back(ctx, tmp, src0, true);
52735274
}
52745275
ggml_add_or_set(ctx, cgraph, isrc0, tmp);
52755276
}

0 commit comments

Comments
 (0)