Skip to content

Metal TQ2_0 #12485

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions ggml/src/ggml-metal/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K,
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K,
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K,
GGML_METAL_KERNEL_TYPE_GET_ROWS_TQ2_0,
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,
Expand Down Expand Up @@ -250,6 +251,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_TQ2_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,
Expand All @@ -275,6 +277,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_TQ2_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,
Expand All @@ -297,6 +300,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_TQ2_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,
Expand All @@ -319,6 +323,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_TQ2_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
Expand Down Expand Up @@ -802,6 +807,7 @@ @implementation GGMLMetalClass
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_TQ2_0, get_rows_tq2_0, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
Expand Down Expand Up @@ -904,6 +910,7 @@ @implementation GGMLMetalClass
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_TQ2_0_F32, mul_mv_tq2_0_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, has_simdgroup_reduction);
Expand All @@ -926,6 +933,7 @@ @implementation GGMLMetalClass
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_TQ2_0_F32, mul_mm_tq2_0_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, has_simdgroup_mm);
Expand Down Expand Up @@ -2523,6 +2531,7 @@ static void ggml_metal_encode_node(
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32 ].pipeline; break;
case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break;
case GGML_TYPE_TQ2_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_TQ2_0_F32 ].pipeline; break;
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
Expand Down Expand Up @@ -2674,6 +2683,12 @@ static void ggml_metal_encode_node(
nth1 = 32;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline;
} break;
case GGML_TYPE_TQ2_0:
{
nth0 = 4;
nth1 = 16;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_TQ2_0_F32].pipeline;
} break;
case GGML_TYPE_IQ2_XXS:
{
nth0 = 4;
Expand Down Expand Up @@ -2764,7 +2779,8 @@ static void ggml_metal_encode_node(

if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S||
src0t == GGML_TYPE_TQ2_0) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
Expand Down Expand Up @@ -2861,6 +2877,7 @@ static void ggml_metal_encode_node(
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32 ].pipeline; break;
case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32 ].pipeline; break;
case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break;
case GGML_TYPE_TQ2_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_TQ2_0_F32 ].pipeline; break;
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
Expand Down Expand Up @@ -2990,6 +3007,12 @@ static void ggml_metal_encode_node(
nth1 = 32;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline;
} break;
case GGML_TYPE_TQ2_0:
{
nth0 = 4;
nth1 = 16;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_TQ2_0_F32].pipeline;
} break;
case GGML_TYPE_IQ2_XXS:
{
nth0 = 4;
Expand Down Expand Up @@ -3089,7 +3112,8 @@ static void ggml_metal_encode_node(

if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S||
src0t == GGML_TYPE_TQ2_0) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
Expand Down Expand Up @@ -3142,6 +3166,7 @@ static void ggml_metal_encode_node(
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ].pipeline; break;
case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K ].pipeline; break;
case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K ].pipeline; break;
case GGML_TYPE_TQ2_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_TQ2_0 ].pipeline; break;
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break;
Expand Down
123 changes: 123 additions & 0 deletions ggml/src/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,21 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
}
}

template <typename type4x4>
void dequantize_tq2_0(device const block_tq2_0 * xb, short il, thread type4x4 & reg) {
const half d = xb->d;
device const uint8_t * q = (device const uint8_t *)xb->qs + 32*(il/8) + 16*(il&1);

il = (il/2)%4;

half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
const half dl = d * coef;
for (int i = 0; i < 16; ++i) {
reg[i/4][i%4] = dl * (q[i] & mask) - d;
}
}

template <typename type4x4>
void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
Expand Down Expand Up @@ -5041,6 +5056,110 @@ kernel void kernel_mul_mv_q6_K_f32(
kernel_mul_mv_q6_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}

// ======================= Ternary

template<typename args_t>
void kernel_mul_mv_tq2_0_f32_impl(
args_t args,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup char * shmem,
uint3 tgpig,
ushort tiisg,
ushort sgitg) {

const int nb = args.ne00/QK_K;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;

const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;

const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;

const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;

device const block_tq2_0 * x = (device const block_tq2_0 *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);

float yl[32];
float sumf[N_DST]={0.f}, all_sum;

const int step = sizeof(block_tq2_0) * nb / 2;

const int ix = tiisg/8; // 0...3
const int it = tiisg%8; // 0...7
const int iq = it/4; // 0 or 1
const int ir = it%4; // 0...3

device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;

for (int ib = ix; ib < nb; ib += 4) {

float sumy = 0.f;
for (int i = 0; i < 8; ++i) {
yl[i+ 0] = y4[i+ 0]; sumy += yl[i+ 0];
yl[i+ 8] = y4[i+32]; sumy += yl[i+ 8];
yl[i+16] = y4[i+64]; sumy += yl[i+16];
yl[i+24] = y4[i+96]; sumy += yl[i+24];
}

device const half * dh = &x[ib].d;
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;

for (int row = 0; row < N_DST; row++) {

float4 acc1 = {0.f, 0.f, 0.f, 0.f};
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
for (int i = 0; i < 8; i += 2) {
acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
}
float dall = dh[0];
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * 1.f/ 1.f +
(acc1[1] + 1.f/256.f * acc2[1]) * 1.f/ 4.f +
(acc1[2] + 1.f/256.f * acc2[2]) * 1.f/16.f +
(acc1[3] + 1.f/256.f * acc2[3]) * 1.f/64.f - sumy);

qs += step;
dh += step;
}

y4 += 4 * QK_K;
}

device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;

for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
dst_f32[first_row + row] = all_sum;
}
}
}

[[host_name("kernel_mul_mv_tq2_0_f32")]]
kernel void kernel_mul_mv_tq2_0_f32(
constant ggml_metal_kargs_mul_mv & args,
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {

kernel_mul_mv_tq2_0_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}

// ======================= "True" 2-bit

template<typename args_t>
Expand Down Expand Up @@ -6466,6 +6585,7 @@ template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_K, QK_NL, dequantize_q4_K>;
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_K, QK_NL, dequantize_q5_K>;
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q6_K, QK_NL, dequantize_q6_K>;
template [[host_name("kernel_get_rows_tq2_0")]] kernel get_rows_q_t kernel_get_rows_q<block_tq2_0, QK_NL, dequantize_tq2_0>;
template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
Expand Down Expand Up @@ -6497,6 +6617,7 @@ template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_m
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K>;
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K>;
template [[host_name("kernel_mul_mm_tq2_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_tq2_0, QK_NL, dequantize_tq2_0>;
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs>;
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
Expand Down Expand Up @@ -6528,6 +6649,7 @@ template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
template [[host_name("kernel_mul_mm_id_tq2_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_tq2_0, QK_NL, dequantize_tq2_0>;
template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
Expand Down Expand Up @@ -6670,6 +6792,7 @@ template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t
template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl>>;
template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl>>;
template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl>>;
template [[host_name("kernel_mul_mv_id_tq2_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_tq2_0_f32_impl>>;
template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl>>;
template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl>>;
template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl>>;
Expand Down