Skip to content

Commit 7c34655

Browse files
committed
cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
1 parent b150abe commit 7c34655

File tree

1 file changed

+35
-35
lines changed

1 file changed

+35
-35
lines changed

ggml-cuda.cu

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -6462,10 +6462,10 @@ static __global__ void flash_attn_ext_f16(
64626462
half16x16_acc lo[Q16][D16];
64636463

64646464
// load heads from Q to shared memory
6465-
for (int64_t j = warp_id; j < Q; j += num_warps) {
6465+
for (int j = warp_id; j < Q; j += num_warps) {
64666466
const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
64676467

6468-
for (int64_t i = lane_id; i < D2; i += NW) {
6468+
for (int i = lane_id; i < D2; i += NW) {
64696469
if (iq1 + j < ne01) {
64706470
sq2[j*T2 + i] = __float22half2_rn(q2[i]);
64716471
} else {
@@ -6477,15 +6477,15 @@ static __global__ void flash_attn_ext_f16(
64776477
nvcuda::wmma::fill_fragment(zr, 0.0);
64786478

64796479
// zero out lo
6480-
for (int64_t j = 0; j < Q16; ++j) {
6481-
for (int64_t i = 0; i < D16; ++i) {
6480+
for (int j = 0; j < Q16; ++j) {
6481+
for (int i = 0; i < D16; ++i) {
64826482
nvcuda::wmma::fill_fragment(lo[j][i], 0.0);
64836483
}
64846484
}
64856485

64866486
// zero out shared memory SH
6487-
for (int64_t j = 0; j < Q; ++j) {
6488-
for (int64_t i = lane_id; i < SH; i += NW) {
6487+
for (int j = 0; j < Q; ++j) {
6488+
for (int i = lane_id; i < SH; i += NW) {
64896489
ss[j*T + i] = 0.0;
64906490
}
64916491
}
@@ -6526,8 +6526,8 @@ static __global__ void flash_attn_ext_f16(
65266526

65276527
// load the queries from shared memory into local memory
65286528
half16x16_a mq[Q16][D16];
6529-
for (int64_t j = 0; j < Q16; ++j) {
6530-
for (int64_t i = 0; i < D16; ++i) {
6529+
for (int j = 0; j < Q16; ++j) {
6530+
for (int i = 0; i < D16; ++i) {
65316531
nvcuda::wmma::load_matrix_sync(mq[j][i], sq + 16*j*T + i*16, T);
65326532
}
65336533
}
@@ -6544,28 +6544,28 @@ static __global__ void flash_attn_ext_f16(
65446544

65456545
// loop over the KV cache
65466546
// each simdgroup handles blocks of Q rows and C columns
6547-
for (int64_t ic = C*warp_id; ic < ne11; ic += C*num_warps) {
6547+
for (int ic = C*warp_id; ic < ne11; ic += C*num_warps) {
65486548
// Q*K^T
65496549
{
65506550
for (int cc = 0; cc < C/16; ++cc) {
65516551
half16x16_acc mqk[Q16];
6552-
for (int64_t j = 0; j < Q16; ++j) {
6552+
for (int j = 0; j < Q16; ++j) {
65536553
nvcuda::wmma::fill_fragment(mqk[j], 0);
65546554
}
65556555

65566556
const half * pk = (const half *) ((const char *) k + ((ic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13));
65576557

6558-
for (int64_t i = 0; i < D16; ++i) {
6558+
for (int i = 0; i < D16; ++i) {
65596559
half16x16_bT mk; // transposed key
65606560
nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half));
65616561

6562-
for (int64_t j = 0; j < Q16; ++j) {
6562+
for (int j = 0; j < Q16; ++j) {
65636563
nvcuda::wmma::mma_sync(mqk[j], mq[j][i], mk, mqk[j]);
65646564
}
65656565
}
65666566

65676567
// mqk = mqk*scale + mask
6568-
for (int64_t j = 0; j < Q16; ++j) {
6568+
for (int j = 0; j < Q16; ++j) {
65696569
half16x16_a mqka;
65706570
half16x16_acc mm;
65716571

@@ -6588,8 +6588,8 @@ static __global__ void flash_attn_ext_f16(
65886588

65896589
// online softmax
65906590
if (C == 32) {
6591-
for (int64_t j = 0; j < Q; ++j) {
6592-
const int64_t p = lane_id;
6591+
for (int j = 0; j < Q; ++j) {
6592+
const int p = lane_id;
65936593

65946594
const half m = M[j];
65956595
const half s = ss[j*T + p];
@@ -6611,10 +6611,10 @@ static __global__ void flash_attn_ext_f16(
66116611
ss[j*T + p] = vs;
66126612
}
66136613
} else {
6614-
for (int64_t j = 0; j < Q; ++j) {
6614+
for (int j = 0; j < Q; ++j) {
66156615
const half m = M[j];
66166616

6617-
for (int64_t p = lane_id; p < C; p += NW) {
6617+
for (int p = lane_id; p < C; p += NW) {
66186618
const half s = ss[j*T + p];
66196619

66206620
smax = __hmax(smax, s);
@@ -6633,7 +6633,7 @@ static __global__ void flash_attn_ext_f16(
66336633
// local sum
66346634
half ls = 0.0f;
66356635

6636-
for (int64_t p = lane_id; p < C; p += NW) {
6636+
for (int p = lane_id; p < C; p += NW) {
66376637
const half s = ss[j*T + p];
66386638

66396639
const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]);
@@ -6656,13 +6656,13 @@ static __global__ void flash_attn_ext_f16(
66566656
}
66576657

66586658
// O = diag(ms)*O
6659-
for (int64_t j = 0; j < Q16; ++j) {
6659+
for (int j = 0; j < Q16; ++j) {
66606660
half16x16_a mm;
66616661
half16x16_b lob;
66626662

66636663
nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T);
66646664

6665-
for (int64_t i = 0; i < D16; ++i) {
6665+
for (int i = 0; i < D16; ++i) {
66666666
// convert accumulator to matrix_b
66676667
nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + 16*j, lo[j][i], T, nvcuda::wmma::mem_row_major);
66686668
nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + 16*j, T);
@@ -6680,17 +6680,17 @@ static __global__ void flash_attn_ext_f16(
66806680
const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23));
66816681

66826682
half16x16_b mk[D16];
6683-
for (int64_t i = 0; i < D16; ++i) {
6683+
for (int i = 0; i < D16; ++i) {
66846684
nvcuda::wmma::load_matrix_sync(mk[i], pv + i*16, nb21/sizeof(half));
66856685
}
66866686

66876687
half16x16_a mv[Q16];
6688-
for (int64_t j = 0; j < Q16; ++j) {
6688+
for (int j = 0; j < Q16; ++j) {
66896689
nvcuda::wmma::load_matrix_sync(mv[j], ss + 16*j*T + 16*cc, T);
66906690
}
66916691

6692-
for (int64_t j = 0; j < Q16; ++j) {
6693-
for (int64_t i = 0; i < D16; ++i) {
6692+
for (int j = 0; j < Q16; ++j) {
6693+
for (int i = 0; i < D16; ++i) {
66946694
nvcuda::wmma::mma_sync(lo[j][i], mv[j], mk[i], lo[j][i]);
66956695
}
66966696
}
@@ -6699,7 +6699,7 @@ static __global__ void flash_attn_ext_f16(
66996699
}
67006700

67016701
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
6702-
for (int64_t j = 0; j < Q; ++j) {
6702+
for (int j = 0; j < Q; ++j) {
67036703
if (lane_id == 0) {
67046704
ss[j*T + 0] = S[j];
67056705
ss[j*T + 1] = M[j];
@@ -6708,16 +6708,16 @@ static __global__ void flash_attn_ext_f16(
67086708
}
67096709

67106710
// reduce the warps sequentially
6711-
for (int64_t sg = 1; sg < num_warps; ++sg) {
6711+
for (int sg = 1; sg < num_warps; ++sg) {
67126712
half S = __float2half(0.0f);
67136713
half M = __float2half(-INFINITY);
67146714

67156715
__syncthreads();
67166716

67176717
// each simdgroup stores its output to shared memory, reusing sq
67186718
if (warp_id == sg) {
6719-
for (int64_t j = 0; j < Q16; ++j) {
6720-
for (int64_t i = 0; i < D16; ++i) {
6719+
for (int j = 0; j < Q16; ++j) {
6720+
for (int i = 0; i < D16; ++i) {
67216721
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
67226722
}
67236723
}
@@ -6727,7 +6727,7 @@ static __global__ void flash_attn_ext_f16(
67276727

67286728
// the first simdgroup accumulates the results from the other simdgroups
67296729
if (warp_id == 0) {
6730-
for (int64_t j = 0; j < Q; ++j) {
6730+
for (int j = 0; j < Q; ++j) {
67316731
const half S0 = ss[j*T + 0];
67326732
const half S1 = ss[j*T + sg*SH + 0];
67336733

@@ -6751,7 +6751,7 @@ static __global__ void flash_attn_ext_f16(
67516751
}
67526752

67536753
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
6754-
for (int64_t j = 0; j < Q16; ++j) {
6754+
for (int j = 0; j < Q16; ++j) {
67556755
half16x16_a ms0;
67566756
half16x16_a ms1;
67576757
half16x16_b t;
@@ -6760,7 +6760,7 @@ static __global__ void flash_attn_ext_f16(
67606760
nvcuda::wmma::load_matrix_sync(ms0, ss + 16*j*T + C + 16*j, T);
67616761
nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T);
67626762

6763-
for (int64_t i = 0; i < D16; ++i) {
6763+
for (int i = 0; i < D16; ++i) {
67646764
nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T);
67656765
nvcuda::wmma::mma_sync(t2, ms1, t, zr);
67666766

@@ -6776,19 +6776,19 @@ static __global__ void flash_attn_ext_f16(
67766776

67776777
// store result to shared memory (reuse sq)
67786778
if (warp_id == 0) {
6779-
for (int64_t j = 0; j < Q16; ++j) {
6780-
for (int64_t i = 0; i < D16; ++i) {
6779+
for (int j = 0; j < Q16; ++j) {
6780+
for (int i = 0; i < D16; ++i) {
67816781
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
67826782
}
67836783
}
67846784
}
67856785

67866786
// final rescale with 1/S and store to global memory
67876787
if (warp_id == 0) {
6788-
for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) {
6788+
for (int j = 0; j < Q && iq1 + j < ne01; ++j) {
67896789
const half S = ss[j*T + 0];
67906790

6791-
for (int64_t i = lane_id; i < D; i += NW) {
6791+
for (int i = lane_id; i < D; i += NW) {
67926792
dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i] / S);
67936793
}
67946794
}

0 commit comments

Comments
 (0)