Skip to content

Commit f12f803

Browse files
committed
metal: For TQ2_0, Apply changes from: 'metal : refactor kernel args into structs (ggml-org#10238)'
1 parent 613a7c5 commit f12f803

File tree

1 file changed

+26
-46
lines changed

1 file changed

+26
-46
lines changed

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 26 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -5058,39 +5058,32 @@ kernel void kernel_mul_mv_q6_K_f32(
50585058

50595059
// ======================= Ternary
50605060

5061+
template<typename args_t>
50615062
void kernel_mul_mv_tq2_0_f32_impl(
5062-
device const void * src0,
5063-
device const float * src1,
5064-
device float * dst,
5065-
int64_t ne00,
5066-
int64_t ne01,
5067-
int64_t ne02,
5068-
int64_t ne10,
5069-
int64_t ne12,
5070-
int64_t ne0,
5071-
int64_t ne1,
5072-
uint r2,
5073-
uint r3,
5074-
threadgroup int8_t * shared_values,
5075-
uint3 tgpig,
5076-
uint tiisg,
5077-
uint sgitg) {
5078-
5079-
const int nb = ne00/QK_K;
5063+
args_t args,
5064+
device const char * src0,
5065+
device const char * src1,
5066+
device char * dst,
5067+
threadgroup char * shmem,
5068+
uint3 tgpig,
5069+
ushort tiisg,
5070+
ushort sgitg) {
5071+
5072+
const int nb = args.ne00/QK_K;
50805073
const int r0 = tgpig.x;
50815074
const int r1 = tgpig.y;
50825075
const int im = tgpig.z;
50835076

50845077
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
50855078
const int ib_row = first_row * nb;
50865079

5087-
const uint i12 = im%ne12;
5088-
const uint i13 = im/ne12;
5080+
const uint i12 = im%args.ne12;
5081+
const uint i13 = im/args.ne12;
50895082

5090-
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
5083+
const uint offset0 = (i12/args.r2)*(nb*args.ne01) + (i13/args.r3)*(nb*args.ne01*args.ne02);
50915084

50925085
device const block_tq2_0 * x = (device const block_tq2_0 *) src0 + ib_row + offset0;
5093-
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
5086+
device const float * y = (device const float *) src1 + r1*args.ne10 + im*args.ne00*args.ne1;
50945087

50955088
float yl[32];
50965089
float sumf[N_DST]={0.f}, all_sum;
@@ -5144,40 +5137,27 @@ void kernel_mul_mv_tq2_0_f32_impl(
51445137
y4 += 4 * QK_K;
51455138
}
51465139

5140+
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
5141+
51475142
for (int row = 0; row < N_DST; ++row) {
51485143
all_sum = simd_sum(sumf[row]);
51495144
if (tiisg == 0) {
5150-
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
5145+
dst_f32[first_row + row] = all_sum;
51515146
}
51525147
}
51535148
}
51545149

51555150
[[host_name("kernel_mul_mv_tq2_0_f32")]]
51565151
kernel void kernel_mul_mv_tq2_0_f32(
5157-
device const void * src0,
5158-
device const float * src1,
5159-
device float * dst,
5160-
constant int64_t & ne00,
5161-
constant int64_t & ne01,
5162-
constant int64_t & ne02,
5163-
constant uint64_t & nb00,
5164-
constant uint64_t & nb01,
5165-
constant uint64_t & nb02,
5166-
constant int64_t & ne10,
5167-
constant int64_t & ne11,
5168-
constant int64_t & ne12,
5169-
constant uint64_t & nb10,
5170-
constant uint64_t & nb11,
5171-
constant uint64_t & nb12,
5172-
constant int64_t & ne0,
5173-
constant int64_t & ne1,
5174-
constant uint & r2,
5175-
constant uint & r3,
5176-
uint3 tgpig[[threadgroup_position_in_grid]],
5177-
uint tiisg[[thread_index_in_simdgroup]],
5178-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
5152+
constant ggml_metal_kargs_mul_mv & args,
5153+
device const char * src0,
5154+
device const char * src1,
5155+
device char * dst,
5156+
uint3 tgpig[[threadgroup_position_in_grid]],
5157+
ushort tiisg[[thread_index_in_simdgroup]],
5158+
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
51795159

5180-
kernel_mul_mv_tq2_0_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
5160+
kernel_mul_mv_tq2_0_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
51815161
}
51825162

51835163
// ======================= "True" 2-bit

0 commit comments

Comments
 (0)