Skip to content

Commit 97c6d27

Browse files
committed
Add q4_0 x q8_1 matrix matrix multiplication support
1 parent dccf084 commit 97c6d27

File tree

5 files changed

+99
-39
lines changed

5 files changed

+99
-39
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1897,6 +1897,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
18971897
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
18981898
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
18991899

1900+
CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
19001901
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
19011902

19021903
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
@@ -3161,6 +3162,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
31613162
// MMQ
31623163
if (src1_type == GGML_TYPE_Q8_1) {
31633164
switch (src0_type) {
3165+
case GGML_TYPE_Q4_0:
31643166
case GGML_TYPE_Q8_0:
31653167
break;
31663168
default:
@@ -3893,8 +3895,8 @@ static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int
38933895
return split_k;
38943896
}
38953897

3896-
static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) {
3897-
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
3898+
static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type, ggml_type src1_type) {
3899+
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
38983900

38993901
if (ctx->device->coopmat2) {
39003902
if ((ctx->device->mul_mat_l[src0_type] && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) {
@@ -3906,7 +3908,7 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
39063908
return aligned ? mmp->a_s : mmp->s;
39073909
}
39083910

3909-
if ((ctx->device->mul_mat_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_l[src0_type])) {
3911+
if ((ctx->device->mul_mat_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_l[src0_type]) || src1_type == GGML_TYPE_Q8_1) {
39103912
return aligned ? mmp->a_s : mmp->s;
39113913
}
39123914
if ((ctx->device->mul_mat_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_l[src0_type]) {
@@ -3915,9 +3917,9 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
39153917
return aligned ? mmp->a_l : mmp->l;
39163918
}
39173919

3918-
static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) {
3919-
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ")");
3920-
return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true, src0_type)->align;
3920+
static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type, ggml_type src1_type) {
3921+
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
3922+
return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true, src0_type, src1_type)->align;
39213923
}
39223924

39233925
static void ggml_vk_matmul(
@@ -4177,10 +4179,10 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
41774179
const int y_ne = ne11 * ne10;
41784180
const int d_ne = ne11 * ne01;
41794181

4180-
const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? GGML_TYPE_F16 : src0->type));
4182+
const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? GGML_TYPE_F16 : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type)));
41814183
const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && ne11 > 8;
41824184

4183-
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type);
4185+
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type));
41844186

41854187
const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, pipeline);
41864188

@@ -7306,8 +7308,8 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
73067308
ggml_vk_quantize_data(x, qx, x_ne, quant);
73077309

73087310
for (size_t i = 0; i < y_ne; i++) {
7309-
y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
7310-
// y[i] = (i % k == i / k) ? 1.0f : 0.0f;
7311+
// y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
7312+
y[i] = (i % k == i / k) ? 1.0f : 0.0f;
73117313
}
73127314

73137315
ggml_pipeline_request_descriptor_sets(ctx->device, p, num_it);
@@ -7481,9 +7483,9 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
74817483
};
74827484
const size_t num_it = 100;
74837485

7484-
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q8_0, true);
7485-
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q8_0, true);
7486-
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q8_0, true);
7486+
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q4_0, true);
7487+
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q4_0, true);
7488+
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q4_0, true);
74877489

74887490
abort();
74897491

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
2+
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
3+
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
4+
5+
#include "types.comp"
6+
7+
// Each iqs value maps to a 32-bit integer
8+
9+
#if defined(DATA_A_Q4_0)
10+
i32vec2 repack(uint ib, uint iqs) {
11+
const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 + 1],
12+
data_a[ib].qs[iqs * 2 ]);
13+
const uint32_t vui = pack32(quants);
14+
return i32vec2(pack32(i8vec4(i16vec4(unpack8( vui & 0x0F0F0F0F)) - int16_t(8))),
15+
pack32(i8vec4(i16vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) - int16_t(8))));
16+
}
17+
#endif
18+
19+
#if defined(DATA_A_Q8_0)
20+
int32_t repack(uint ib, uint iqs) {
21+
const int16_t v0 = data_a[ib].qs[iqs * 2 ];
22+
const int16_t v1 = data_a[ib].qs[iqs * 2 + 1];
23+
return pack32(i16vec2(v1, v0));
24+
}
25+
#endif
26+
27+
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
28+
FLOAT_TYPE get_d(uint ib) {
29+
return FLOAT_TYPE(data_a[ib].d);
30+
}
31+
#endif
32+
33+
#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
34+
FLOAT_TYPE_VEC2 get_dm(uint ib) {
35+
return FLOAT_TYPE_VEC2(data_a[ib].d, data_a[ib].m);
36+
}
37+
#endif

ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp

Lines changed: 42 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,12 @@ layout (constant_id = 10) const uint WARP = 32;
8181

8282
// Assumption: BK == 32
8383
struct block_q8_0_internal {
84-
FLOAT_TYPE ds;
84+
FLOAT_TYPE d;
8585
int32_t qs[BK / 4];
8686
};
8787

8888
struct block_q8_1_internal {
89-
FLOAT_TYPE_VEC2 ds;
89+
FLOAT_TYPE_VEC2 dm;
9090
int32_t qs[BK / 4];
9191
};
9292

@@ -97,9 +97,10 @@ shared block_q8_0_internal buf_a[BM];
9797
shared block_q8_1_internal buf_a[BM];
9898
#endif
9999

100-
shared block_q8_1_internal buf_b[BN];
100+
shared block_q8_0_internal buf_b[BN];
101101

102-
#define LOAD_VEC 4
102+
#define LOAD_VEC_A (4 * QUANT_R)
103+
#define LOAD_VEC_B 4
103104

104105
#ifdef MUL_MAT_ID
105106
shared u16vec2 row_ids[3072];
@@ -111,6 +112,8 @@ shared u16vec2 row_ids[3072];
111112
shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
112113
#endif
113114

115+
#include "dequant_funcs_mmq.comp"
116+
114117
void main() {
115118
#if defined(DATA_A_IQ4_NL)
116119
init_iq4nl_shmem();
@@ -162,10 +165,13 @@ void main() {
162165
const uint warp_r = warp_i % (BM / WM);
163166
const uint warp_c = warp_i / (BM / WM);
164167

165-
const uint loadr = gl_LocalInvocationID.x % (BK / LOAD_VEC);
166-
const uint loadc = gl_LocalInvocationID.x / (BK / LOAD_VEC);
168+
const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A);
169+
const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A);
170+
const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B);
171+
const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B);
167172

168-
const uint loadstride = BLOCK_SIZE * LOAD_VEC / BK;
173+
const uint loadstride_a = BLOCK_SIZE * LOAD_VEC_A / BK;
174+
const uint loadstride_b = BLOCK_SIZE * LOAD_VEC_B / BK;
169175

170176
#ifdef MUL_MAT_ID
171177
uint _ne1 = 0;
@@ -222,46 +228,52 @@ void main() {
222228
block_q8_1_internal cache_a[WMITER * TM];
223229
#endif
224230

225-
block_q8_1_internal cache_b[TN];
231+
block_q8_0_internal cache_b[TN];
226232

227233
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
228234
sums[i] = ACC_TYPE(0.0f);
229235
}
230236
#endif
231237

232238
for (uint block = start_k; block < end_k; block += BK) {
233-
[[unroll]] for (uint l = 0; loadc + l < BM; l += loadstride) {
234-
#if defined(DATA_A_Q8_0)
235-
const uint ib = pos_a_ib + (loadc + l) * p.stride_a / BK;
236-
const uint iqs = loadr;
239+
[[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) {
240+
const uint ib = pos_a_ib + (loadc_a + l) * p.stride_a / BK;
241+
const uint iqs = loadr_a;
237242

238-
const uint buf_ib = loadc + l;
243+
const uint buf_ib = loadc_a + l;
239244

240245
// Should ds be gated to a single thread?
241246
if (iqs == 0) {
242-
buf_a[buf_ib].ds = FLOAT_TYPE(data_a[ib].d);
247+
#if QUANT_AUXF == 1
248+
buf_a[buf_ib].d = get_d(ib);
249+
#else
250+
buf_a[buf_ib].dm = get_dm(ib, 0);
251+
#endif
243252
}
244-
const int16_t v0 = data_a[ib].qs[iqs * 2 ];
245-
const int16_t v1 = data_a[ib].qs[iqs * 2 + 1];
246-
buf_a[buf_ib].qs[iqs] = pack32(i16vec2(v1, v0));
253+
#if QUANT_R == 1
254+
buf_a[buf_ib].qs[iqs] = repack(ib, iqs);
255+
#else
256+
const i32vec2 vals = repack(ib, iqs);
257+
buf_a[buf_ib].qs[iqs] = vals.x;
258+
buf_a[buf_ib].qs[iqs + 4] = vals.y;
247259
#endif
248260
}
249-
[[unroll]] for (uint l = 0; loadc + l < BN; l += loadstride) {
261+
[[unroll]] for (uint l = 0; loadc_b + l < BN; l += loadstride_b) {
250262
#ifdef MUL_MAT_ID
251-
const u16vec2 row_idx = row_ids[ic * BN + loadc + l];
252-
const uint idx = pos_b_ib + row_idx.y * p.batch_stride_b / LOAD_VEC + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC + loadr;
263+
const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
264+
const uint idx = pos_b_ib + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
253265
const uint ib = idx / 8;
254266
const uint iqs = idx & 0x7;
255267
#else
256-
const uint ib = pos_b_ib + (loadc + l) * p.stride_b / BK;
257-
const uint iqs = loadr;
268+
const uint ib = pos_b_ib + (loadc_b + l) * p.stride_b / BK;
269+
const uint iqs = loadr_b;
258270
#endif
259271

260-
const uint buf_ib = loadc + l;
272+
const uint buf_ib = loadc_b + l;
261273

262274
// Should ds be gated to a single thread?
263275
if (iqs == 0) {
264-
buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib].ds);
276+
buf_b[buf_ib].d = FLOAT_TYPE(data_b[ib].ds.x);
265277
}
266278
const int32_t v0 = int32_t(data_b[ib].qs[iqs * 2 ]);
267279
const int32_t v1 = int32_t(data_b[ib].qs[iqs * 2 + 1]);
@@ -311,7 +323,12 @@ void main() {
311323
q_sum = dotPacked4x8AccSatEXT(cache_a[cache_a_idx].qs[idx_k], cache_b[cc].qs[idx_k], q_sum);
312324
}
313325

314-
const float factor = float(cache_a[cache_a_idx].ds) * float(cache_b[cc].ds.x);
326+
#if QUANT_AUXF == 1
327+
const float factor = float(cache_a[cache_a_idx].d) * float(cache_b[cc].d);
328+
#else
329+
// TODO
330+
// const float factor = float(cache_a[cache_a_idx].d) * float(cache_b[cc].d);
331+
#endif
315332

316333
sums[sums_idx] = ACC_TYPE(fma(float(q_sum), factor, float(sums[sums_idx])));
317334
}

ggml/src/ggml-vulkan/vulkan-shaders/types.comp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ struct block_q4_0_packed16
4949
#if defined(DATA_A_Q4_0)
5050
#define QUANT_K QUANT_K_Q4_0
5151
#define QUANT_R QUANT_R_Q4_0
52+
#define QUANT_AUXF 1
5253
#define A_TYPE block_q4_0
5354
#define A_TYPE_PACKED16 block_q4_0_packed16
5455
#endif
@@ -73,6 +74,7 @@ struct block_q4_1_packed16
7374
#if defined(DATA_A_Q4_1)
7475
#define QUANT_K QUANT_K_Q4_1
7576
#define QUANT_R QUANT_R_Q4_1
77+
#define QUANT_AUXF 2
7678
#define A_TYPE block_q4_1
7779
#define A_TYPE_PACKED16 block_q4_1_packed16
7880
#endif
@@ -97,6 +99,7 @@ struct block_q5_0_packed16
9799
#if defined(DATA_A_Q5_0)
98100
#define QUANT_K QUANT_K_Q5_0
99101
#define QUANT_R QUANT_R_Q5_0
102+
#define QUANT_AUXF 1
100103
#define A_TYPE block_q5_0
101104
#define A_TYPE_PACKED16 block_q5_0_packed16
102105
#endif
@@ -123,6 +126,7 @@ struct block_q5_1_packed16
123126
#if defined(DATA_A_Q5_1)
124127
#define QUANT_K QUANT_K_Q5_1
125128
#define QUANT_R QUANT_R_Q5_1
129+
#define QUANT_AUXF 2
126130
#define A_TYPE block_q5_1
127131
#define A_TYPE_PACKED16 block_q5_1_packed16
128132
#endif

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
349349
string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
350350
}
351351

352-
if (!coopmat2 && !coopmat && tname == "q8_0") {
352+
if (!coopmat2 && !coopmat && (tname == "q4_0" || tname == "q8_0")) {
353353
string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
354354
}
355355
}

0 commit comments

Comments
 (0)