Skip to content

ggml : fa without mask + add asserts #7278

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 38 additions & 31 deletions ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -2512,13 +2512,14 @@ static enum ggml_status ggml_metal_graph_compute(
} break;
case GGML_OP_FLASH_ATTN_EXT:
{
GGML_ASSERT(ne00 % 4 == 0);
GGML_ASSERT(ne00 % 4 == 0);
GGML_ASSERT(ne11 % 32 == 0);

GGML_ASSERT(src0->type == GGML_TYPE_F32);

struct ggml_tensor * src3 = gf->nodes[i]->src[3];
GGML_ASSERT(ggml_are_same_shape (src1, src2));

GGML_ASSERT(ggml_are_same_shape(src1, src2));
GGML_ASSERT(src3);
struct ggml_tensor * src3 = gf->nodes[i]->src[3];

size_t offs_src3 = 0;

Expand All @@ -2528,6 +2529,11 @@ static enum ggml_status ggml_metal_graph_compute(
GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
"the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");

const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
const uint64_t nb21 = src2 ? src2->nb[1] : 0;
const uint64_t nb22 = src2 ? src2->nb[2] : 0;
const uint64_t nb23 = src2 ? src2->nb[3] : 0;

const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
//const int64_t ne31 = src3 ? src3->ne[1] : 0;
const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
Expand Down Expand Up @@ -2590,34 +2596,35 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
if (id_src3) {
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
} else {
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:3];
}
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:5];
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:6];
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:7];
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:8];
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:9];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10];
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11];
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12];
[encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:13];
[encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:14];
[encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:15];
[encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:16];
[encoder setBytes:&nb10 length:sizeof(uint64_t) atIndex:17];
[encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:18];
[encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:19];
[encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:20];
[encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:21];
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:22];
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:23];
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:24];
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:25];
[encoder setBytes:&scale length:sizeof( float) atIndex:26];
[encoder setBytes:&max_bias length:sizeof( float) atIndex:27];
[encoder setBytes:&m0 length:sizeof(m0) atIndex:28];
[encoder setBytes:&m1 length:sizeof(m1) atIndex:29];
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:30];
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
[encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11];
[encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12];
[encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13];
[encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
[encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
[encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
[encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17];
[encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18];
[encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19];
[encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20];
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21];
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22];
[encoder setBytes:&scale length:sizeof( float) atIndex:23];
[encoder setBytes:&max_bias length:sizeof( float) atIndex:24];
[encoder setBytes:&m0 length:sizeof(m0) atIndex:25];
[encoder setBytes:&m1 length:sizeof(m1) atIndex:26];
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27];

if (!use_vec_kernel) {
// half8x8 kernel
Expand Down
53 changes: 20 additions & 33 deletions ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -2049,27 +2049,24 @@ typedef void (flash_attn_ext_f16_t)(
device const char * v,
device const char * mask,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant int64_t & ne13,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant uint64_t & nb13,
constant uint64_t & nb21,
constant uint64_t & nb22,
constant uint64_t & nb23,
constant uint64_t & nb31,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant float & scale,
constant float & max_bias,
constant float & m0,
Expand All @@ -2090,27 +2087,24 @@ kernel void kernel_flash_attn_ext_f16(
device const char * v,
device const char * mask,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant int64_t & ne13,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant uint64_t & nb13,
constant uint64_t & nb21,
constant uint64_t & nb22,
constant uint64_t & nb23,
constant uint64_t & nb31,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant float & scale,
constant float & max_bias,
constant float & m0,
Expand Down Expand Up @@ -2180,10 +2174,6 @@ kernel void kernel_flash_attn_ext_f16(
const short ne22 = ne12;
const short ne23 = ne13;

const uint nb21 = nb11;
const uint nb22 = nb12;
const uint nb23 = nb13;

// broadcast
const short rk2 = ne02/ne12;
const short rk3 = ne03/ne13;
Expand Down Expand Up @@ -2247,11 +2237,16 @@ kernel void kernel_flash_attn_ext_f16(
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
}

// mqk = mqk*scale + mask*slope
simdgroup_half8x8 mm;
simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false);
simdgroup_multiply(mm, mslope, mm);
simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
if (mask != q) {
// mqk = mqk*scale + mask*slope
simdgroup_half8x8 mm;
simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false);
simdgroup_multiply(mm, mslope, mm);
simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
} else {
// mqk = mqk*scale
simdgroup_multiply(mqk, mscale, mqk);
}

simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
}
Expand Down Expand Up @@ -2425,27 +2420,24 @@ kernel void kernel_flash_attn_ext_vec_f16(
device const char * v,
device const char * mask,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant int64_t & ne13,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant uint64_t & nb13,
constant uint64_t & nb21,
constant uint64_t & nb22,
constant uint64_t & nb23,
constant uint64_t & nb31,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant float & scale,
constant float & max_bias,
constant float & m0,
Expand Down Expand Up @@ -2521,10 +2513,6 @@ kernel void kernel_flash_attn_ext_vec_f16(
const short ne22 = ne12;
const short ne23 = ne13;

const uint nb21 = nb11;
const uint nb22 = nb12;
const uint nb23 = nb13;

// broadcast
const short rk2 = ne02/ne12;
const short rk3 = ne03/ne13;
Expand Down Expand Up @@ -2589,8 +2577,7 @@ kernel void kernel_flash_attn_ext_vec_f16(

// mqk = mqk*scale + mask*slope
if (tiisg == 0) {
float4 mm = (float4) mp4[ic/4 + cc];
mqk = mqk*scale + mm*slope;
mqk = mqk*scale + ((mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f);

ss4[cc] = mqk;
}
Expand Down
10 changes: 10 additions & 0 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -2822,6 +2822,16 @@ bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor
(t0->ne[3] == t1->ne[3] );
}

bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");

return
(t0->nb[0] == t1->nb[0] ) &&
(t0->nb[1] == t1->nb[1] ) &&
(t0->nb[2] == t1->nb[2] ) &&
(t0->nb[3] == t1->nb[3] );
}

// check if t1 can be represented as a repeatition of t0
static inline bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
Expand Down
3 changes: 2 additions & 1 deletion ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -766,7 +766,8 @@ extern "C" {
GGML_API bool ggml_is_3d (const struct ggml_tensor * tensor);
GGML_API int ggml_n_dims (const struct ggml_tensor * tensor); // returns 1 for scalars

GGML_API bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor * t1);
GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1);
GGML_API bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1);

// use this to compute the memory overhead of a tensor
GGML_API size_t ggml_tensor_overhead(void);
Expand Down
25 changes: 15 additions & 10 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1487,25 +1487,27 @@ struct test_flash_attn_ext : public test_case {
const int64_t kv; // kv size
const int64_t nb; // batch size

const bool mask; // use mask

const float max_bias; // ALiBi

std::string vars() override {
return VARS_TO_STR5(hs, nh, kv, nb, max_bias);
return VARS_TO_STR6(hs, nh, kv, nb, mask, max_bias);
}

double max_nmse_err() override {
return 5e-4;
}

test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8, float max_bias = 0.0f)
: hs(hs), nh(nh), kv(kv), nb(nb), max_bias(max_bias) {}
test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8, bool mask = true, float max_bias = 0.0f)
: hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias) {}

ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1);
ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1);
ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs), max_bias);
ggml_tensor * m = mask ? ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1) : nullptr;
ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hs), max_bias);
return out;
}
};
Expand Down Expand Up @@ -2175,11 +2177,14 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_leaky_relu());

for (int hs : { 64, 80, 128, 256, }) {
for (float max_bias : {0.0f, 8.0f}) {
for (int nh : { 32, }) {
for (int kv : { 512, 1024, }) {
for (int nb : { 1, 2, 4, 8, }) {
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, max_bias));
for (bool mask : { true, false } ) {
for (float max_bias : { 0.0f, 8.0f }) {
if (!mask && max_bias > 0.0f) continue;
for (int nh : { 32, }) {
for (int kv : { 512, 1024, }) {
for (int nb : { 1, 2, 4, 8, }) {
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias));
}
}
}
}
Expand Down
Loading