Skip to content

Commit dccf084

Browse files
committed
Vulkan: Add DP4A MMQ and Q8_1 quantization shader
1 parent 2cc4a5e commit dccf084

File tree

9 files changed

+824
-49
lines changed

9 files changed

+824
-49
lines changed

ggml/src/ggml-quants.c

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2020,6 +2020,13 @@ size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nr
20202020
return nrow * row_size;
20212021
}
20222022

2023+
size_t quantize_q8_1(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
2024+
(void)quant_weights; // not used
2025+
const size_t row_size = ggml_row_size(GGML_TYPE_Q8_1, n_per_row);
2026+
quantize_row_q8_1_ref(src, dst, (int64_t)nrow*n_per_row);
2027+
return nrow * row_size;
2028+
}
2029+
20232030
// ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs)
20242031

20252032
void quantize_row_tq1_0_ref(const float * restrict x, block_tq1_0 * restrict y, int64_t k) {

ggml/src/ggml-quants.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ GGML_API size_t quantize_q4_1(const float * GGML_RESTRICT src, void * GGML_RESTR
8989
GGML_API size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
9090
GGML_API size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
9191
GGML_API size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
92+
GGML_API size_t quantize_q8_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
9293

9394
GGML_API void iq2xs_init_impl(enum ggml_type type);
9495
GGML_API void iq2xs_free_impl(enum ggml_type type);

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 289 additions & 35 deletions
Large diffs are not rendered by default.

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ void main() {
212212
#else
213213
ACC_TYPE sums[WMITER * TM * WNITER * TN];
214214
FLOAT_TYPE cache_a[WMITER * TM];
215-
FLOAT_TYPE cache_b[WNITER * TN];
215+
FLOAT_TYPE cache_b[TN];
216216

217217
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
218218
sums[i] = ACC_TYPE(0.0f);
@@ -744,16 +744,14 @@ void main() {
744744
}
745745
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
746746
[[unroll]] for (uint j = 0; j < TN; j++) {
747-
cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i];
747+
cache_b[j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i];
748748
}
749-
}
750749

751-
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
752750
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
753751
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
754752
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
755753
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
756-
sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr]), ACC_TYPE(cache_b[wsic * TN + cc]), sums[sums_idx]);
754+
sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr]), ACC_TYPE(cache_b[cc]), sums[sums_idx]);
757755
}
758756
}
759757
}

0 commit comments

Comments
 (0)