Skip to content

Commit c2aa654

Browse files
committed
q5_k
q4_k q3_k q2_k q6_k multi row example
1 parent 26a8406 commit c2aa654

File tree

7 files changed

+159
-127
lines changed

7 files changed

+159
-127
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,6 @@
4444

4545
#define MAX_VK_BUFFERS 256
4646

47-
#ifndef K_QUANTS_PER_ITERATION
48-
#define K_QUANTS_PER_ITERATION 1
49-
#else
50-
static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
51-
#endif
52-
5347
#define VK_CHECK(err, msg) \
5448
do { \
5549
vk::Result err_ = (err); \
@@ -1792,11 +1786,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
17921786
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f32_f32", mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
17931787
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f32_f32", mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
17941788
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f32_f32", mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size, 1}, 1, true);
1795-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f32_f32", mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
1796-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f32_f32", mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
1797-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f32_f32", mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
1798-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f32_f32", mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
1799-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f32_f32", mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1789+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f32_f32", mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1790+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f32_f32", mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1791+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f32_f32", mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1792+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f32_f32", mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1793+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f32_f32", mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {subgroup_size_16, 2}, 1, true);
18001794
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f32_f32", mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
18011795

18021796
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f16_f32", mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
@@ -1806,11 +1800,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
18061800
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f16_f32", mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
18071801
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f16_f32", mul_mat_vec_q5_1_f16_f32_len, mul_mat_vec_q5_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
18081802
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f16_f32", mul_mat_vec_q8_0_f16_f32_len, mul_mat_vec_q8_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size, 1}, 1, true);
1809-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f16_f32", mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
1810-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f16_f32", mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
1811-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f16_f32", mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
1812-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f16_f32", mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
1813-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f16_f32", mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1803+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f16_f32", mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1804+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f16_f32", mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1805+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f16_f32", mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1806+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f16_f32", mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1807+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f16_f32", mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {subgroup_size_16, 2}, 1, true);
18141808
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f16_f32", mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size}, 1, true);
18151809

18161810
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
@@ -1820,11 +1814,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
18201814
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
18211815
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", mul_mat_vec_id_q5_1_f32_len, mul_mat_vec_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
18221816
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", mul_mat_vec_id_q8_0_f32_len, mul_mat_vec_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size, 1}, 1, true);
1823-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
1824-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
1825-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
1826-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
1827-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1817+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1818+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1819+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1820+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1821+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {subgroup_size_16, 2}, 1, true);
18281822
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
18291823

18301824
// dequant shaders

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
#extension GL_EXT_shader_16bit_storage : require
33
#extension GL_EXT_shader_8bit_storage : require
44

5-
#define K_QUANTS_PER_ITERATION 2
6-
75
#ifdef MUL_MAT_ID
86
#define EXPERT_COUNT 8
97
#endif

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33

44
#include "mul_mat_vec_base.comp"
55

6-
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
6+
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
77

8-
shared FLOAT_TYPE tmp[32];
8+
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
9+
10+
shared FLOAT_TYPE tmp[BLOCK_SIZE];
911

1012
void main() {
1113
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
@@ -20,22 +22,25 @@ void main() {
2022
const uint num_blocks_per_row = p.ncols / QUANT_K;
2123
const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
2224

23-
const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
24-
const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
25+
// 16 threads are used to process each block
26+
const uint it_size = gl_WorkGroupSize.x/16;
27+
const uint tid = gl_LocalInvocationID.x;
28+
const uint itid = tid%16; // 0...16
29+
const uint ix = tid/16;
2530

26-
const uint step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
31+
const uint step = 8;
2732

28-
const uint v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
29-
const uint v_in = tid - step*v_im; // 0...15 or 0...7
33+
const uint v_im = itid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
34+
const uint v_in = itid - step*v_im; // 0...15 or 0...7
3035

31-
const uint l0 = K_QUANTS_PER_ITERATION*v_in; // 0...15
36+
const uint l0 = 2*v_in; // 0...15
3237
const uint q_offset = 32*v_im + l0;
3338
const uint s_offset = 8*v_im;
3439
const uint y_offset = 128*v_im + l0;
3540

3641
FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp
3742

38-
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
43+
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
3944
const uint y_idx = i * QUANT_K + y_offset;
4045

4146
f16vec2 d = data_a[ib0 + i].d;
@@ -71,7 +76,7 @@ void main() {
7176

7277
FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
7378
FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
74-
[[unroll]] for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
79+
[[unroll]] for (int l = 0; l < 2; ++l) {
7580
sum1 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_lo4[0]) * FLOAT_TYPE((qs0[l] >> 0) & 3),
7681
fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_lo4[1]) * FLOAT_TYPE((qs16[l] >> 0) & 3),
7782
fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_lo4[2]) * FLOAT_TYPE((qs0[l] >> 2) & 3),
@@ -96,7 +101,7 @@ void main() {
96101

97102
// sum up partial sums and write back result
98103
barrier();
99-
[[unroll]] for (uint s = 16; s > 0; s >>= 1) {
104+
[[unroll]] for (uint s = gl_WorkGroupSize.x/2; s > 0; s >>= 1) {
100105
if (tid < s) {
101106
tmp[tid] += tmp[tid + s];
102107
}

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33

44
#include "mul_mat_vec_base.comp"
55

6-
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
6+
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
77

8-
shared FLOAT_TYPE tmp[32];
8+
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
9+
10+
shared FLOAT_TYPE tmp[BLOCK_SIZE];
911

1012
void main() {
1113
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
@@ -20,25 +22,28 @@ void main() {
2022
const uint num_blocks_per_row = p.ncols / QUANT_K;
2123
const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
2224

23-
const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
24-
const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
25+
// 16 threads are used to process each block
26+
const uint it_size = gl_WorkGroupSize.x/16;
27+
const uint tid = gl_LocalInvocationID.x;
28+
const uint itid = tid%16; // 0...16
29+
const uint ix = tid/16;
2530

26-
const uint step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
31+
const uint step = 8;
2732

28-
const uint v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
29-
const uint v_in = tid - step*v_im; // 0...15 or 0...7
33+
const uint v_im = itid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
34+
const uint v_in = itid - step*v_im; // 0...15 or 0...7
3035

3136
const uint8_t m = uint8_t(1 << (4 * v_im));
3237

33-
const uint l0 = K_QUANTS_PER_ITERATION*v_in; // 0...15
38+
const uint l0 = 2*v_in; // 0...15
3439
const uint q_offset = 32*v_im + l0;
3540
const uint y_offset = 128*v_im + l0;
3641

3742
FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp
3843

3944
const uint s_shift = 4 * v_im;
4045

41-
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
46+
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
4247
const uint y_idx = i * QUANT_K + y_offset;
4348

4449
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
@@ -66,7 +71,7 @@ void main() {
6671
u8vec2 s10 = unpack8(s10_16);
6772

6873
FLOAT_TYPE sum = FLOAT_TYPE(0.0);
69-
for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
74+
[[unroll]] for (int l = 0; l < 2; ++l) {
7075
sum = fma(FLOAT_TYPE(b0[l]) * FLOAT_TYPE(int8_t(((s0[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 0)) != 0) ? 0 : 4)),
7176
fma(FLOAT_TYPE(b32[l]) * FLOAT_TYPE(int8_t(((s2[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4)),
7277
fma(FLOAT_TYPE(b64[l]) * FLOAT_TYPE(int8_t(((s4[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4)),
@@ -83,7 +88,7 @@ void main() {
8388

8489
// sum up partial sums and write back result
8590
barrier();
86-
[[unroll]] for (uint s = 16; s > 0; s >>= 1) {
91+
[[unroll]] for (uint s = gl_WorkGroupSize.x/2; s > 0; s >>= 1) {
8792
if (tid < s) {
8893
tmp[tid] += tmp[tid + s];
8994
}

0 commit comments

Comments
 (0)