@@ -5058,39 +5058,32 @@ kernel void kernel_mul_mv_q6_K_f32(
5058
5058
5059
5059
// ======================= Ternary
5060
5060
5061
+ template <typename args_t >
5061
5062
void kernel_mul_mv_tq2_0_f32_impl (
5062
- device const void * src0,
5063
- device const float * src1,
5064
- device float * dst,
5065
- int64_t ne00,
5066
- int64_t ne01,
5067
- int64_t ne02,
5068
- int64_t ne10,
5069
- int64_t ne12,
5070
- int64_t ne0,
5071
- int64_t ne1,
5072
- uint r2,
5073
- uint r3,
5074
- threadgroup int8_t * shared_values,
5075
- uint3 tgpig,
5076
- uint tiisg,
5077
- uint sgitg) {
5078
-
5079
- const int nb = ne00/QK_K;
5063
+ args_t args,
5064
+ device const char * src0,
5065
+ device const char * src1,
5066
+ device char * dst,
5067
+ threadgroup char * shmem,
5068
+ uint3 tgpig,
5069
+ ushort tiisg,
5070
+ ushort sgitg) {
5071
+
5072
+ const int nb = args.ne00 /QK_K;
5080
5073
const int r0 = tgpig.x ;
5081
5074
const int r1 = tgpig.y ;
5082
5075
const int im = tgpig.z ;
5083
5076
5084
5077
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
5085
5078
const int ib_row = first_row * nb;
5086
5079
5087
- const uint i12 = im%ne12;
5088
- const uint i13 = im/ne12;
5080
+ const uint i12 = im%args. ne12 ;
5081
+ const uint i13 = im/args. ne12 ;
5089
5082
5090
- const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
5083
+ const uint offset0 = (i12/args. r2 )*(nb*args. ne01 ) + (i13/args. r3 )*(nb*args. ne01 *args. ne02 );
5091
5084
5092
5085
device const block_tq2_0 * x = (device const block_tq2_0 *) src0 + ib_row + offset0;
5093
- device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
5086
+ device const float * y = (device const float *) src1 + r1*args. ne10 + im*args. ne00 *args. ne1 ;
5094
5087
5095
5088
float yl[32 ];
5096
5089
float sumf[N_DST]={0 .f }, all_sum;
@@ -5144,40 +5137,27 @@ void kernel_mul_mv_tq2_0_f32_impl(
5144
5137
y4 += 4 * QK_K;
5145
5138
}
5146
5139
5140
+ device float * dst_f32 = (device float *) dst + (uint64_t )im*args.ne0 *args.ne1 + (uint64_t )r1*args.ne0 ;
5141
+
5147
5142
for (int row = 0 ; row < N_DST; ++row) {
5148
5143
all_sum = simd_sum (sumf[row]);
5149
5144
if (tiisg == 0 ) {
5150
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
5145
+ dst_f32[ first_row + row] = all_sum;
5151
5146
}
5152
5147
}
5153
5148
}
5154
5149
5155
5150
[[host_name(" kernel_mul_mv_tq2_0_f32" )]]
5156
5151
kernel void kernel_mul_mv_tq2_0_f32 (
5157
- device const void * src0,
5158
- device const float * src1,
5159
- device float * dst,
5160
- constant int64_t & ne00,
5161
- constant int64_t & ne01,
5162
- constant int64_t & ne02,
5163
- constant uint64_t & nb00,
5164
- constant uint64_t & nb01,
5165
- constant uint64_t & nb02,
5166
- constant int64_t & ne10,
5167
- constant int64_t & ne11,
5168
- constant int64_t & ne12,
5169
- constant uint64_t & nb10,
5170
- constant uint64_t & nb11,
5171
- constant uint64_t & nb12,
5172
- constant int64_t & ne0,
5173
- constant int64_t & ne1,
5174
- constant uint & r2,
5175
- constant uint & r3,
5176
- uint3 tgpig[[threadgroup_position_in_grid]],
5177
- uint tiisg[[thread_index_in_simdgroup]],
5178
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
5152
+ constant ggml_metal_kargs_mul_mv & args,
5153
+ device const char * src0,
5154
+ device const char * src1,
5155
+ device char * dst,
5156
+ uint3 tgpig[[threadgroup_position_in_grid]],
5157
+ ushort tiisg[[thread_index_in_simdgroup]],
5158
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
5179
5159
5180
- kernel_mul_mv_tq2_0_f32_impl (src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3 , nullptr , tgpig, tiisg, sgitg);
5160
+ kernel_mul_mv_tq2_0_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst , nullptr , tgpig, tiisg, sgitg);
5181
5161
}
5182
5162
5183
5163
// ======================= "True" 2-bit
0 commit comments