Skip to content

Commit 1f8a592

Browse files
committed
cuda : make loops use the same loop values
Thanks Johannes again for the tip
1 parent 7c34655 commit 1f8a592

File tree

2 files changed

+37
-8
lines changed

2 files changed

+37
-8
lines changed

ggml-cuda.cu

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6462,10 +6462,20 @@ static __global__ void flash_attn_ext_f16(
64626462
half16x16_acc lo[Q16][D16];
64636463

64646464
// load heads from Q to shared memory
6465-
for (int j = warp_id; j < Q; j += num_warps) {
6465+
for (int j0 = 0; j0 < Q; j0 += num_warps) {
6466+
const int j = j0 + warp_id;
6467+
if (j >= Q) {
6468+
break;
6469+
}
6470+
64666471
const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
64676472

6468-
for (int i = lane_id; i < D2; i += NW) {
6473+
for (int i0 = 0; i0 < D2; i0 += NW) {
6474+
const int i = i0 + lane_id;
6475+
if (i >= D2) {
6476+
break;
6477+
}
6478+
64696479
if (iq1 + j < ne01) {
64706480
sq2[j*T2 + i] = __float22half2_rn(q2[i]);
64716481
} else {
@@ -6485,7 +6495,12 @@ static __global__ void flash_attn_ext_f16(
64856495

64866496
// zero out shared memory SH
64876497
for (int j = 0; j < Q; ++j) {
6488-
for (int i = lane_id; i < SH; i += NW) {
6498+
for (int i0 = 0; i0 < SH; i0 += NW) {
6499+
const int i = i0 + lane_id;
6500+
if (i >= SH) {
6501+
break;
6502+
}
6503+
64896504
ss[j*T + i] = 0.0;
64906505
}
64916506
}
@@ -6544,7 +6559,12 @@ static __global__ void flash_attn_ext_f16(
65446559

65456560
// loop over the KV cache
65466561
// each simdgroup handles blocks of Q rows and C columns
6547-
for (int ic = C*warp_id; ic < ne11; ic += C*num_warps) {
6562+
for (int ic0 = 0; ic0 < ne11; ic0 += C*num_warps) {
6563+
const int ic = ic0 + warp_id*C;
6564+
if (ic >= ne11) {
6565+
break;
6566+
}
6567+
65486568
// Q*K^T
65496569
{
65506570
for (int cc = 0; cc < C/16; ++cc) {
@@ -6614,7 +6634,9 @@ static __global__ void flash_attn_ext_f16(
66146634
for (int j = 0; j < Q; ++j) {
66156635
const half m = M[j];
66166636

6617-
for (int p = lane_id; p < C; p += NW) {
6637+
for (int p0 = 0; p0 < C; p0 += NW) {
6638+
const int p = p0 + lane_id;
6639+
66186640
const half s = ss[j*T + p];
66196641

66206642
smax = __hmax(smax, s);
@@ -6633,7 +6655,9 @@ static __global__ void flash_attn_ext_f16(
66336655
// local sum
66346656
half ls = 0.0f;
66356657

6636-
for (int p = lane_id; p < C; p += NW) {
6658+
for (int p0 = 0; p0 < C; p0 += NW) {
6659+
const int p = p0 + lane_id;
6660+
66376661
const half s = ss[j*T + p];
66386662

66396663
const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]);
@@ -6788,7 +6812,12 @@ static __global__ void flash_attn_ext_f16(
67886812
for (int j = 0; j < Q && iq1 + j < ne01; ++j) {
67896813
const half S = ss[j*T + 0];
67906814

6791-
for (int i = lane_id; i < D; i += NW) {
6815+
for (int i0 = 0; i0 < D; i0 += NW) {
6816+
const int i = i0 + lane_id;
6817+
if (i >= D) {
6818+
break;
6819+
}
6820+
67926821
dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i] / S);
67936822
}
67946823
}

tests/test-backend-ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2210,7 +2210,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
22102210
test_cases.emplace_back(new test_leaky_relu());
22112211

22122212
#if 1
2213-
for (int hs : { 64, 80, 128, }) {
2213+
for (int hs : { 128, 64, 80, }) {
22142214
for (int nh : { 32, }) {
22152215
for (int kv : { 512, 1024, 2048, 4096, }) {
22162216
for (int nb : { 1, 2, 4, 8, 512, 1024, 2048, }) {

0 commit comments

Comments
 (0)