Skip to content

Commit 6011ac3

Browse files
committed
iq1_m: Metal - dequantize works, dot product does not
1 parent e5360e4 commit 6011ac3

File tree

2 files changed

+252
-6
lines changed

2 files changed

+252
-6
lines changed

ggml-metal.m

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S,
6565
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S,
6666
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,
67+
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M,
6768
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
6869
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
6970
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
@@ -91,6 +92,7 @@
9192
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32,
9293
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32,
9394
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,
95+
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32,
9496
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
9597
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,
9698
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
@@ -114,6 +116,7 @@
114116
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32,
115117
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32,
116118
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,
119+
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32,
117120
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
118121
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
119122
GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
@@ -134,6 +137,7 @@
134137
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32,
135138
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32,
136139
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,
140+
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
137141
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
138142
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
139143
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
@@ -154,6 +158,7 @@
154158
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,
155159
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,
156160
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
161+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,
157162
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
158163
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
159164
GGML_METAL_KERNEL_TYPE_ROPE_F32,
@@ -490,6 +495,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
490495
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true);
491496
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true);
492497
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
498+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true);
493499
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
494500
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
495501
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
@@ -517,6 +523,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
517523
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, ctx->support_simdgroup_reduction);
518524
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction);
519525
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
526+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, ctx->support_simdgroup_reduction);
520527
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
521528
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction);
522529
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
@@ -540,6 +547,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
540547
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, ctx->support_simdgroup_reduction);
541548
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction);
542549
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
550+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, ctx->support_simdgroup_reduction);
543551
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
544552
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction);
545553
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
@@ -560,6 +568,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
560568
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, ctx->support_simdgroup_mm);
561569
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm);
562570
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
571+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, ctx->support_simdgroup_mm);
563572
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
564573
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm);
565574
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
@@ -580,6 +589,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
580589
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, ctx->support_simdgroup_mm);
581590
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm);
582591
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
592+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm);
583593
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
584594
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
585595
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
@@ -1421,6 +1431,7 @@ static enum ggml_status ggml_metal_graph_compute(
14211431
case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break;
14221432
case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break;
14231433
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
1434+
case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break;
14241435
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
14251436
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
14261437
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
@@ -1575,6 +1586,12 @@ static enum ggml_status ggml_metal_graph_compute(
15751586
nth1 = 16;
15761587
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
15771588
} break;
1589+
case GGML_TYPE_IQ1_M:
1590+
{
1591+
nth0 = 4;
1592+
nth1 = 16;
1593+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline;
1594+
} break;
15781595
case GGML_TYPE_IQ4_NL:
15791596
{
15801597
nth0 = 4;
@@ -1619,9 +1636,9 @@ static enum ggml_status ggml_metal_graph_compute(
16191636
[encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
16201637
[encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
16211638

1622-
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
1623-
src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
1624-
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ2_S) {
1639+
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
1640+
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
1641+
src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
16251642
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
16261643
}
16271644
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
@@ -1743,6 +1760,7 @@ static enum ggml_status ggml_metal_graph_compute(
17431760
case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break;
17441761
case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break;
17451762
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
1763+
case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32 ].pipeline; break;
17461764
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
17471765
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break;
17481766
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
@@ -1900,6 +1918,12 @@ static enum ggml_status ggml_metal_graph_compute(
19001918
nth1 = 16;
19011919
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline;
19021920
} break;
1921+
case GGML_TYPE_IQ1_M:
1922+
{
1923+
nth0 = 4;
1924+
nth1 = 16;
1925+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline;
1926+
} break;
19031927
case GGML_TYPE_IQ4_NL:
19041928
{
19051929
nth0 = 4;
@@ -1960,9 +1984,9 @@ static enum ggml_status ggml_metal_graph_compute(
19601984
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
19611985
}
19621986

1963-
if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
1964-
src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
1965-
src2t == GGML_TYPE_Q2_K || src2t == GGML_TYPE_IQ1_S || src2t == GGML_TYPE_IQ2_S) {
1987+
if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 || src2t == GGML_TYPE_Q5_0 ||
1988+
src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 || src2t == GGML_TYPE_Q2_K ||
1989+
src2t == GGML_TYPE_IQ1_S || src2t == GGML_TYPE_IQ1_M || src2t == GGML_TYPE_IQ2_S) {
19661990
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
19671991
}
19681992
else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) {
@@ -2024,6 +2048,7 @@ static enum ggml_status ggml_metal_graph_compute(
20242048
case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S ].pipeline; break;
20252049
case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S ].pipeline; break;
20262050
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break;
2051+
case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M ].pipeline; break;
20272052
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
20282053
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break;
20292054
case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;

0 commit comments

Comments
 (0)