@@ -1452,138 +1452,188 @@ kernel void kernel_mul_mat_q3_K_f32(
1452
1452
1453
1453
}
1454
1454
1455
+ #if QK_K == 256
1455
1456
kernel void kernel_mul_mat_q4_K_f32 (
1456
1457
device const void * src0,
1457
1458
device const float * src1,
1458
1459
device float * dst,
1459
1460
constant int64_t & ne00,
1460
1461
constant int64_t & ne10,
1461
1462
constant int64_t & ne0,
1462
- threadgroup float * sum [[threadgroup( 0 )]],
1463
+ constant int64_t & ne01[[buffer( 4 )]],
1463
1464
uint2 tgpig[[threadgroup_position_in_grid]],
1464
- uint2 tpitg[[thread_position_in_threadgroup]],
1465
- uint2 tptg[[threads_per_threadgroup]]) {
1466
-
1467
- const int nb = ne00/QK_K;
1468
-
1469
- const int64_t r0 = tgpig.x ;
1470
- const int64_t r1 = tgpig.y ;
1471
-
1472
- const int nth = tptg.x *tptg.y ;
1473
- const int ith = tptg.y *tpitg.x + tpitg.y ;
1474
-
1475
- device const block_q4_K * x = (device const block_q4_K *) src0 + r0*nb;
1476
- device const float * yy = (device const float *) src1 + r1*ne10;
1477
-
1478
- float sumf = 0 ;
1479
-
1480
- #if QK_K == 256
1465
+ uint tiisg[[thread_index_in_simdgroup]],
1466
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1481
1467
1482
1468
const uint16_t kmask1 = 0x3f3f ;
1483
1469
const uint16_t kmask2 = 0x0f0f ;
1484
1470
const uint16_t kmask3 = 0xc0c0 ;
1485
1471
1486
- const int tid = tpitg. y ; // 0...16
1487
- const int il = tid/ 4 ; // 0...3
1488
- const int ir = tid - 4 *il; // 0...3
1489
- const int n = 4 ;
1472
+ const int ix = tiisg/ 8 ; // 0...3
1473
+ const int it = tiisg% 8 ; // 0...7
1474
+ const int im = it/ 4 ; // 0 or 1
1475
+ const int ir = it% 4 ; // 0...3
1490
1476
1491
- const int im = il/2 ; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
1492
- const int in = il%2 ;
1477
+ const int nb = ne00/QK_K;
1478
+ const int r0 = tgpig.x ;
1479
+ const int r1 = tgpig.y ;
1480
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
1481
+ const int ib_row = first_row * nb;
1482
+ device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row;
1483
+ device const float * y = (device const float *) src1 + r1*ne10;
1484
+ float yl[16 ];
1485
+ float yh[16 ];
1486
+ float sumf[N_DST]={0 .f }, all_sum;
1493
1487
1494
- const int l0 = n*(2 *ir + in);
1495
- const int q_offset = 32 *im + l0;
1496
- const int y_offset = 64 *im + l0;
1488
+ const int step = sizeof (block_q4_K) * nb / 2 ;
1497
1489
1498
- uchar2 sc1, sc2, sc3, sc4 ;
1490
+ device const float * y4 = y + ix * QK_K + 64 * im + 8 * ir ;
1499
1491
1500
- for (int i = tpitg.x ; i < nb; i += tptg.x ) {
1492
+ uint16_t sc16[4 ];
1493
+ thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
1501
1494
1502
- device const uint8_t * q1 = (x + i)->qs + q_offset;
1503
- device const uint8_t * q2 = q1 + 64 ;
1504
- device const float * y1 = yy + i*QK_K + y_offset;
1505
- device const float * y2 = y1 + 128 ;
1495
+ for (int ib = ix; ib < nb; ib += 4 ) {
1506
1496
1507
- const float dall = (float )((x + i)->d );
1508
- const float dmin = (float )((x + i)->dmin );
1497
+ float4 sumy = {0 .f , 0 .f , 0 .f , 0 .f };
1498
+ for (int i = 0 ; i < 8 ; ++i) {
1499
+ yl[i+0 ] = y4[i+ 0 ]; sumy[0 ] += yl[i+0 ];
1500
+ yl[i+8 ] = y4[i+ 32 ]; sumy[1 ] += yl[i+8 ];
1501
+ yh[i+0 ] = y4[i+128 ]; sumy[2 ] += yh[i+0 ];
1502
+ yh[i+8 ] = y4[i+160 ]; sumy[3 ] += yh[i+8 ];
1503
+ }
1509
1504
1510
- device const uint16_t * a = (device const uint16_t *)(x + i)->scales ;
1511
- sc1 = as_type<uchar2>((uint16_t )(a[im+0 ] & kmask1));
1512
- sc2 = as_type<uchar2>((uint16_t )(a[im+2 ] & kmask1));
1513
- sc3 = as_type<uchar2>((uint16_t )(((a[im+4 ] >> 0 ) & kmask2) | ((a[im+0 ] & kmask3) >> 2 )));
1514
- sc4 = as_type<uchar2>((uint16_t )(((a[im+4 ] >> 4 ) & kmask2) | ((a[im+2 ] & kmask3) >> 2 )));
1505
+ device const uint16_t * sc = (device const uint16_t *)x[ib].scales + im;
1506
+ device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
1507
+ device const half * dh = &x[ib].d ;
1515
1508
1516
- float4 s = {0 .f , 0 .f , 0 .f , 0 .f };
1517
- float smin = 0 ;
1518
- for (int l = 0 ; l < n; ++l) {
1509
+ for (int row = 0 ; row < N_DST; row++) {
1519
1510
1520
- s[0 ] += y1[l] * (q1[l] & 0xF ); s[1 ] += y1[l+32 ] * (q1[l] >> 4 );
1521
- s[2 ] += y2[l] * (q2[l] & 0xF ); s[3 ] += y2[l+32 ] * (q2[l] >> 4 );
1522
- smin += y1[l] * sc2[0 ] + y1[l+32 ] * sc2[1 ] + y2[l] * sc4[0 ] + y2[l+32 ] * sc4[1 ];
1511
+ sc16[0 ] = sc[0 ] & kmask1;
1512
+ sc16[1 ] = sc[2 ] & kmask1;
1513
+ sc16[2 ] = ((sc[4 ] >> 0 ) & kmask2) | ((sc[0 ] & kmask3) >> 2 );
1514
+ sc16[3 ] = ((sc[4 ] >> 4 ) & kmask2) | ((sc[2 ] & kmask3) >> 2 );
1515
+
1516
+ device const uint16_t * q2 = q1 + 32 ;
1517
+
1518
+ float4 acc1 = {0 .f , 0 .f , 0 .f , 0 .f };
1519
+ float4 acc2 = {0 .f , 0 .f , 0 .f , 0 .f };
1520
+ for (int i = 0 ; i < 8 ; i += 2 ) {
1521
+ acc1[0 ] += yl[i+0 ] * (q1[i/2 ] & 0x000F );
1522
+ acc1[1 ] += yl[i+1 ] * (q1[i/2 ] & 0x0F00 );
1523
+ acc1[2 ] += yl[i+8 ] * (q1[i/2 ] & 0x00F0 );
1524
+ acc1[3 ] += yl[i+9 ] * (q1[i/2 ] & 0xF000 );
1525
+ acc2[0 ] += yh[i+0 ] * (q2[i/2 ] & 0x000F );
1526
+ acc2[1 ] += yh[i+1 ] * (q2[i/2 ] & 0x0F00 );
1527
+ acc2[2 ] += yh[i+8 ] * (q2[i/2 ] & 0x00F0 );
1528
+ acc2[3 ] += yh[i+9 ] * (q2[i/2 ] & 0xF000 );
1529
+ }
1523
1530
1531
+ float dall = dh[0 ];
1532
+ float dmin = dh[1 ];
1533
+ sumf[row] += dall * ((acc1[0 ] + 1 .f /256 .f * acc1[1 ]) * sc8[0 ] +
1534
+ (acc1[2 ] + 1 .f /256 .f * acc1[3 ]) * sc8[1 ] * 1 .f /16 .f +
1535
+ (acc2[0 ] + 1 .f /256 .f * acc2[1 ]) * sc8[4 ] +
1536
+ (acc2[2 ] + 1 .f /256 .f * acc2[3 ]) * sc8[5 ] * 1 .f /16 .f ) -
1537
+ dmin * (sumy[0 ] * sc8[2 ] + sumy[1 ] * sc8[3 ] + sumy[2 ] * sc8[6 ] + sumy[3 ] * sc8[7 ]);
1538
+
1539
+ q1 += step;
1540
+ sc += step;
1541
+ dh += step;
1524
1542
}
1525
- sumf += dall * (s[0 ] * sc1[0 ] + s[1 ] * sc1[1 ] + s[2 ] * sc3[0 ] + s[3 ] * sc3[1 ]) - dmin * smin;
1526
1543
1544
+ y4 += 4 * QK_K;
1527
1545
}
1546
+
1547
+ for (int row = 0 ; row < N_DST; ++row) {
1548
+ all_sum = simd_sum (sumf[row]);
1549
+ if (tiisg == 0 ) {
1550
+ dst[r1*ne0 + first_row + row] = all_sum;
1551
+ }
1552
+ }
1553
+ }
1528
1554
#else
1529
- uint16_t aux16[2 ];
1530
- thread const uint8_t * scales = (thread const uint8_t *)aux16;
1555
+ kernel void kernel_mul_mat_q4_K_f32 (
1556
+ device const void * src0,
1557
+ device const float * src1,
1558
+ device float * dst,
1559
+ constant int64_t & ne00,
1560
+ constant int64_t & ne10,
1561
+ constant int64_t & ne0,
1562
+ constant int64_t & ne01[[buffer(4 )]],
1563
+ uint2 tgpig[[threadgroup_position_in_grid]],
1564
+ uint tiisg[[thread_index_in_simdgroup]],
1565
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1531
1566
1532
- const int il = 4 *tpitg.x ;
1567
+ const int ix = tiisg/4 ; // 0...7
1568
+ const int it = tiisg%4 ; // 0...3
1533
1569
1534
- for (int i = tpitg.y ; i < nb; i += tptg.y ) {
1570
+ const int nb = ne00/QK_K;
1571
+ const int r0 = tgpig.x ;
1572
+ const int r1 = tgpig.y ;
1573
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
1574
+ const int ib_row = first_row * nb;
1575
+ device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row;
1576
+ device const float * y = (device const float *) src1 + r1*ne10;
1577
+ float yl[8 ];
1578
+ float yh[8 ];
1579
+ float sumf[N_DST]={0 .f }, all_sum;
1535
1580
1536
- device const uint8_t * q = x[i].qs + il;
1537
- device const float * y = yy + i * QK_K + il;
1581
+ const int step = sizeof (block_q4_K) * nb / 2 ;
1538
1582
1539
- const float d = (float )x[i].d [0 ];
1540
- const float m = (float )x[i].d [1 ];
1583
+ device const float * y4 = y + ix * QK_K + 8 * it;
1541
1584
1542
- device const uint16_t * a = (device const uint16_t *)x[i].scales ;
1543
- aux16[0 ] = a[0 ] & 0x0f0f ;
1544
- aux16[1 ] = (a[0 ] >> 4 ) & 0x0f0f ;
1585
+ uint16_t sc16[4 ];
1545
1586
1546
- for (int l = 0 ; l < 4 ; ++l) {
1547
- sumf += d * scales[0 ] * (y[l+ 0 ] * (q[l] & 0xF ) + y[l+16 ] * (q[l+16 ] & 0xF )) - m * scales[2 ] * (y[l+ 0 ] + y[l+16 ])
1548
- + d * scales[1 ] * (y[l+32 ] * (q[l] >> 4 ) + y[l+48 ] * (q[l+16 ] >> 4 )) - m * scales[3 ] * (y[l+32 ] + y[l+48 ]);
1587
+ for (int ib = ix; ib < nb; ib += 8 ) {
1588
+
1589
+ float2 sumy = {0 .f , 0 .f };
1590
+ for (int i = 0 ; i < 8 ; ++i) {
1591
+ yl[i] = y4[i+ 0 ]; sumy[0 ] += yl[i];
1592
+ yh[i] = y4[i+32 ]; sumy[1 ] += yh[i];
1549
1593
}
1550
- }
1551
- #endif
1552
1594
1553
- sum[ith] = sumf;
1595
+ device const uint16_t * sc = (device const uint16_t *)x[ib].scales ;
1596
+ device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
1597
+ device const half * dh = x[ib].d ;
1554
1598
1555
- //
1556
- // Accumulate the sum from all threads in the threadgroup
1557
- // This version is slightly faster than the commented out one below,
1558
- // which I copy-pasted from ggerganov's q4_0 dot product for metal.
1559
- //
1560
- threadgroup_barrier (mem_flags::mem_threadgroup);
1561
- if (ith%4 == 0 ) {
1562
- for (int i = 1 ; i < 4 ; ++i) sum[ith] += sum[ith + i];
1563
- }
1564
- threadgroup_barrier (mem_flags::mem_threadgroup);
1565
- if (ith%16 == 0 ) {
1566
- for (int i = 4 ; i < 16 ; i += 4 ) sum[ith] += sum[ith + i];
1567
- }
1568
- threadgroup_barrier (mem_flags::mem_threadgroup);
1569
- if (ith == 0 ) {
1570
- for (int i = 16 ; i < nth; i += 16 ) sum[0 ] += sum[i];
1571
- dst[r1*ne0 + r0] = sum[0 ];
1572
- }
1599
+ for (int row = 0 ; row < N_DST; row++) {
1600
+
1601
+ sc16[0 ] = sc[0 ] & 0x000f ;
1602
+ sc16[1 ] = sc[0 ] & 0x0f00 ;
1603
+ sc16[2 ] = sc[0 ] & 0x00f0 ;
1604
+ sc16[3 ] = sc[0 ] & 0xf000 ;
1605
+
1606
+ float2 acc1 = {0 .f , 0 .f };
1607
+ float2 acc2 = {0 .f , 0 .f };
1608
+ for (int i = 0 ; i < 8 ; i += 2 ) {
1609
+ acc1[0 ] += yl[i+0 ] * (qs[i/2 ] & 0x000F );
1610
+ acc1[1 ] += yl[i+1 ] * (qs[i/2 ] & 0x0F00 );
1611
+ acc2[0 ] += yh[i+0 ] * (qs[i/2 ] & 0x00F0 );
1612
+ acc2[1 ] += yh[i+1 ] * (qs[i/2 ] & 0xF000 );
1613
+ }
1614
+
1615
+ float dall = dh[0 ];
1616
+ float dmin = dh[1 ];
1617
+ sumf[row] += dall * ((acc1[0 ] + 1 .f /256 .f * acc1[1 ]) * sc16[0 ] +
1618
+ (acc2[0 ] + 1 .f /256 .f * acc2[1 ]) * sc16[1 ] * 1 .f /4096 .f ) -
1619
+ dmin * 1 .f /16 .f * (sumy[0 ] * sc16[2 ] + sumy[1 ] * sc16[3 ] * 1 .f /256 .f );
1620
+
1621
+ qs += step;
1622
+ sc += step;
1623
+ dh += step;
1624
+ }
1573
1625
1574
- // // accumulate the sum from all threads in the threadgroup
1575
- // threadgroup_barrier(mem_flags::mem_threadgroup);
1576
- // for (uint i = nth/2; i > 0; i /= 2) {
1577
- // if (ith < i) {
1578
- // sum[ith] += sum[ith + i];
1579
- // }
1580
- // threadgroup_barrier(mem_flags::mem_threadgroup);
1581
- // }
1626
+ y4 += 8 * QK_K;
1627
+ }
1582
1628
1583
- // if (ith == 0) {
1584
- // dst[r1*ne0 + r0] = sum[0];
1585
- // }
1629
+ for (int row = 0 ; row < N_DST; ++row) {
1630
+ all_sum = simd_sum (sumf[row]);
1631
+ if (tiisg == 0 ) {
1632
+ dst[r1*ne0 + first_row + row] = all_sum;
1633
+ }
1634
+ }
1586
1635
}
1636
+ #endif
1587
1637
1588
1638
kernel void kernel_mul_mat_q5_K_f32 (
1589
1639
device const void * src0,
0 commit comments