Skip to content

Commit dafae66

Browse files
authored
vulkan: dynamic subgroup size for the remaining k quants (#10745)
* q5_k q4_k q3_k q2_k q6_k multi row example * revert as multi row isnt faster for k quants
1 parent ae4b922 commit dafae66

File tree

6 files changed

+72
-61
lines changed

6 files changed

+72
-61
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,6 @@
4444

4545
#define MAX_VK_BUFFERS 256
4646

47-
#ifndef K_QUANTS_PER_ITERATION
48-
#define K_QUANTS_PER_ITERATION 1
49-
#else
50-
static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
51-
#endif
52-
5347
#define VK_CHECK(err, msg) \
5448
do { \
5549
vk::Result err_ = (err); \
@@ -1792,10 +1786,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
17921786
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f32_f32", mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
17931787
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f32_f32", mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
17941788
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f32_f32", mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size, 1}, 1, true);
1795-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f32_f32", mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
1796-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f32_f32", mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
1797-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f32_f32", mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
1798-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f32_f32", mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
1789+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f32_f32", mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1790+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f32_f32", mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1791+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f32_f32", mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1792+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f32_f32", mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
17991793
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f32_f32", mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
18001794
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f32_f32", mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
18011795

@@ -1806,10 +1800,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
18061800
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f16_f32", mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
18071801
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f16_f32", mul_mat_vec_q5_1_f16_f32_len, mul_mat_vec_q5_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
18081802
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f16_f32", mul_mat_vec_q8_0_f16_f32_len, mul_mat_vec_q8_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size, 1}, 1, true);
1809-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f16_f32", mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
1810-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f16_f32", mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
1811-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f16_f32", mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
1812-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f16_f32", mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
1803+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f16_f32", mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1804+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f16_f32", mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1805+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f16_f32", mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1806+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f16_f32", mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
18131807
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f16_f32", mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
18141808
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f16_f32", mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size}, 1, true);
18151809

@@ -1820,10 +1814,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
18201814
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
18211815
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", mul_mat_vec_id_q5_1_f32_len, mul_mat_vec_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
18221816
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", mul_mat_vec_id_q8_0_f32_len, mul_mat_vec_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size, 1}, 1, true);
1823-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
1824-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
1825-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
1826-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
1817+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1818+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1819+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1820+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
18271821
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
18281822
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
18291823

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
#extension GL_EXT_shader_16bit_storage : require
33
#extension GL_EXT_shader_8bit_storage : require
44

5-
#define K_QUANTS_PER_ITERATION 2
6-
75
#ifdef MUL_MAT_ID
86
#define EXPERT_COUNT 8
97
#endif

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33

44
#include "mul_mat_vec_base.comp"
55

6-
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
6+
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
77

8-
shared FLOAT_TYPE tmp[32];
8+
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
9+
10+
shared FLOAT_TYPE tmp[BLOCK_SIZE];
911

1012
void main() {
1113
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
@@ -20,22 +22,25 @@ void main() {
2022
const uint num_blocks_per_row = p.ncols / QUANT_K;
2123
const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
2224

23-
const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
24-
const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
25+
// 16 threads are used to process each block
26+
const uint it_size = gl_WorkGroupSize.x/16;
27+
const uint tid = gl_LocalInvocationID.x;
28+
const uint itid = tid%16; // 0...16
29+
const uint ix = tid/16;
2530

26-
const uint step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
31+
const uint step = 8;
2732

28-
const uint v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
29-
const uint v_in = tid - step*v_im; // 0...15 or 0...7
33+
const uint v_im = itid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
34+
const uint v_in = itid - step*v_im; // 0...15 or 0...7
3035

31-
const uint l0 = K_QUANTS_PER_ITERATION*v_in; // 0...15
36+
const uint l0 = 2*v_in; // 0...15
3237
const uint q_offset = 32*v_im + l0;
3338
const uint s_offset = 8*v_im;
3439
const uint y_offset = 128*v_im + l0;
3540

3641
FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp
3742

38-
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
43+
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
3944
const uint y_idx = i * QUANT_K + y_offset;
4045

4146
f16vec2 d = data_a[ib0 + i].d;
@@ -71,7 +76,7 @@ void main() {
7176

7277
FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
7378
FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
74-
[[unroll]] for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
79+
[[unroll]] for (int l = 0; l < 2; ++l) {
7580
sum1 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_lo4[0]) * FLOAT_TYPE((qs0[l] >> 0) & 3),
7681
fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_lo4[1]) * FLOAT_TYPE((qs16[l] >> 0) & 3),
7782
fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_lo4[2]) * FLOAT_TYPE((qs0[l] >> 2) & 3),
@@ -96,7 +101,7 @@ void main() {
96101

97102
// sum up partial sums and write back result
98103
barrier();
99-
[[unroll]] for (uint s = 16; s > 0; s >>= 1) {
104+
[[unroll]] for (uint s = gl_WorkGroupSize.x/2; s > 0; s >>= 1) {
100105
if (tid < s) {
101106
tmp[tid] += tmp[tid + s];
102107
}

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33

44
#include "mul_mat_vec_base.comp"
55

6-
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
6+
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
77

8-
shared FLOAT_TYPE tmp[32];
8+
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
9+
10+
shared FLOAT_TYPE tmp[BLOCK_SIZE];
911

1012
void main() {
1113
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
@@ -20,25 +22,28 @@ void main() {
2022
const uint num_blocks_per_row = p.ncols / QUANT_K;
2123
const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
2224

23-
const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
24-
const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
25+
// 16 threads are used to process each block
26+
const uint it_size = gl_WorkGroupSize.x/16;
27+
const uint tid = gl_LocalInvocationID.x;
28+
const uint itid = tid%16; // 0...16
29+
const uint ix = tid/16;
2530

26-
const uint step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
31+
const uint step = 8;
2732

28-
const uint v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
29-
const uint v_in = tid - step*v_im; // 0...15 or 0...7
33+
const uint v_im = itid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
34+
const uint v_in = itid - step*v_im; // 0...15 or 0...7
3035

3136
const uint8_t m = uint8_t(1 << (4 * v_im));
3237

33-
const uint l0 = K_QUANTS_PER_ITERATION*v_in; // 0...15
38+
const uint l0 = 2*v_in; // 0...15
3439
const uint q_offset = 32*v_im + l0;
3540
const uint y_offset = 128*v_im + l0;
3641

3742
FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp
3843

3944
const uint s_shift = 4 * v_im;
4045

41-
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
46+
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
4247
const uint y_idx = i * QUANT_K + y_offset;
4348

4449
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
@@ -66,7 +71,7 @@ void main() {
6671
u8vec2 s10 = unpack8(s10_16);
6772

6873
FLOAT_TYPE sum = FLOAT_TYPE(0.0);
69-
for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
74+
[[unroll]] for (int l = 0; l < 2; ++l) {
7075
sum = fma(FLOAT_TYPE(b0[l]) * FLOAT_TYPE(int8_t(((s0[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 0)) != 0) ? 0 : 4)),
7176
fma(FLOAT_TYPE(b32[l]) * FLOAT_TYPE(int8_t(((s2[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4)),
7277
fma(FLOAT_TYPE(b64[l]) * FLOAT_TYPE(int8_t(((s4[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4)),
@@ -83,7 +88,7 @@ void main() {
8388

8489
// sum up partial sums and write back result
8590
barrier();
86-
[[unroll]] for (uint s = 16; s > 0; s >>= 1) {
91+
[[unroll]] for (uint s = gl_WorkGroupSize.x/2; s > 0; s >>= 1) {
8792
if (tid < s) {
8893
tmp[tid] += tmp[tid + s];
8994
}

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44

55
#include "mul_mat_vec_base.comp"
66

7-
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
7+
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
88

9-
shared FLOAT_TYPE tmp[32];
9+
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
10+
11+
shared FLOAT_TYPE tmp[BLOCK_SIZE];
1012

11-
// This shader assumes K_QUANTS_PER_ITERATION == 2 for alignment of loads
1213
void main() {
1314
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
1415

@@ -22,14 +23,17 @@ void main() {
2223
const uint num_blocks_per_row = p.ncols / QUANT_K;
2324
const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
2425

25-
const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
26-
const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
26+
// 16 threads are used to process each block
27+
const uint it_size = gl_WorkGroupSize.x/16;
28+
const uint tid = gl_LocalInvocationID.x;
29+
const uint itid = tid%16; // 0...16
30+
const uint ix = tid/16;
2731

28-
const uint step = 8/K_QUANTS_PER_ITERATION; // 8 or 4
32+
const uint step = 4;
2933

30-
const uint il = tid/step; // 0...3
31-
const uint ir = tid - step*il; // 0...7 or 0...3
32-
const uint n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4
34+
const uint il = itid/step; // 0...3
35+
const uint ir = itid - step*il; // 0...7 or 0...3
36+
const uint n = 4;
3337

3438
const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
3539
const uint v_in = il % 2;
@@ -40,7 +44,7 @@ void main() {
4044

4145
FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp
4246

43-
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
47+
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
4448
const uint y1_idx = i * QUANT_K + y_offset;
4549
const uint y2_idx = y1_idx + 128;
4650

@@ -115,7 +119,7 @@ void main() {
115119

116120
// sum up partial sums and write back result
117121
barrier();
118-
[[unroll]] for (uint s = 16; s > 0; s >>= 1) {
122+
[[unroll]] for (uint s = gl_WorkGroupSize.x/2; s > 0; s >>= 1) {
119123
if (tid < s) {
120124
tmp[tid] += tmp[tid + s];
121125
}

0 commit comments

Comments
 (0)