Skip to content

Commit 2893137

Browse files
committed
metal : add kernel_get_rows_i32
ggml-ci
1 parent ab62fc3 commit 2893137

File tree

2 files changed

+33
-0
lines changed

2 files changed

+33
-0
lines changed

ggml-metal.m

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787
GGML_METAL_DECL_KERNEL(get_rows_q4_K);
8888
GGML_METAL_DECL_KERNEL(get_rows_q5_K);
8989
GGML_METAL_DECL_KERNEL(get_rows_q6_K);
90+
GGML_METAL_DECL_KERNEL(get_rows_i32);
9091
GGML_METAL_DECL_KERNEL(rms_norm);
9192
GGML_METAL_DECL_KERNEL(group_norm);
9293
GGML_METAL_DECL_KERNEL(norm);
@@ -377,6 +378,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
377378
GGML_METAL_ADD_KERNEL(get_rows_q4_K);
378379
GGML_METAL_ADD_KERNEL(get_rows_q5_K);
379380
GGML_METAL_ADD_KERNEL(get_rows_q6_K);
381+
GGML_METAL_ADD_KERNEL(get_rows_i32);
380382
GGML_METAL_ADD_KERNEL(rms_norm);
381383
GGML_METAL_ADD_KERNEL(group_norm);
382384
GGML_METAL_ADD_KERNEL(norm);
@@ -499,6 +501,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
499501
GGML_METAL_DEL_KERNEL(get_rows_q4_K);
500502
GGML_METAL_DEL_KERNEL(get_rows_q5_K);
501503
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
504+
GGML_METAL_DEL_KERNEL(get_rows_i32);
502505
GGML_METAL_DEL_KERNEL(rms_norm);
503506
GGML_METAL_DEL_KERNEL(group_norm);
504507
GGML_METAL_DEL_KERNEL(norm);
@@ -1978,6 +1981,7 @@ void ggml_metal_graph_compute(
19781981
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break;
19791982
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_K]; break;
19801983
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break;
1984+
case GGML_TYPE_I32: [encoder setComputePipelineState:ctx->pipeline_get_rows_i32]; break;
19811985
default: GGML_ASSERT(false && "not implemented");
19821986
}
19831987

ggml-metal.metal

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3829,6 +3829,35 @@ kernel void kernel_get_rows_f16(
38293829
}
38303830
}
38313831

3832+
kernel void kernel_get_rows_i32(
3833+
device const void * src0,
3834+
device const char * src1,
3835+
device int32_t * dst,
3836+
constant int64_t & ne00,
3837+
constant uint64_t & nb01,
3838+
constant uint64_t & nb02,
3839+
constant int64_t & ne10,
3840+
constant uint64_t & nb10,
3841+
constant uint64_t & nb11,
3842+
constant uint64_t & nb1,
3843+
constant uint64_t & nb2,
3844+
uint3 tgpig[[threadgroup_position_in_grid]],
3845+
uint tiitg[[thread_index_in_threadgroup]],
3846+
uint3 tptg [[threads_per_threadgroup]]) {
3847+
const int64_t i10 = tgpig.x;
3848+
const int64_t i11 = tgpig.y;
3849+
3850+
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
3851+
3852+
const int64_t i02 = i11;
3853+
3854+
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
3855+
((device int32_t *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
3856+
((device int32_t *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
3857+
}
3858+
}
3859+
3860+
38323861
#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
38333862
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
38343863
#define BLOCK_SIZE_K 32

0 commit comments

Comments
 (0)