Skip to content

Commit 891c639

Browse files
authored
vulkan: Pad N dimension of B matrix for coopmat2 perf, to avoid bounds checking (#12273)
* vulkan: Pad N dimension of B matrix for coopmat2 perf, to avoid bounds checking
1 parent 2f21123 commit 891c639

File tree

2 files changed

+34
-22
lines changed

2 files changed

+34
-22
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
#include "ggml-vulkan-shaders.hpp"
3131

32+
#define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1))
3233
#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
3334

3435
#define VK_VENDOR_ID_AMD 0x1002
@@ -368,6 +369,7 @@ struct vk_mat_mat_push_constants {
368369
uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
369370
uint32_t k_split;
370371
uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3;
372+
uint32_t padded_N;
371373
};
372374
struct vk_mat_vec_push_constants {
373375
uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
@@ -380,6 +382,7 @@ struct vk_mat_mat_id_push_constants {
380382
uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
381383
uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
382384
uint32_t nei0; uint32_t nei1; uint32_t nbi1; uint32_t ne11;
385+
uint32_t padded_N;
383386
};
384387
struct vk_mat_vec_id_push_constants {
385388
uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
@@ -3882,18 +3885,19 @@ static void ggml_vk_matmul(
38823885
vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& split_k_buffer,
38833886
uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,
38843887
uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
3885-
uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3) {
3888+
uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3,
3889+
uint32_t padded_n) {
38863890
VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ")");
38873891
ggml_vk_sync_buffers(subctx);
38883892
if (split_k == 1) {
3889-
const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3 };
3893+
const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n };
38903894
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, sizeof(vk_mat_mat_push_constants), &pc, { m, n, batch });
38913895
return;
38923896
}
38933897

38943898
GGML_ASSERT(batch_stride_d == m * n);
38953899

3896-
const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3 };
3900+
const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3, padded_n };
38973901
// Make sure enough workgroups get assigned for split k to work
38983902
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, sizeof(vk_mat_mat_push_constants), &pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch });
38993903
ggml_vk_sync_buffers(subctx);
@@ -3937,14 +3941,15 @@ static void ggml_vk_matmul_id(
39373941
vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids,
39383942
uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,
39393943
uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
3940-
uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11) {
3944+
uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11,
3945+
uint32_t padded_n) {
39413946
VK_LOG_DEBUG("ggml_vk_matmul_id(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), ids: (" << ids.buffer->buffer << ", " << ids.offset << ", " << ids.size << "), " <<
39423947
"m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", " <<
39433948
"batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", " <<
39443949
"n_as: " << n_as << ", nei0: " << nei0 << ", nei1: " << nei1 << ", nbi1: " << nbi1 << ", ne11: " << ne11 << ")");
39453950
ggml_vk_sync_buffers(subctx);
39463951
const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d,
3947-
nei0, nei1, nbi1, ne11 };
3952+
nei0, nei1, nbi1, ne11, padded_n };
39483953
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids }, sizeof(vk_mat_mat_id_push_constants), &pc, { m, nei1, n_as });
39493954
}
39503955

@@ -4106,15 +4111,17 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
41064111
// Not implemented
41074112
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
41084113

4109-
const int x_ne = ne01 * ne00;
4110-
const int y_ne = ne11 * ne10;
4111-
const int d_ne = ne11 * ne01;
4112-
41134114
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? GGML_TYPE_F16 : src0->type));
41144115
const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8;
41154116

41164117
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type);
41174118

4119+
// Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
4120+
uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11;
4121+
const int x_ne = ne01 * ne00;
4122+
const int y_ne = padded_n * ne10;
4123+
const int d_ne = ne11 * ne01;
4124+
41184125
const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, pipeline);
41194126

41204127
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
@@ -4237,7 +4244,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
42374244
{ d_D, d_buf_offset, d_sz * ne12 * ne13 }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k },
42384245
ne01, ne11, ne10,
42394246
ne10, ne10, ne01, stride_batch_x, stride_batch_y, ne20*ne21,
4240-
split_k, ne12*ne13, ne02, ne12, r2, r3
4247+
split_k, ne12*ne13, ne02, ne12, r2, r3, padded_n
42414248
); // NOLINT
42424249
}
42434250

@@ -4688,15 +4695,17 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
46884695
// Not implemented
46894696
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
46904697

4691-
const uint64_t x_ne = ne01 * ne00;
4692-
const uint64_t y_ne = ne11 * ne10;
4693-
const uint64_t d_ne = ne21 * ne20;
4694-
46954698
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? GGML_TYPE_F16 : src0->type));
46964699
const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8;
46974700

46984701
vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type);
46994702

4703+
// Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
4704+
uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11;
4705+
const uint64_t x_ne = ne01 * ne00;
4706+
const uint64_t y_ne = padded_n * ne10;
4707+
const uint64_t d_ne = ne21 * ne20;
4708+
47004709
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
47014710
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
47024711
const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
@@ -4815,7 +4824,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
48154824
{ d_D, d_buf_offset, d_sz * ne22 * ne23 }, { d_ids, ids_buf_offset, ids_sz },
48164825
ne01, ne21, ne10, ne10, ne10, ne01,
48174826
stride_batch_x, stride_batch_y, ne20*ne21,
4818-
n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11
4827+
n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11, padded_n
48194828
); // NOLINT
48204829
}
48214830

@@ -6775,7 +6784,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
67756784
ctx, subctx, p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), ggml_vk_subbuffer(ctx->prealloc_split_k),
67766785
m, n, k,
67776786
k, k, m, k*m, k*n, m*n,
6778-
split_k, batch, batch, batch, 1, 1
6787+
split_k, batch, batch, batch, 1, 1, n
67796788
);
67806789
}
67816790
ggml_vk_ctx_end(subctx);
@@ -7120,7 +7129,7 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
71207129
ctx, subctx, p, ggml_vk_subbuffer(qx_buf), ggml_vk_subbuffer(y_buf), ggml_vk_subbuffer(d_buf), ggml_vk_subbuffer(ctx->prealloc_split_k),
71217130
m, n, k,
71227131
k, k, m, k*m, k*n, m*n,
7123-
split_k, batch, batch, batch, 1, 1
7132+
split_k, batch, batch, batch, 1, 1, n
71247133
);
71257134
}
71267135
ggml_vk_ctx_end(subctx);

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ layout (push_constant) uniform parameter
4848
uint broadcast2;
4949
uint broadcast3;
5050
#endif
51+
// N dimension for the B matrix can be >= p.N
52+
uint padded_N;
5153
} p;
5254

5355

@@ -202,18 +204,19 @@ void main() {
202204
#endif
203205

204206
// Use end_k rather than p.K as the dimension because that's what
205-
// we need to bound check against when using split_k
207+
// we need to bound check against when using split_k.
208+
// Bounds check B against padded_N, but bounds check D against N.
206209
tensorLayoutA = setTensorLayoutDimensionNV(tensorLayoutA, p.M, end_k);
207-
tensorLayoutB = setTensorLayoutDimensionNV(tensorLayoutB, p.N, end_k);
210+
tensorLayoutB = setTensorLayoutDimensionNV(tensorLayoutB, p.padded_N, end_k);
208211
tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.N, p.M);
209212
tensorLayoutAClamp = setTensorLayoutDimensionNV(tensorLayoutAClamp, p.M, end_k);
210-
tensorLayoutBClamp = setTensorLayoutDimensionNV(tensorLayoutBClamp, p.N, end_k);
213+
tensorLayoutBClamp = setTensorLayoutDimensionNV(tensorLayoutBClamp, p.padded_N, end_k);
211214

212215
tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0);
213216

214217
#if !defined(MUL_MAT_ID)
215218
// Detect a fast path where all loads are entirely in bounds and no clamping is required
216-
if ((ir + 1) * BM <= p.M && (ic + 1) * BN <= p.N && (start_k % BK) == 0 && (end_k % BK) == 0 &&
219+
if ((ir + 1) * BM <= p.M && (ic + 1) * BN <= p.padded_N && (start_k % BK) == 0 && (end_k % BK) == 0 &&
217220
#if QUANT_K == 1
218221
(stride_a % 8) == 0 &&
219222
#endif
@@ -263,7 +266,7 @@ void main() {
263266
#ifdef MUL_MAT_ID
264267
bool unclampedB = true;
265268
#else
266-
bool unclampedB = (ic + 1) * BN <= p.N && block_k + BK <= end_k && (block_k % 8) == 0;
269+
bool unclampedB = (ic + 1) * BN <= p.padded_N && block_k + BK <= end_k && (block_k % 8) == 0;
267270
#endif
268271
if (unclampedA && unclampedB) {
269272
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA);

0 commit comments

Comments
 (0)