Skip to content

Commit be0a0f8

Browse files
authored
vulkan: Implement grouped query attention in the coopmat2 FA shader (#12559)
When adjacent batches of Q share the same batches of K/V, batch them into the same workgroup. For example, when: dst(128,32,1,1) = FA(q(128,1,32,1), k(128,16640,8,1), v(128,16640,8,1)) previously we would run 32 workgroups computing 1 result each, now we will run 8 workgroups computing 4 results each. This doesn't directly translate to better performance (at least when you have >=32 SMs), but in a subsequent change I'll enable split_k which will scale much better with 4x fewer workgroups.
1 parent 92e3006 commit be0a0f8

File tree

2 files changed

+71
-20
lines changed

2 files changed

+71
-20
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
#define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1))
3333
#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
34+
static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
3435

3536
#define VK_VENDOR_ID_AMD 0x1002
3637
#define VK_VENDOR_ID_APPLE 0x106b
@@ -501,6 +502,8 @@ struct vk_flash_attn_push_constants {
501502
uint32_t n_head_log2;
502503
float m0;
503504
float m1;
505+
506+
uint32_t gqa_ratio;
504507
};
505508

506509
struct vk_op_push_constants {
@@ -5402,7 +5405,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
54025405
const uint32_t nbm1 = mask ? mask->nb[1] : 0;
54035406

54045407
const uint32_t D = neq0;
5405-
const uint32_t N = neq1;
5408+
uint32_t N = neq1;
54065409
const uint32_t KV = nek1;
54075410

54085411
GGML_ASSERT(ne0 == D);
@@ -5460,6 +5463,22 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
54605463
vk_pipeline pipeline = pipelines[aligned];
54615464
assert(pipeline);
54625465

5466+
uint32_t gqa_ratio = 1;
5467+
uint32_t qk_ratio = neq2 / nek2;
5468+
uint32_t workgroups_x = (uint32_t)neq1;
5469+
uint32_t workgroups_y = (uint32_t)neq2;
5470+
uint32_t workgroups_z = (uint32_t)neq3;
5471+
5472+
if (N == 1 && qk_ratio > 1 && is_pow2(qk_ratio) && gqa_ratio <= flash_attention_num_small_rows &&
5473+
qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
5474+
// grouped query attention - make the N dimension equal to gqa_ratio, reduce
5475+
// workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
5476+
// and change addressing calculations to index Q's dimension 2.
5477+
gqa_ratio = qk_ratio;
5478+
N = gqa_ratio;
5479+
workgroups_y /= N;
5480+
}
5481+
54635482
if (dryrun) {
54645483
// Request descriptor sets
54655484
ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
@@ -5549,7 +5568,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
55495568
v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
55505569
nbm1,
55515570
scale, max_bias, logit_softcap,
5552-
mask != nullptr, n_head_log2, m0, m1 };
5571+
mask != nullptr, n_head_log2, m0, m1, gqa_ratio };
55535572
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
55545573
{
55555574
vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
@@ -5558,7 +5577,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
55585577
vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
55595578
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
55605579
},
5561-
sizeof(vk_flash_attn_push_constants), &pc, { (uint32_t)neq1, (uint32_t)neq2, (uint32_t)neq3 });
5580+
sizeof(vk_flash_attn_push_constants), &pc, { workgroups_x, workgroups_y, workgroups_z });
55625581
}
55635582

55645583
static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) {

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp

Lines changed: 49 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ layout (push_constant) uniform parameter {
6161
uint32_t n_head_log2;
6262
float m0;
6363
float m1;
64+
65+
uint32_t gqa_ratio;
6466
} p;
6567

6668
layout (binding = 0) readonly buffer Q {uint8_t data_q[];};
@@ -103,6 +105,28 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele
103105
#define DECODEFUNC
104106
#endif
105107

108+
// Store the output when doing grouped query attention.
109+
// Rows index by Q's dimension 2, and the first N rows are valid.
110+
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
111+
{
112+
if (r < N && c < D) {
113+
uint32_t offset = (iq2 + r) * D + c;
114+
data_o[o_offset + offset] = D_TYPE(elem);
115+
}
116+
return elem;
117+
}
118+
119+
// Load the slope matrix, indexed by Q's dimension 2.
120+
ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
121+
{
122+
const uint32_t h = iq2 + (r & (p.gqa_ratio - 1));
123+
124+
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
125+
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
126+
127+
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
128+
}
129+
106130
void main() {
107131
#ifdef NEEDS_INIT_IQ_SHMEM
108132
init_iq_shmem(gl_WorkGroupSize);
@@ -116,7 +140,9 @@ void main() {
116140

117141
const uint32_t i = gl_WorkGroupID.x;
118142

119-
const uint32_t iq2 = gl_WorkGroupID.y;
143+
// When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
144+
// When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
145+
const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio;
120146
const uint32_t iq3 = gl_WorkGroupID.z;
121147

122148
// broadcast factors
@@ -149,8 +175,10 @@ void main() {
149175
tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D);
150176
tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D);
151177

152-
// nb?1 are already divided by the type size and are in units of elements
153-
uint32_t q_stride = p.nb01;
178+
// nb?1 are already divided by the type size and are in units of elements.
179+
// When using grouped query attention, Q is indexed by iq2, so the stride
180+
// should be nb02 (which is in bytes).
181+
uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
154182
uint32_t k_stride = p.nb11;
155183
uint32_t v_stride = p.nb21;
156184
// hint to the compiler that strides are aligned for the aligned variant of the shader
@@ -182,16 +210,11 @@ void main() {
182210
L = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
183211
M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(-1.0/0.0);
184212

185-
ACC_TYPE slope = ACC_TYPE(1.0);
213+
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> slopeMat = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(1.0);
186214

187215
// ALiBi
188216
if (p.max_bias > 0.0f) {
189-
const uint32_t h = iq2;
190-
191-
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
192-
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
193-
194-
slope = pow(base, ACC_TYPE(exph));
217+
coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2);
195218
}
196219

197220
[[dont_unroll]]
@@ -215,12 +238,16 @@ void main() {
215238
if (p.mask != 0) {
216239
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
217240
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
241+
// When using grouped query attention, all rows use the same mask.
242+
if (p.gqa_ratio > 1) {
243+
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, 0, 1);
244+
}
218245

219246
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
220247

221248
coopMatLoadTensorNV(mv, data_m, 0, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
222249

223-
S += slope*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
250+
S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
224251
}
225252

226253
// Clear padding elements to -inf, so they don't contribute to rowmax
@@ -297,13 +324,18 @@ void main() {
297324

298325
O = Ldiag*O;
299326

300-
tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV);
301-
tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D);
302-
303-
// permute dimensions
304-
tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2);
305327
uint32_t o_offset = iq3*p.ne2*p.ne1;
306328

307329
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
308-
coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, 1, 0, D), tensorViewPermute);
330+
if (p.gqa_ratio > 1) {
331+
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
332+
} else {
333+
tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV);
334+
tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D);
335+
336+
// permute dimensions
337+
tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2);
338+
339+
coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, D), tensorViewPermute);
340+
}
309341
}

0 commit comments

Comments
 (0)