29
29
30
30
#include " ggml-vulkan-shaders.hpp"
31
31
32
+ #define ROUNDUP_POW2 (M, N ) (((M) + (N) - 1 ) & ~((N) - 1 ))
32
33
#define CEIL_DIV (M, N ) (((M) + (N)-1 ) / (N))
33
34
34
35
#define VK_VENDOR_ID_AMD 0x1002
@@ -368,6 +369,7 @@ struct vk_mat_mat_push_constants {
368
369
uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
369
370
uint32_t k_split;
370
371
uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3;
372
+ uint32_t padded_N;
371
373
};
372
374
struct vk_mat_vec_push_constants {
373
375
uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
@@ -380,6 +382,7 @@ struct vk_mat_mat_id_push_constants {
380
382
uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
381
383
uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
382
384
uint32_t nei0; uint32_t nei1; uint32_t nbi1; uint32_t ne11;
385
+ uint32_t padded_N;
383
386
};
384
387
struct vk_mat_vec_id_push_constants {
385
388
uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
@@ -3882,18 +3885,19 @@ static void ggml_vk_matmul(
3882
3885
vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& split_k_buffer,
3883
3886
uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,
3884
3887
uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
3885
- uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3) {
3888
+ uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3,
3889
+ uint32_t padded_n) {
3886
3890
VK_LOG_DEBUG (" ggml_vk_matmul(a: (" << a.buffer ->buffer << " , " << a.offset << " , " << a.size << " ), b: (" << b.buffer ->buffer << " , " << b.offset << " , " << b.size << " ), d: (" << d.buffer ->buffer << " , " << d.offset << " , " << d.size << " ), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer ->buffer : VK_NULL_HANDLE) << " , " << split_k_buffer.offset << " , " << split_k_buffer.size << " ), m: " << m << " , n: " << n << " , k: " << k << " , stride_a: " << stride_a << " , stride_b: " << stride_b << " , stride_d: " << stride_d << " , batch_stride_a: " << batch_stride_a << " , batch_stride_b: " << batch_stride_b << " , batch_stride_d: " << batch_stride_d << " , split_k: " << split_k << " , batch: " << batch << " , ne02: " << ne02 << " , ne12: " << ne12 << " , broadcast2: " << broadcast2 << " , broadcast3: " << broadcast3 << " )" );
3887
3891
ggml_vk_sync_buffers (subctx);
3888
3892
if (split_k == 1 ) {
3889
- const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3 };
3893
+ const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n };
3890
3894
ggml_vk_dispatch_pipeline (ctx, subctx, pipeline, { a, b, d }, sizeof (vk_mat_mat_push_constants), &pc, { m, n, batch });
3891
3895
return ;
3892
3896
}
3893
3897
3894
3898
GGML_ASSERT (batch_stride_d == m * n);
3895
3899
3896
- const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, CEIL_DIV (k, split_k), ne02, ne12, broadcast2, broadcast3 };
3900
+ const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, CEIL_DIV (k, split_k), ne02, ne12, broadcast2, broadcast3, padded_n };
3897
3901
// Make sure enough workgroups get assigned for split k to work
3898
3902
ggml_vk_dispatch_pipeline (ctx, subctx, pipeline, { a, b, split_k_buffer }, sizeof (vk_mat_mat_push_constants), &pc1, { (CEIL_DIV (m, pipeline->wg_denoms [0 ]) * pipeline->wg_denoms [0 ]) * split_k, n, batch });
3899
3903
ggml_vk_sync_buffers (subctx);
@@ -3937,14 +3941,15 @@ static void ggml_vk_matmul_id(
3937
3941
vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids,
3938
3942
uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,
3939
3943
uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
3940
- uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11) {
3944
+ uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11,
3945
+ uint32_t padded_n) {
3941
3946
VK_LOG_DEBUG (" ggml_vk_matmul_id(a: (" << a.buffer ->buffer << " , " << a.offset << " , " << a.size << " ), b: (" << b.buffer ->buffer << " , " << b.offset << " , " << b.size << " ), d: (" << d.buffer ->buffer << " , " << d.offset << " , " << d.size << " ), ids: (" << ids.buffer ->buffer << " , " << ids.offset << " , " << ids.size << " ), " <<
3942
3947
" m: " << m << " , n: " << n << " , k: " << k << " , stride_a: " << stride_a << " , stride_b: " << stride_b << " , stride_d: " << stride_d << " , " <<
3943
3948
" batch_stride_a: " << batch_stride_a << " , batch_stride_b: " << batch_stride_b << " , batch_stride_d: " << batch_stride_d << " , " <<
3944
3949
" n_as: " << n_as << " , nei0: " << nei0 << " , nei1: " << nei1 << " , nbi1: " << nbi1 << " , ne11: " << ne11 << " )" );
3945
3950
ggml_vk_sync_buffers (subctx);
3946
3951
const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d,
3947
- nei0, nei1, nbi1, ne11 };
3952
+ nei0, nei1, nbi1, ne11, padded_n };
3948
3953
ggml_vk_dispatch_pipeline (ctx, subctx, pipeline, { a, b, d, ids }, sizeof (vk_mat_mat_id_push_constants), &pc, { m, nei1, n_as });
3949
3954
}
3950
3955
@@ -4106,15 +4111,17 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4106
4111
// Not implemented
4107
4112
GGML_ASSERT (y_non_contig || !qy_needs_dequant); // NOLINT
4108
4113
4109
- const int x_ne = ne01 * ne00;
4110
- const int y_ne = ne11 * ne10;
4111
- const int d_ne = ne11 * ne01;
4112
-
4113
4114
const uint32_t kpad = ggml_vk_align_size (ne10, ggml_vk_guess_matmul_pipeline_align (ctx, mmp, ne01, ne11, qx_needs_dequant ? GGML_TYPE_F16 : src0->type ));
4114
4115
const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8 ;
4115
4116
4116
4117
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline (ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type );
4117
4118
4119
+ // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
4120
+ uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2 (ne11, pipeline->wg_denoms [1 ]) :ne11;
4121
+ const int x_ne = ne01 * ne00;
4122
+ const int y_ne = padded_n * ne10;
4123
+ const int d_ne = ne11 * ne01;
4124
+
4118
4125
const uint32_t split_k = ggml_vk_guess_split_k (ctx, ne01, ne11, ne10, pipeline);
4119
4126
4120
4127
const uint64_t qx_sz = ggml_type_size (src0->type ) * x_ne / ggml_blck_size (src0->type );
@@ -4237,7 +4244,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4237
4244
{ d_D, d_buf_offset, d_sz * ne12 * ne13 }, { ctx->prealloc_split_k , 0 , d_sz * ne12 * ne13 * split_k },
4238
4245
ne01, ne11, ne10,
4239
4246
ne10, ne10, ne01, stride_batch_x, stride_batch_y, ne20*ne21,
4240
- split_k, ne12*ne13, ne02, ne12, r2, r3
4247
+ split_k, ne12*ne13, ne02, ne12, r2, r3, padded_n
4241
4248
); // NOLINT
4242
4249
}
4243
4250
@@ -4688,15 +4695,17 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
4688
4695
// Not implemented
4689
4696
GGML_ASSERT (y_non_contig || !qy_needs_dequant); // NOLINT
4690
4697
4691
- const uint64_t x_ne = ne01 * ne00;
4692
- const uint64_t y_ne = ne11 * ne10;
4693
- const uint64_t d_ne = ne21 * ne20;
4694
-
4695
4698
const uint32_t kpad = ggml_vk_align_size (ne10, ggml_vk_guess_matmul_id_pipeline_align (ctx, mmp, ne01, nei1, qx_needs_dequant ? GGML_TYPE_F16 : src0->type ));
4696
4699
const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8 ;
4697
4700
4698
4701
vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline (ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type );
4699
4702
4703
+ // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
4704
+ uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2 (ne11, pipeline->wg_denoms [1 ]) :ne11;
4705
+ const uint64_t x_ne = ne01 * ne00;
4706
+ const uint64_t y_ne = padded_n * ne10;
4707
+ const uint64_t d_ne = ne21 * ne20;
4708
+
4700
4709
const uint64_t qx_sz = ggml_type_size (src0->type ) * x_ne / ggml_blck_size (src0->type );
4701
4710
const uint64_t qy_sz = ggml_type_size (src1->type ) * y_ne / ggml_blck_size (src1->type );
4702
4711
const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof (ggml_fp16_t ) * x_ne;
@@ -4815,7 +4824,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
4815
4824
{ d_D, d_buf_offset, d_sz * ne22 * ne23 }, { d_ids, ids_buf_offset, ids_sz },
4816
4825
ne01, ne21, ne10, ne10, ne10, ne01,
4817
4826
stride_batch_x, stride_batch_y, ne20*ne21,
4818
- n_as, nei0, nei1, nbi1 / ggml_type_size (ids->type ), ne11
4827
+ n_as, nei0, nei1, nbi1 / ggml_type_size (ids->type ), ne11, padded_n
4819
4828
); // NOLINT
4820
4829
}
4821
4830
@@ -6775,7 +6784,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
6775
6784
ctx, subctx, p, ggml_vk_subbuffer (d_X), ggml_vk_subbuffer (d_Y), ggml_vk_subbuffer (d_D), ggml_vk_subbuffer (ctx->prealloc_split_k ),
6776
6785
m, n, k,
6777
6786
k, k, m, k*m, k*n, m*n,
6778
- split_k, batch, batch, batch, 1 , 1
6787
+ split_k, batch, batch, batch, 1 , 1 , n
6779
6788
);
6780
6789
}
6781
6790
ggml_vk_ctx_end (subctx);
@@ -7120,7 +7129,7 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
7120
7129
ctx, subctx, p, ggml_vk_subbuffer (qx_buf), ggml_vk_subbuffer (y_buf), ggml_vk_subbuffer (d_buf), ggml_vk_subbuffer (ctx->prealloc_split_k ),
7121
7130
m, n, k,
7122
7131
k, k, m, k*m, k*n, m*n,
7123
- split_k, batch, batch, batch, 1 , 1
7132
+ split_k, batch, batch, batch, 1 , 1 , n
7124
7133
);
7125
7134
}
7126
7135
ggml_vk_ctx_end (subctx);
0 commit comments