Skip to content

Commit 5660976

Browse files
committed
Improve performance with better q4_k and q5_k dequant and store unrolling
1 parent 6805303 commit 5660976

File tree

2 files changed

+34
-23
lines changed

2 files changed

+34
-23
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1842,7 +1842,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
18421842
if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eFloat16 &&
18431843
(vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eFloat16 &&
18441844
(vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat32 &&
1845-
(vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat32
1845+
(vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat32 &&
1846+
(vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup
18461847
) {
18471848
device->coop_mat_m = prop.MSize;
18481849
device->coop_mat_n = prop.NSize;

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -365,15 +365,20 @@ void main() {
365365

366366
const vec2 loadd = vec2(data_a[ib].d);
367367

368-
uint8_t sc;
369-
uint8_t mbyte;
370-
if (is < 4) {
371-
sc = uint8_t(data_a[ib].scales[is ] & 63);
372-
mbyte = uint8_t(data_a[ib].scales[is + 4] & 63);
373-
} else {
374-
sc = uint8_t((data_a[ib].scales[is + 4] & 0xF) | ((data_a[ib].scales[is - 4] >> 6) << 4));
375-
mbyte = uint8_t((data_a[ib].scales[is + 4] >> 4) | ((data_a[ib].scales[is ] >> 6) << 4));
376-
}
368+
const uint scidx0 = (is < 4) ? is : (is + 4);
369+
const uint scidx1 = (is < 4) ? is : (is - 4);
370+
const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
371+
const uint scidxshift1 = (is < 4) ? 0 : 2;
372+
const uint mbidx0 = is + 4;
373+
const uint mbidx1 = (is < 4) ? is + 4 : is;
374+
const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
375+
const uint mbidxshift0 = (is < 4) ? 0 : 4;
376+
const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
377+
const uint mbidxshift1 = (is < 4) ? 0 : 2;
378+
379+
const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
380+
const uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
381+
377382
const float d = loadd.x * sc;
378383
const float m = -loadd.y * mbyte;
379384

@@ -396,15 +401,20 @@ void main() {
396401

397402
const vec2 loadd = vec2(data_a[ib].d);
398403

399-
uint8_t sc;
400-
uint8_t mbyte;
401-
if (is < 4) {
402-
sc = uint8_t(data_a[ib].scales[is ] & 63);
403-
mbyte = uint8_t(data_a[ib].scales[is + 4] & 63);
404-
} else {
405-
sc = uint8_t((data_a[ib].scales[is + 4] & 0xF) | ((data_a[ib].scales[is - 4] >> 6) << 4));
406-
mbyte = uint8_t((data_a[ib].scales[is + 4] >> 4) | ((data_a[ib].scales[is ] >> 6) << 4));
407-
}
404+
const uint scidx0 = (is < 4) ? is : (is + 4);
405+
const uint scidx1 = (is < 4) ? is : (is - 4);
406+
const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
407+
const uint scidxshift1 = (is < 4) ? 0 : 2;
408+
const uint mbidx0 = is + 4;
409+
const uint mbidx1 = (is < 4) ? is + 4 : is;
410+
const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
411+
const uint mbidxshift0 = (is < 4) ? 0 : 4;
412+
const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
413+
const uint mbidxshift1 = (is < 4) ? 0 : 2;
414+
415+
const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
416+
const uint8_t mbyte = uint8_t(((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
417+
408418
const float d = loadd.x * sc;
409419
const float m = -loadd.y * mbyte;
410420

@@ -547,8 +557,8 @@ void main() {
547557

548558
#ifdef COOPMAT
549559
#ifdef MUL_MAT_ID
550-
for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
551-
for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
560+
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
561+
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
552562
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
553563

554564
[[unroll]] for (uint col = 0; col < BN; col += storestride) {
@@ -564,8 +574,8 @@ void main() {
564574
#else
565575
const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float
566576

567-
for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
568-
for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
577+
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
578+
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
569579
const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N;
570580

571581
if (is_aligned && is_in_bounds) {

0 commit comments

Comments
 (0)