@@ -5755,31 +5755,32 @@ kernel void kernel_mul_mm(
5755
5755
}
5756
5756
5757
5757
// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids
5758
+ // TODO: this kernel needs to be reimplemented from scratch for better performance
5758
5759
template <typename block_q, short nl, void (*dequantize_func)(device const block_q *, short , thread half4x4 &)>
5759
5760
void kernel_mul_mm_id_impl (
5760
- device const uchar * src0,
5761
- device const uchar * src1,
5761
+ int32_t ne00,
5762
+ int32_t ne02,
5763
+ uint64_t nb01,
5764
+ uint64_t nb02,
5765
+ int32_t ne11,
5766
+ int32_t ne12,
5767
+ uint64_t nb10,
5768
+ uint64_t nb11,
5769
+ uint64_t nb12,
5770
+ int32_t ne0,
5771
+ int32_t ne1,
5772
+ int64_t ne0ne1,
5773
+ device const char * src0,
5774
+ device const char * src1,
5762
5775
threadgroup ushort2 * rowids,
5763
- device float * dst,
5764
- constant int64_t & ne00,
5765
- constant int64_t & ne02,
5766
- constant uint64_t & nb01,
5767
- constant uint64_t & nb02,
5768
- constant int64_t & ne11,
5769
- constant int64_t & ne12,
5770
- constant uint64_t & nb10,
5771
- constant uint64_t & nb11,
5772
- constant uint64_t & nb12,
5773
- constant int64_t & ne0,
5774
- int64_t ne1,
5775
- int64_t ne0ne1,
5776
- threadgroup uchar * shared_memory,
5777
- uint3 tgpig[[threadgroup_position_in_grid]],
5778
- uint tiitg[[thread_index_in_threadgroup]],
5779
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
5780
-
5781
- threadgroup half * sa = (threadgroup half *)(shared_memory);
5782
- threadgroup float * sb = (threadgroup float *)(shared_memory + 4096 );
5776
+ device char * dst,
5777
+ threadgroup char * shmem,
5778
+ uint3 tgpig[[threadgroup_position_in_grid]],
5779
+ ushort tiitg[[thread_index_in_threadgroup]],
5780
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
5781
+
5782
+ threadgroup half * sa = (threadgroup half *)(shmem);
5783
+ threadgroup float * sb = (threadgroup float *)(shmem + 4096 );
5783
5784
5784
5785
const uint r0 = tgpig.y ;
5785
5786
const uint r1 = tgpig.x ;
@@ -5796,9 +5797,9 @@ void kernel_mul_mm_id_impl(
5796
5797
5797
5798
simdgroup_half8x8 ma[4 ];
5798
5799
simdgroup_float8x8 mb[2 ];
5799
- simdgroup_float8x8 c_res [8 ];
5800
+ simdgroup_float8x8 mc [8 ];
5800
5801
for (int i = 0 ; i < 8 ; i++){
5801
- c_res [i] = make_filled_simdgroup_matrix<float , 8 >(0 .f );
5802
+ mc [i] = make_filled_simdgroup_matrix<float , 8 >(0 .f );
5802
5803
}
5803
5804
short il = (tiitg % THREAD_PER_ROW);
5804
5805
@@ -5836,41 +5837,57 @@ void kernel_mul_mm_id_impl(
5836
5837
threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2 ));
5837
5838
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2 ));
5838
5839
5840
+ #pragma unroll(BLOCK_SIZE_K/8)
5839
5841
for (int ik = 0 ; ik < BLOCK_SIZE_K / 8 ; ik++) {
5842
+ #pragma unroll(4)
5840
5843
for (int i = 0 ; i < 4 ; i++) {
5841
5844
simdgroup_load (ma[i], lsma + SG_MAT_SIZE * i);
5842
5845
}
5843
5846
simdgroup_barrier (mem_flags::mem_none);
5847
+ #pragma unroll(2)
5844
5848
for (int i = 0 ; i < 2 ; i++) {
5845
5849
simdgroup_load (mb[i], lsmb + SG_MAT_SIZE * i);
5846
5850
}
5847
5851
5848
5852
lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
5849
5853
lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
5850
5854
5855
+ #pragma unroll(8)
5851
5856
for (int i = 0 ; i < 8 ; i++){
5852
- simdgroup_multiply_accumulate (c_res [i], mb[i/4 ], ma[i%4 ], c_res [i]);
5857
+ simdgroup_multiply_accumulate (mc [i], mb[i/4 ], ma[i%4 ], mc [i]);
5853
5858
}
5854
5859
}
5855
5860
}
5856
5861
5857
5862
{
5858
5863
threadgroup_barrier (mem_flags::mem_threadgroup);
5859
- threadgroup float * temp_str = ((threadgroup float *)shared_memory ) \
5864
+ threadgroup float * temp_str = ((threadgroup float *) shmem ) \
5860
5865
+ 32 * (sgitg&1 ) + (16 * (sgitg>>1 )) * BLOCK_SIZE_M;
5861
5866
for (int i = 0 ; i < 8 ; i++) {
5862
- simdgroup_store (c_res [i], temp_str + 8 * (i%4 ) + 8 * BLOCK_SIZE_M * (i/4 ), BLOCK_SIZE_M);
5867
+ simdgroup_store (mc [i], temp_str + 8 * (i%4 ) + 8 * BLOCK_SIZE_M * (i/4 ), BLOCK_SIZE_M);
5863
5868
}
5864
5869
5865
5870
threadgroup_barrier (mem_flags::mem_threadgroup);
5866
5871
5867
- device float * C = dst + (BLOCK_SIZE_M * r0);
5868
5872
if (sgitg == 0 ) {
5869
5873
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
5870
5874
threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j];
5871
- int joff = jid[0 ] * ne0 + jid[1 ] * ne0ne1;
5872
- for (int i = 0 ; i < n_rows; i++) {
5873
- *(C + i + joff) = *(temp_str + i + j * BLOCK_SIZE_M);
5875
+ int64_t joff = jid[0 ]*ne0 + jid[1 ]*ne0ne1;
5876
+
5877
+ device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + joff;
5878
+ device float4 * D4 = (device float4 *) D;
5879
+
5880
+ threadgroup float * C = temp_str + (j*BLOCK_SIZE_M);
5881
+ threadgroup float4 * C4 = (threadgroup float4 *) C;
5882
+
5883
+ int i = 0 ;
5884
+ for (; i < n_rows/4 ; i++) {
5885
+ *(D4 + i) = *(C4 + i);
5886
+ }
5887
+
5888
+ i *= 4 ;
5889
+ for (; i < n_rows; i++) {
5890
+ *(D + i) = *(C + i);
5874
5891
}
5875
5892
}
5876
5893
}
@@ -5879,48 +5896,34 @@ void kernel_mul_mm_id_impl(
5879
5896
5880
5897
template <typename block_q, short nl, void (*dequantize_func)(device const block_q *, short , thread half4x4 &)>
5881
5898
kernel void kernel_mul_mm_id (
5882
- device const uchar * src0s,
5883
- device const uchar * src1,
5884
- device float * dst,
5885
- device const uchar * ids,
5886
- constant int64_t & nei0,
5887
- constant int64_t & nei1,
5888
- constant uint64_t & nbi1,
5889
- constant int64_t & ne00,
5890
- constant int64_t & ne02,
5891
- constant uint64_t & nb01,
5892
- constant uint64_t & nb02,
5893
- constant int64_t & ne11,
5894
- constant int64_t & ne12,
5895
- constant int64_t & ne13,
5896
- constant uint64_t & nb10,
5897
- constant uint64_t & nb11,
5898
- constant uint64_t & nb12,
5899
- constant int64_t & ne0,
5900
- constant int64_t & ne1,
5901
- constant uint64_t & nb1,
5902
- threadgroup uchar * shared_memory [[threadgroup(0 )]],
5903
- uint3 tgpig[[threadgroup_position_in_grid]],
5904
- uint tiitg[[thread_index_in_threadgroup]],
5905
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
5899
+ constant ggml_metal_kargs_mul_mm_id & args,
5900
+ device const char * src0s,
5901
+ device const char * src1,
5902
+ device char * dst,
5903
+ device const char * ids,
5904
+ threadgroup char * shmem [[threadgroup(0 )]],
5905
+ uint3 tgpig[[threadgroup_position_in_grid]],
5906
+ ushort tiitg[[thread_index_in_threadgroup]],
5907
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
5906
5908
5907
5909
const int32_t i02 = tgpig.z ;
5910
+
5908
5911
tgpig.z = 0 ;
5909
5912
5910
- device const uchar * src0 = src0s + i02*nb02;
5913
+ device const char * src0 = src0s + i02*args. nb02 ;
5911
5914
5912
5915
// row indices
5913
- threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192 );
5916
+ threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shmem + 8192 );
5914
5917
5915
5918
// TODO: parallelize this loop
5916
5919
int64_t _ne1 = 0 ;
5917
- for (ushort ii1 = 0 ; ii1 < nei1; ii1++) {
5918
- for (ushort ii0 = 0 ; ii0 < nei0; ii0++) {
5919
- int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
5920
+ for (ushort ii1 = 0 ; ii1 < args. nei1 ; ii1++) {
5921
+ for (ushort ii0 = 0 ; ii0 < args. nei0 ; ii0++) {
5922
+ int32_t id = ((device int32_t *) (ids + ii1*args. nbi1 ))[ii0];
5920
5923
if (id == i02) {
5921
- // if (tiitg == 0) {
5924
+ if (tiitg == 0 ) {
5922
5925
rowids[_ne1] = ushort2 (ii0, ii1);
5923
- // }
5926
+ }
5924
5927
_ne1++;
5925
5928
}
5926
5929
}
@@ -5929,23 +5932,23 @@ kernel void kernel_mul_mm_id(
5929
5932
threadgroup_barrier (mem_flags::mem_threadgroup);
5930
5933
5931
5934
kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
5935
+ args.ne00 ,
5936
+ args.ne02 ,
5937
+ args.nb01 ,
5938
+ args.nb02 ,
5939
+ args.ne11 ,
5940
+ args.ne12 ,
5941
+ args.nb10 ,
5942
+ args.nb11 ,
5943
+ args.nb12 ,
5944
+ args.ne0 ,
5945
+ _ne1,
5946
+ (int64_t )args.ne0 *args.ne1 ,
5932
5947
src0,
5933
5948
src1,
5934
5949
rowids,
5935
5950
dst,
5936
- ne00,
5937
- ne02,
5938
- nb01,
5939
- nb02,
5940
- ne11,
5941
- ne12,
5942
- nb10,
5943
- nb11,
5944
- nb12,
5945
- ne0,
5946
- _ne1,
5947
- ne0*ne1,
5948
- shared_memory,
5951
+ shmem,
5949
5952
tgpig,
5950
5953
tiitg,
5951
5954
sgitg);
0 commit comments