Skip to content

Commit 07aaa0f

Browse files
authored
ggml : fix ggml_flash_attn to use op_params (#2387)
* ggml : fix ggml_flash_attn to use op_params
1 parent fce48ca commit 07aaa0f

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

ggml.c

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7030,14 +7030,16 @@ struct ggml_tensor * ggml_flash_attn(
70307030
}
70317031

70327032
//struct ggml_tensor * result = ggml_dup_tensor(ctx, q);
7033-
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, q->ne);
7033+
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, q->n_dims, q->ne);
7034+
7035+
int32_t t = masked ? 1 : 0;
7036+
ggml_set_op_params(result, &t, sizeof(t));
70347037

70357038
result->op = GGML_OP_FLASH_ATTN;
70367039
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
70377040
result->src[0] = q;
70387041
result->src[1] = k;
70397042
result->src[2] = v;
7040-
result->src[3] = ggml_new_i32(ctx, masked ? 1 : 0);
70417043

70427044
return result;
70437045
}
@@ -7061,7 +7063,7 @@ struct ggml_tensor * ggml_flash_ff(
70617063
}
70627064

70637065
//struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
7064-
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, a->ne);
7066+
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, a->n_dims, a->ne);
70657067

70667068
result->op = GGML_OP_FLASH_FF;
70677069
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -7127,13 +7129,15 @@ struct ggml_tensor * ggml_flash_attn_back(
71277129

71287130
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
71297131

7132+
int32_t masked_i = masked ? 1 : 0;
7133+
ggml_set_op_params(result, &masked_i, sizeof(masked_i));
7134+
71307135
result->op = GGML_OP_FLASH_ATTN_BACK;
71317136
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
71327137
result->src[0] = q;
71337138
result->src[1] = k;
71347139
result->src[2] = v;
71357140
result->src[3] = d;
7136-
result->src[4] = ggml_new_i32(ctx, masked ? 1 : 0);
71377141

71387142
return result;
71397143
}
@@ -14773,7 +14777,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
1477314777
} break;
1477414778
case GGML_OP_FLASH_ATTN:
1477514779
{
14776-
const int32_t t = ggml_get_i32_1d(tensor->src[3], 0);
14780+
const int32_t t = ggml_get_op_params_i32(tensor, 0);
1477714781
GGML_ASSERT(t == 0 || t == 1);
1477814782
const bool masked = t != 0;
1477914783
ggml_compute_forward_flash_attn(params, tensor->src[0], tensor->src[1], tensor->src[2], masked, tensor);
@@ -14784,7 +14788,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
1478414788
} break;
1478514789
case GGML_OP_FLASH_ATTN_BACK:
1478614790
{
14787-
int32_t t = ggml_get_i32_1d(tensor->src[4], 0);
14791+
int32_t t = ggml_get_op_params_i32(tensor, 0);
1478814792
GGML_ASSERT(t == 0 || t == 1);
1478914793
bool masked = t != 0;
1479014794
ggml_compute_forward_flash_attn_back(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], masked, tensor);
@@ -15402,7 +15406,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
1540215406
{
1540315407
struct ggml_tensor * flash_grad = NULL;
1540415408
if (src0->grad || src1->grad || tensor->src[2]->grad) {
15405-
int32_t t = ggml_get_i32_1d(tensor->src[3], 0);
15409+
int32_t t = ggml_get_op_params_i32(tensor, 0);
1540615410
GGML_ASSERT(t == 0 || t == 1);
1540715411
bool masked = t != 0;
1540815412
flash_grad =

0 commit comments

Comments
 (0)