Skip to content

Commit 3855622

Browse files
committed
cont : mul mm id
ggml-ci
1 parent eea1f7e commit 3855622

File tree

3 files changed

+119
-95
lines changed

3 files changed

+119
-95
lines changed

ggml/src/ggml-common.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,24 @@ typedef struct {
509509
int16_t r3;
510510
} ggml_metal_kargs_mul_mv;
511511

512+
typedef struct {
513+
int32_t nei0;
514+
int32_t nei1;
515+
uint64_t nbi1;
516+
int32_t ne00;
517+
int32_t ne02;
518+
uint64_t nb01;
519+
uint64_t nb02;
520+
int32_t ne11;
521+
int32_t ne12;
522+
int32_t ne13;
523+
uint64_t nb10;
524+
uint64_t nb11;
525+
uint64_t nb12;
526+
int32_t ne0;
527+
int32_t ne1;
528+
} ggml_metal_kargs_mul_mm_id;
529+
512530
typedef struct {
513531
int32_t nei0;
514532
int32_t nei1;

ggml/src/ggml-metal.m

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2297,27 +2297,30 @@ static void ggml_metal_encode_node(
22972297
default: GGML_ABORT("MUL_MAT_ID not implemented");
22982298
}
22992299

2300+
ggml_metal_kargs_mul_mm_id args = {
2301+
/*.nei0 =*/ ne20,
2302+
/*.nei1 =*/ ne21,
2303+
/*.nbi1 =*/ nb21,
2304+
/*.ne00 =*/ ne00,
2305+
/*.ne02 =*/ ne02,
2306+
/*.nb01 =*/ nb01,
2307+
/*.nb02 =*/ nb02,
2308+
/*.ne11 =*/ ne11,
2309+
/*.ne12 =*/ ne12,
2310+
/*.ne13 =*/ ne13,
2311+
/*.nb10 =*/ nb10,
2312+
/*.nb11 =*/ nb11,
2313+
/*.nb12 =*/ nb12,
2314+
/*.ne0 =*/ ne0,
2315+
/*.ne1 =*/ ne1,
2316+
};
2317+
23002318
[encoder setComputePipelineState:pipeline];
2301-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2302-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2303-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2304-
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
2305-
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
2306-
[encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
2307-
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
2308-
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
2309-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:8];
2310-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9];
2311-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10];
2312-
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
2313-
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
2314-
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
2315-
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
2316-
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
2317-
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
2318-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
2319-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18];
2320-
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
2319+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
2320+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2321+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
2322+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
2323+
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:4];
23212324

23222325
[encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0];
23232326

ggml/src/ggml-metal.metal

Lines changed: 78 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -5755,31 +5755,32 @@ kernel void kernel_mul_mm(
57555755
}
57565756

57575757
// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids
5758+
// TODO: this kernel needs to be reimplemented from scratch for better performance
57585759
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
57595760
void kernel_mul_mm_id_impl(
5760-
device const uchar * src0,
5761-
device const uchar * src1,
5761+
int32_t ne00,
5762+
int32_t ne02,
5763+
uint64_t nb01,
5764+
uint64_t nb02,
5765+
int32_t ne11,
5766+
int32_t ne12,
5767+
uint64_t nb10,
5768+
uint64_t nb11,
5769+
uint64_t nb12,
5770+
int32_t ne0,
5771+
int32_t ne1,
5772+
int64_t ne0ne1,
5773+
device const char * src0,
5774+
device const char * src1,
57625775
threadgroup ushort2 * rowids,
5763-
device float * dst,
5764-
constant int64_t & ne00,
5765-
constant int64_t & ne02,
5766-
constant uint64_t & nb01,
5767-
constant uint64_t & nb02,
5768-
constant int64_t & ne11,
5769-
constant int64_t & ne12,
5770-
constant uint64_t & nb10,
5771-
constant uint64_t & nb11,
5772-
constant uint64_t & nb12,
5773-
constant int64_t & ne0,
5774-
int64_t ne1,
5775-
int64_t ne0ne1,
5776-
threadgroup uchar * shared_memory,
5777-
uint3 tgpig[[threadgroup_position_in_grid]],
5778-
uint tiitg[[thread_index_in_threadgroup]],
5779-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
5780-
5781-
threadgroup half * sa = (threadgroup half *)(shared_memory);
5782-
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
5776+
device char * dst,
5777+
threadgroup char * shmem,
5778+
uint3 tgpig[[threadgroup_position_in_grid]],
5779+
ushort tiitg[[thread_index_in_threadgroup]],
5780+
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
5781+
5782+
threadgroup half * sa = (threadgroup half *)(shmem);
5783+
threadgroup float * sb = (threadgroup float *)(shmem + 4096);
57835784

57845785
const uint r0 = tgpig.y;
57855786
const uint r1 = tgpig.x;
@@ -5796,9 +5797,9 @@ void kernel_mul_mm_id_impl(
57965797

57975798
simdgroup_half8x8 ma[4];
57985799
simdgroup_float8x8 mb[2];
5799-
simdgroup_float8x8 c_res[8];
5800+
simdgroup_float8x8 mc[8];
58005801
for (int i = 0; i < 8; i++){
5801-
c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
5802+
mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
58025803
}
58035804
short il = (tiitg % THREAD_PER_ROW);
58045805

@@ -5836,41 +5837,57 @@ void kernel_mul_mm_id_impl(
58365837
threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
58375838
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
58385839

5840+
#pragma unroll(BLOCK_SIZE_K/8)
58395841
for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
5842+
#pragma unroll(4)
58405843
for (int i = 0; i < 4; i++) {
58415844
simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
58425845
}
58435846
simdgroup_barrier(mem_flags::mem_none);
5847+
#pragma unroll(2)
58445848
for (int i = 0; i < 2; i++) {
58455849
simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
58465850
}
58475851

58485852
lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
58495853
lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
58505854

5855+
#pragma unroll(8)
58515856
for (int i = 0; i < 8; i++){
5852-
simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
5857+
simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
58535858
}
58545859
}
58555860
}
58565861

58575862
{
58585863
threadgroup_barrier(mem_flags::mem_threadgroup);
5859-
threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
5864+
threadgroup float * temp_str = ((threadgroup float *) shmem) \
58605865
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
58615866
for (int i = 0; i < 8; i++) {
5862-
simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
5867+
simdgroup_store(mc[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
58635868
}
58645869

58655870
threadgroup_barrier(mem_flags::mem_threadgroup);
58665871

5867-
device float * C = dst + (BLOCK_SIZE_M * r0);
58685872
if (sgitg == 0) {
58695873
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
58705874
threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j];
5871-
int joff = jid[0] * ne0 + jid[1] * ne0ne1;
5872-
for (int i = 0; i < n_rows; i++) {
5873-
*(C + i + joff) = *(temp_str + i + j * BLOCK_SIZE_M);
5875+
int64_t joff = jid[0]*ne0 + jid[1]*ne0ne1;
5876+
5877+
device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + joff;
5878+
device float4 * D4 = (device float4 *) D;
5879+
5880+
threadgroup float * C = temp_str + (j*BLOCK_SIZE_M);
5881+
threadgroup float4 * C4 = (threadgroup float4 *) C;
5882+
5883+
int i = 0;
5884+
for (; i < n_rows/4; i++) {
5885+
*(D4 + i) = *(C4 + i);
5886+
}
5887+
5888+
i *= 4;
5889+
for (; i < n_rows; i++) {
5890+
*(D + i) = *(C + i);
58745891
}
58755892
}
58765893
}
@@ -5879,48 +5896,34 @@ void kernel_mul_mm_id_impl(
58795896

58805897
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
58815898
kernel void kernel_mul_mm_id(
5882-
device const uchar * src0s,
5883-
device const uchar * src1,
5884-
device float * dst,
5885-
device const uchar * ids,
5886-
constant int64_t & nei0,
5887-
constant int64_t & nei1,
5888-
constant uint64_t & nbi1,
5889-
constant int64_t & ne00,
5890-
constant int64_t & ne02,
5891-
constant uint64_t & nb01,
5892-
constant uint64_t & nb02,
5893-
constant int64_t & ne11,
5894-
constant int64_t & ne12,
5895-
constant int64_t & ne13,
5896-
constant uint64_t & nb10,
5897-
constant uint64_t & nb11,
5898-
constant uint64_t & nb12,
5899-
constant int64_t & ne0,
5900-
constant int64_t & ne1,
5901-
constant uint64_t & nb1,
5902-
threadgroup uchar * shared_memory [[threadgroup(0)]],
5903-
uint3 tgpig[[threadgroup_position_in_grid]],
5904-
uint tiitg[[thread_index_in_threadgroup]],
5905-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
5899+
constant ggml_metal_kargs_mul_mm_id & args,
5900+
device const char * src0s,
5901+
device const char * src1,
5902+
device char * dst,
5903+
device const char * ids,
5904+
threadgroup char * shmem [[threadgroup(0)]],
5905+
uint3 tgpig[[threadgroup_position_in_grid]],
5906+
ushort tiitg[[thread_index_in_threadgroup]],
5907+
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
59065908

59075909
const int32_t i02 = tgpig.z;
5910+
59085911
tgpig.z = 0;
59095912

5910-
device const uchar * src0 = src0s + i02*nb02;
5913+
device const char * src0 = src0s + i02*args.nb02;
59115914

59125915
// row indices
5913-
threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192);
5916+
threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shmem + 8192);
59145917

59155918
// TODO: parallelize this loop
59165919
int64_t _ne1 = 0;
5917-
for (ushort ii1 = 0; ii1 < nei1; ii1++) {
5918-
for (ushort ii0 = 0; ii0 < nei0; ii0++) {
5919-
int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
5920+
for (ushort ii1 = 0; ii1 < args.nei1; ii1++) {
5921+
for (ushort ii0 = 0; ii0 < args.nei0; ii0++) {
5922+
int32_t id = ((device int32_t *) (ids + ii1*args.nbi1))[ii0];
59205923
if (id == i02) {
5921-
//if (tiitg == 0) {
5924+
if (tiitg == 0) {
59225925
rowids[_ne1] = ushort2(ii0, ii1);
5923-
//}
5926+
}
59245927
_ne1++;
59255928
}
59265929
}
@@ -5929,23 +5932,23 @@ kernel void kernel_mul_mm_id(
59295932
threadgroup_barrier(mem_flags::mem_threadgroup);
59305933

59315934
kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
5935+
args.ne00,
5936+
args.ne02,
5937+
args.nb01,
5938+
args.nb02,
5939+
args.ne11,
5940+
args.ne12,
5941+
args.nb10,
5942+
args.nb11,
5943+
args.nb12,
5944+
args.ne0,
5945+
_ne1,
5946+
(int64_t)args.ne0*args.ne1,
59325947
src0,
59335948
src1,
59345949
rowids,
59355950
dst,
5936-
ne00,
5937-
ne02,
5938-
nb01,
5939-
nb02,
5940-
ne11,
5941-
ne12,
5942-
nb10,
5943-
nb11,
5944-
nb12,
5945-
ne0,
5946-
_ne1,
5947-
ne0*ne1,
5948-
shared_memory,
5951+
shmem,
59495952
tgpig,
59505953
tiitg,
59515954
sgitg);

0 commit comments

Comments
 (0)