Skip to content

Commit 71b69aa

Browse files
committed
cuda : fix flash_attn kernel to produce same results as CPU
1 parent fd878f7 commit 71b69aa

File tree

2 files changed

+42
-26
lines changed

2 files changed

+42
-26
lines changed

ggml-cuda.cu

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6445,7 +6445,7 @@ static __global__ void flash_attn_ext_f16(
64456445
const int D16 = D/16;
64466446
const int Q16 = Q/16;
64476447
const int NW = WARP_SIZE;
6448-
const int SH = (C + Q); // shared memory per simdgroup in (half)
6448+
const int SH = (C + 2*Q); // shared memory per simdgroup in (half)
64496449

64506450
const int T = D + num_warps*SH; // shared memory size per query in (half)
64516451
const int T2 = T/2; // shared memory size per query in (half2)
@@ -6526,11 +6526,16 @@ static __global__ void flash_attn_ext_f16(
65266526
}
65276527
}
65286528

6529-
const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1;
6530-
65316529
// pointer to the mask
65326530
const half * mp = mask ? (const half *) (mask + iq1*nb31) : nullptr;
65336531

6532+
// prepare diagonal scale matrix
6533+
half16x16_b mscale;
6534+
for (int i = 0; i < 16; ++i) {
6535+
ss[i*T + i] = __float2half(scale);
6536+
}
6537+
nvcuda::wmma::load_matrix_sync(mscale, ss, T);
6538+
65346539
// loop over the KV cache
65356540
// each simdgroup handles blocks of Q rows and C columns
65366541
for (int64_t ic = C*warp_id; ic < ne11; ic += C*num_warps) {
@@ -6555,10 +6560,15 @@ static __global__ void flash_attn_ext_f16(
65556560

65566561
// mqk = mqk*scale + mask
65576562
for (int64_t j = 0; j < Q16; ++j) {
6558-
for (uint32_t i = 0; i < mqk[j].num_elements; i++) {
6559-
// TODO: process mask
6560-
mqk[j].x[i] = __float2half(scale) * mqk[j].x[i];
6561-
}
6563+
half16x16_a mqka;
6564+
half16x16_acc mm;
6565+
6566+
// convert accumulator to matrix_a
6567+
nvcuda::wmma::store_matrix_sync( ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major);
6568+
nvcuda::wmma::load_matrix_sync (mqka, ss + 16*j*T + 16*cc, T);
6569+
6570+
nvcuda::wmma::load_matrix_sync(mm, mp + 16*j*(nb31/sizeof(half)) + ic + 16*cc, nb31/sizeof(half), nvcuda::wmma::mem_row_major);
6571+
nvcuda::wmma::mma_sync(mqk[j], mqka, mscale, mm);
65626572
nvcuda::wmma::store_matrix_sync(ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major);
65636573
}
65646574
}
@@ -6631,18 +6641,19 @@ static __global__ void flash_attn_ext_f16(
66316641

66326642
// O = diag(ms)*O
66336643
for (int64_t j = 0; j < Q16; ++j) {
6634-
// half16x16_a mm;
6635-
// half16x16_b zro;
6644+
half16x16_a mm;
6645+
half16x16_b lob;
66366646

6637-
// nvcuda::wmma::fill_fragment(zro, 0.0);
6638-
// nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T);
6647+
nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T);
66396648

66406649
for (int64_t i = 0; i < D16; ++i) {
6641-
//nvcuda::wmma::mma_sync(lo[j][i], mm, zro, lo[j][i]);
6642-
for (uint32_t k = 0; k < 16*16; k++) {
6643-
half tmp = ss[(16*j + k%16)*T + C + 16*j + k%16];
6644-
lo[j][i].x[k] = tmp * lo[j][i].x[k];
6645-
}
6650+
// convert accumulator to matrix_b
6651+
// TODO: try to avoid the extra QxQ matrix in shared memory needed for this conversion
6652+
nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + Q, lo[j][i], T, nvcuda::wmma::mem_row_major);
6653+
nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + Q, T);
6654+
6655+
nvcuda::wmma::fill_fragment(lo[j][i], 0.0);
6656+
nvcuda::wmma::mma_sync(lo[j][i], mm, lob, lo[j][i]);
66466657
}
66476658
}
66486659

@@ -6732,10 +6743,11 @@ static __global__ void flash_attn_ext_f16(
67326743
nvcuda::wmma::fill_fragment(t2, 0.0);
67336744
nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T);
67346745
nvcuda::wmma::mma_sync(t2, ms1, t, t2);
6735-
// store temporally 'lo' data
6736-
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
6737-
// load 'lo' data into t
6738-
nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T);
6746+
6747+
// convert accumulator to matrix_b
6748+
nvcuda::wmma::store_matrix_sync( sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
6749+
nvcuda::wmma::load_matrix_sync (t, sq + 16*j*T + i*16, T);
6750+
67396751
nvcuda::wmma::mma_sync(lo[j][i], ms0, t, t2);
67406752
}
67416753
}
@@ -10897,8 +10909,8 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
1089710909

1089810910
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
1089910911
GGML_ASSERT(!mask || mask->backend == GGML_BACKEND_GPU);
10900-
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 8) &&
10901-
"the Flash-Attention CUDA kernel requires the mask to be padded to 8 and at least n_queries big");
10912+
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
10913+
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
1090210914

1090310915
ggml_cuda_set_device(g_main_device);
1090410916
const cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
@@ -10914,13 +10926,17 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
1091410926

1091510927
const int nqpb = 16; // queries per block
1091610928
const int ncpw = 32; // cache values per warp (does not work for other values)
10917-
// const int nwarps = Q->ne[1] <= nqpb ? MAX(4, MIN(K->ne[1]/ncpw, 32)) : 4;
10918-
const int nwarps = 1;
10929+
10930+
const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much?
10931+
const int nwarps = Q->ne[1] <= nqpb ? MAX(4, MIN(K->ne[1]/ncpw, nwarps_max)) : 4;
1091910932

1092010933
dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]);
1092110934
dim3 block_dim(32, nwarps, 1);
1092210935

10923-
int shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2);
10936+
// TODO: compare to Metal, here we need extra `nqpb` space in order to do the diag(ms)*O scaling
10937+
// try to avoid this
10938+
const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + 2*nqpb))*(sizeof(float)/2);
10939+
1092410940
switch (Q->ne[0])
1092510941
{
1092610942
case 16:

tests/test-backend-ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2214,7 +2214,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
22142214
for (int hs : { 128, }) {
22152215
for (int nh : { 32, }) {
22162216
for (int kv : { 512, 1024, }) {
2217-
for (int nb : { 1, 2, 4, 8, 512 }) {
2217+
for (int nb : { 1, 2, 4, 7, 8, 15, 16, 512 }) {
22182218
test_cases.emplace_back(new test_attn (hs, nh, kv, nb));
22192219
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb));
22202220
}

0 commit comments

Comments
 (0)