Skip to content

Commit 1202e06

Browse files
committed
metal : add Q8_0 mul_mm kernel
1 parent 61c8259 commit 1202e06

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

ggml-metal.m

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
8484
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
8585
GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
86+
GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32);
8687
GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
8788
GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
8889
GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
@@ -209,6 +210,7 @@ @implementation GGMLMetalClass
209210
GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
210211
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
211212
GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
213+
GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
212214
GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
213215
GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
214216
GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
@@ -751,9 +753,10 @@ void ggml_metal_graph_compute(
751753
ne00%32 == 0 &&
752754
ne11 > 1) {
753755
switch (src0->type) {
754-
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
756+
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
755757
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
756758
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
759+
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break;
757760
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
758761
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
759762
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;

ggml-metal.metal

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2041,6 +2041,7 @@ typedef void (mat_mm_t)(device const uchar *, device const float *, device float
20412041
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
20422042
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
20432043
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
2044+
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
20442045
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
20452046
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
20462047
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;

0 commit comments

Comments
 (0)