@@ -7030,14 +7030,16 @@ struct ggml_tensor * ggml_flash_attn(
7030
7030
}
7031
7031
7032
7032
//struct ggml_tensor * result = ggml_dup_tensor(ctx, q);
7033
- struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, q->ne);
7033
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, q->n_dims, q->ne);
7034
+
7035
+ int32_t t = masked ? 1 : 0;
7036
+ ggml_set_op_params(result, &t, sizeof(t));
7034
7037
7035
7038
result->op = GGML_OP_FLASH_ATTN;
7036
7039
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7037
7040
result->src[0] = q;
7038
7041
result->src[1] = k;
7039
7042
result->src[2] = v;
7040
- result->src[3] = ggml_new_i32(ctx, masked ? 1 : 0);
7041
7043
7042
7044
return result;
7043
7045
}
@@ -7061,7 +7063,7 @@ struct ggml_tensor * ggml_flash_ff(
7061
7063
}
7062
7064
7063
7065
//struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
7064
- struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4 , a->ne);
7066
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, a->n_dims , a->ne);
7065
7067
7066
7068
result->op = GGML_OP_FLASH_FF;
7067
7069
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -7127,13 +7129,15 @@ struct ggml_tensor * ggml_flash_attn_back(
7127
7129
7128
7130
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
7129
7131
7132
+ int32_t masked_i = masked ? 1 : 0;
7133
+ ggml_set_op_params(result, &masked_i, sizeof(masked_i));
7134
+
7130
7135
result->op = GGML_OP_FLASH_ATTN_BACK;
7131
7136
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7132
7137
result->src[0] = q;
7133
7138
result->src[1] = k;
7134
7139
result->src[2] = v;
7135
7140
result->src[3] = d;
7136
- result->src[4] = ggml_new_i32(ctx, masked ? 1 : 0);
7137
7141
7138
7142
return result;
7139
7143
}
@@ -14773,7 +14777,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
14773
14777
} break;
14774
14778
case GGML_OP_FLASH_ATTN:
14775
14779
{
14776
- const int32_t t = ggml_get_i32_1d (tensor->src[3] , 0);
14780
+ const int32_t t = ggml_get_op_params_i32 (tensor, 0);
14777
14781
GGML_ASSERT(t == 0 || t == 1);
14778
14782
const bool masked = t != 0;
14779
14783
ggml_compute_forward_flash_attn(params, tensor->src[0], tensor->src[1], tensor->src[2], masked, tensor);
@@ -14784,7 +14788,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
14784
14788
} break;
14785
14789
case GGML_OP_FLASH_ATTN_BACK:
14786
14790
{
14787
- int32_t t = ggml_get_i32_1d (tensor->src[4] , 0);
14791
+ int32_t t = ggml_get_op_params_i32 (tensor, 0);
14788
14792
GGML_ASSERT(t == 0 || t == 1);
14789
14793
bool masked = t != 0;
14790
14794
ggml_compute_forward_flash_attn_back(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], masked, tensor);
@@ -15402,7 +15406,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
15402
15406
{
15403
15407
struct ggml_tensor * flash_grad = NULL;
15404
15408
if (src0->grad || src1->grad || tensor->src[2]->grad) {
15405
- int32_t t = ggml_get_i32_1d (tensor->src[3] , 0);
15409
+ int32_t t = ggml_get_op_params_i32 (tensor, 0);
15406
15410
GGML_ASSERT(t == 0 || t == 1);
15407
15411
bool masked = t != 0;
15408
15412
flash_grad =
0 commit comments