@@ -1209,108 +1209,133 @@ kernel void kernel_mul_mat_q2_K_f32(
1209
1209
constant int64_t & ne00,
1210
1210
constant int64_t & ne10,
1211
1211
constant int64_t & ne0,
1212
- threadgroup float * sum [[threadgroup( 0 )]],
1212
+ constant int64_t & ne01[[buffer( 4 )]],
1213
1213
uint2 tgpig[[threadgroup_position_in_grid]],
1214
- uint2 tpitg[[thread_position_in_threadgroup ]],
1215
- uint2 tptg[[threads_per_threadgroup ]]) {
1214
+ uint tiisg[[thread_index_in_simdgroup ]],
1215
+ uint sgitg[[simdgroup_index_in_threadgroup ]]) {
1216
1216
1217
1217
const int nb = ne00/QK_K;
1218
+ const int r0 = tgpig.x ;
1219
+ const int r1 = tgpig.y ;
1218
1220
1219
- const int64_t r0 = tgpig.x ;
1220
- const int64_t r1 = tgpig.y ;
1221
-
1222
- device const block_q2_K * x = (device const block_q2_K *) src0 + r0*nb;
1223
- device const float * yy = (device const float *) src1 + r1*ne10;
1224
-
1225
- const int nth = tptg.x *tptg.y ;
1226
- const int ith = tptg.y *tpitg.x + tpitg.y ;
1221
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
1222
+ const int ib_row = first_row * nb;
1223
+ device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row;
1224
+ device const float * y = (device const float *) src1 + r1*ne10;
1225
+ float yl[32 ];
1226
+ float sumf[N_DST]={0 .f }, all_sum;
1227
1227
1228
- float sumf = 0 ;
1228
+ const int step = sizeof (block_q2_K) * nb ;
1229
1229
1230
1230
#if QK_K == 256
1231
- const int tid = tpitg.y ; // 0...16
1232
- const int il = tid/4 ; // 0...3
1233
- const int ir = tid%4 ; // 0...3
1234
- const int ip = il/2 ; // 0 or 1
1235
- const int shift1 = 4 *(il%2 );// 0 or 4
1236
- const int shift2 = shift1+2 ;// 2 or 6
1237
- const int n = 8 ;
1238
- const int is = 4 *il + (n*ir)/16 ;
1231
+ const int ix = tiisg/8 ; // 0...3
1232
+ const int it = tiisg%8 ; // 0...7
1233
+ const int im = it/4 ; // 0 or 1
1234
+ const int ir = it%4 ; // 0...3
1235
+ const int is = (8 *ir)/16 ;// 0 or 1
1239
1236
1240
- const int y_offset = 64 *il + n*ir;
1241
- const int q_offset = 32 *ip + n*ir;
1237
+ device const float * y4 = y + ix * QK_K + 128 * im + 8 * ir;
1242
1238
1243
- for (int i = tpitg. x ; i < nb; i += tptg. x ) {
1239
+ for (int ib = ix; ib < nb; ib += 4 ) {
1244
1240
1245
- device const uint8_t * q = x[i].qs + q_offset;
1246
- device const uint8_t * scales = x[i].scales + is;
1241
+ float4 sumy = {0 .f , 0 .f , 0 .f , 0 .f };
1242
+ for (int i = 0 ; i < 8 ; ++i) {
1243
+ yl[i+ 0 ] = y4[i+ 0 ]; sumy[0 ] += yl[i+ 0 ];
1244
+ yl[i+ 8 ] = y4[i+32 ]; sumy[1 ] += yl[i+ 8 ];
1245
+ yl[i+16 ] = y4[i+64 ]; sumy[2 ] += yl[i+16 ];
1246
+ yl[i+24 ] = y4[i+96 ]; sumy[3 ] += yl[i+24 ];
1247
+ }
1247
1248
1248
- uint8_t d1 = scales[0 ] & 0xF ;
1249
- uint8_t d2 = scales[2 ] & 0xF ;
1250
- uint8_t m1 = scales[0 ] >> 4 ;
1251
- uint8_t m2 = scales[2 ] >> 4 ;
1249
+ device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8 *im + is;
1250
+ device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
1251
+ device const half * dh = &x[ib].d ;
1252
1252
1253
- device const float * y = yy + i*QK_K + y_offset;
1253
+ for ( int row = 0 ; row < N_DST; row++) {
1254
1254
1255
- float2 s = {0 .f , 0 .f };
1256
- float smin = 0 ;
1257
- for (int l = 0 ; l < n; ++l) {
1258
- s[0 ] += y[l+ 0 ] * ((q[l] >> shift1) & 3 );
1259
- s[1 ] += y[l+32 ] * ((q[l] >> shift2) & 3 );
1260
- smin += y[l+ 0 ] * m1 + y[l+32 ] * m2;
1255
+ float4 acc1 = {0 .f , 0 .f , 0 .f , 0 .f };
1256
+ float4 acc2 = {0 .f , 0 .f , 0 .f , 0 .f };
1257
+ for (int i = 0 ; i < 8 ; i += 2 ) {
1258
+ acc1[0 ] += yl[i+ 0 ] * (qs[i/2 ] & 0x0003 );
1259
+ acc2[0 ] += yl[i+ 1 ] * (qs[i/2 ] & 0x0300 );
1260
+ acc1[1 ] += yl[i+ 8 ] * (qs[i/2 ] & 0x000c );
1261
+ acc2[1 ] += yl[i+ 9 ] * (qs[i/2 ] & 0x0c00 );
1262
+ acc1[2 ] += yl[i+16 ] * (qs[i/2 ] & 0x0030 );
1263
+ acc2[2 ] += yl[i+17 ] * (qs[i/2 ] & 0x3000 );
1264
+ acc1[3 ] += yl[i+24 ] * (qs[i/2 ] & 0x00c0 );
1265
+ acc2[3 ] += yl[i+25 ] * (qs[i/2 ] & 0xc000 );
1266
+ }
1267
+ float dall = dh[0 ];
1268
+ float dmin = dh[1 ] * 1 .f /16 .f ;
1269
+ sumf[row] += dall * ((acc1[0 ] + 1 .f /256 .f * acc2[0 ]) * (sc[0 ] & 0xF ) * 1 .f / 1 .f +
1270
+ (acc1[1 ] + 1 .f /256 .f * acc2[1 ]) * (sc[2 ] & 0xF ) * 1 .f / 4 .f +
1271
+ (acc1[2 ] + 1 .f /256 .f * acc2[2 ]) * (sc[4 ] & 0xF ) * 1 .f /16 .f +
1272
+ (acc1[3 ] + 1 .f /256 .f * acc2[3 ]) * (sc[6 ] & 0xF ) * 1 .f /64 .f ) -
1273
+ dmin * (sumy[0 ] * (sc[0 ] & 0xF0 ) + sumy[1 ] * (sc[2 ] & 0xF0 ) + sumy[2 ] * (sc[4 ] & 0xF0 ) + sumy[3 ] * (sc[6 ] & 0xF0 ));
1274
+
1275
+ qs += step/2 ;
1276
+ sc += step;
1277
+ dh += step/2 ;
1261
1278
}
1262
1279
1263
- const float dall = (float )x[i].d ;
1264
- const float dmin = (float )x[i].dmin ;
1265
-
1266
- sumf += dall * (s[0 ] * d1 + s[1 ] * d2) - dmin * smin;
1267
-
1280
+ y4 += 4 * QK_K;
1268
1281
}
1269
1282
#else
1270
- const int il = 4 * tpitg.x ;
1283
+ const int ix = tiisg/2 ; // 0...15
1284
+ const int it = tiisg%2 ; // 0...1
1271
1285
1272
- uint32_t aux[2 ];
1273
- thread const uint8_t * d = (thread const uint8_t *)aux;
1274
- thread const uint8_t * m = (thread const uint8_t *)aux + 4 ;
1286
+ device const float * y4 = y + ix * QK_K + 8 * it;
1275
1287
1276
- for (int i = tpitg. y ; i < nb; i += tptg. y ) {
1288
+ for (int ib = ix; ib < nb; ib += 16 ) {
1277
1289
1278
- device const uint8_t * q = x[i].qs + il;
1279
- device const float * y = yy + i*QK_K + il;
1290
+ float4 sumy = {0 .f , 0 .f , 0 .f , 0 .f };
1291
+ for (int i = 0 ; i < 8 ; ++i) {
1292
+ yl[i+ 0 ] = y4[i+ 0 ]; sumy[0 ] += yl[i+ 0 ];
1293
+ yl[i+ 8 ] = y4[i+16 ]; sumy[1 ] += yl[i+ 8 ];
1294
+ yl[i+16 ] = y4[i+32 ]; sumy[2 ] += yl[i+16 ];
1295
+ yl[i+24 ] = y4[i+48 ]; sumy[3 ] += yl[i+24 ];
1296
+ }
1280
1297
1281
- const float dall = (float )x[i].d ;
1282
- const float dmin = (float )x[i].dmin ;
1298
+ device const uint8_t * sc = (device const uint8_t *)x[ib].scales ;
1299
+ device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
1300
+ device const half * dh = &x[ib].d ;
1283
1301
1284
- device const uint32_t * a = (device const uint32_t *)x[i].scales ;
1285
- aux[0 ] = a[0 ] & 0x0f0f0f0f ;
1286
- aux[1 ] = (a[0 ] >> 4 ) & 0x0f0f0f0f ;
1302
+ for (int row = 0 ; row < N_DST; row++) {
1287
1303
1288
- for (int l = 0 ; l < 4 ; ++l) {
1289
- sumf += y[l+ 0 ] * (dall * d[0 ] * ((q[l] >> 0 ) & 3 ) - dmin * m[0 ])
1290
- + y[l+16 ] * (dall * d[1 ] * ((q[l] >> 2 ) & 3 ) - dmin * m[1 ])
1291
- + y[l+32 ] * (dall * d[2 ] * ((q[l] >> 4 ) & 3 ) - dmin * m[2 ])
1292
- + y[l+48 ] * (dall * d[3 ] * ((q[l] >> 6 ) & 3 ) - dmin * m[3 ]);
1304
+ float4 acc1 = {0 .f , 0 .f , 0 .f , 0 .f };
1305
+ float4 acc2 = {0 .f , 0 .f , 0 .f , 0 .f };
1306
+ for (int i = 0 ; i < 8 ; i += 2 ) {
1307
+ acc1[0 ] += yl[i+ 0 ] * (qs[i/2 ] & 0x0003 );
1308
+ acc2[0 ] += yl[i+ 1 ] * (qs[i/2 ] & 0x0300 );
1309
+ acc1[1 ] += yl[i+ 8 ] * (qs[i/2 ] & 0x000c );
1310
+ acc2[1 ] += yl[i+ 9 ] * (qs[i/2 ] & 0x0c00 );
1311
+ acc1[2 ] += yl[i+16 ] * (qs[i/2 ] & 0x0030 );
1312
+ acc2[2 ] += yl[i+17 ] * (qs[i/2 ] & 0x3000 );
1313
+ acc1[3 ] += yl[i+24 ] * (qs[i/2 ] & 0x00c0 );
1314
+ acc2[3 ] += yl[i+25 ] * (qs[i/2 ] & 0xc000 );
1315
+ }
1316
+
1317
+ float dall = dh[0 ];
1318
+ float dmin = dh[1 ];
1319
+ sumf[row] += dall * ((acc1[0 ] + 1 .f /256 .f * acc2[0 ]) * (sc[0 ] & 0xF ) * 1 .f / 1 .f +
1320
+ (acc1[1 ] + 1 .f /256 .f * acc2[1 ]) * (sc[1 ] & 0xF ) * 1 .f / 4 .f +
1321
+ (acc1[2 ] + 1 .f /256 .f * acc2[2 ]) * (sc[2 ] & 0xF ) * 1 .f /16 .f +
1322
+ (acc1[3 ] + 1 .f /256 .f * acc2[3 ]) * (sc[3 ] & 0xF ) * 1 .f /64 .f ) -
1323
+ dmin * (sumy[0 ] * (sc[0 ] >> 4 ) + sumy[1 ] * (sc[1 ] >> 4 ) + sumy[2 ] * (sc[2 ] >> 4 ) + sumy[3 ] * (sc[3 ] >> 4 ));
1324
+
1325
+ qs += step/2 ;
1326
+ sc += step;
1327
+ dh += step/2 ;
1293
1328
}
1329
+
1330
+ y4 += 16 * QK_K;
1294
1331
}
1295
1332
#endif
1296
1333
1297
- sum[ith] = sumf;
1298
-
1299
- //
1300
- // Accumulate the sum from all threads in the threadgroup
1301
- //
1302
- threadgroup_barrier (mem_flags::mem_threadgroup);
1303
- if (ith%4 == 0 ) {
1304
- for (int i = 1 ; i < 4 ; ++i) sum[ith] += sum[ith + i];
1305
- }
1306
- threadgroup_barrier (mem_flags::mem_threadgroup);
1307
- if (ith%16 == 0 ) {
1308
- for (int i = 4 ; i < 16 ; i += 4 ) sum[ith] += sum[ith + i];
1309
- }
1310
- threadgroup_barrier (mem_flags::mem_threadgroup);
1311
- if (ith == 0 ) {
1312
- for (int i = 16 ; i < nth; i += 16 ) sum[0 ] += sum[i];
1313
- dst[r1*ne0 + r0] = sum[0 ];
1334
+ for (int row = 0 ; row < N_DST; ++row) {
1335
+ all_sum = simd_sum (sumf[row]);
1336
+ if (tiisg == 0 ) {
1337
+ dst[r1*ne0 + first_row + row] = all_sum;
1338
+ }
1314
1339
}
1315
1340
}
1316
1341
0 commit comments