@@ -1586,6 +1586,72 @@ kernel void kernel_cpy_f32_q4_0(
1586
1586
}
1587
1587
}
1588
1588
1589
+ kernel void kernel_cpy_f32_q4_1 (
1590
+ device const float * src0,
1591
+ device void * dst,
1592
+ constant int64_t & ne00,
1593
+ constant int64_t & ne01,
1594
+ constant int64_t & ne02,
1595
+ constant int64_t & ne03,
1596
+ constant uint64_t & nb00,
1597
+ constant uint64_t & nb01,
1598
+ constant uint64_t & nb02,
1599
+ constant uint64_t & nb03,
1600
+ constant int64_t & ne0,
1601
+ constant int64_t & ne1,
1602
+ constant int64_t & ne2,
1603
+ constant int64_t & ne3,
1604
+ constant uint64_t & nb0,
1605
+ constant uint64_t & nb1,
1606
+ constant uint64_t & nb2,
1607
+ constant uint64_t & nb3,
1608
+ uint3 tgpig[[threadgroup_position_in_grid]],
1609
+ uint3 tpitg[[thread_position_in_threadgroup]],
1610
+ uint3 ntg[[threads_per_threadgroup]]) {
1611
+ const int64_t i03 = tgpig[2 ];
1612
+ const int64_t i02 = tgpig[1 ];
1613
+ const int64_t i01 = tgpig[0 ];
1614
+
1615
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
1616
+
1617
+ const int64_t i3 = n / (ne2*ne1*ne0);
1618
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1619
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1620
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_1;
1621
+
1622
+ device block_q4_1 * dst_data = (device block_q4_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1623
+
1624
+ for (int64_t i00 = tpitg.x *QK4_1; i00 < ne00; i00 += ntg.x *QK4_1) {
1625
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
1626
+
1627
+ float min = FLT_MAX;
1628
+ float max = -FLT_MAX;
1629
+
1630
+ for (int j = 0 ; j < QK4_1; j++) {
1631
+ const float v = src[j];
1632
+ if (min > v) min = v;
1633
+ if (max < v) max = v;
1634
+ }
1635
+
1636
+ const float d = (max - min) / ((1 << 4 ) - 1 );
1637
+ const float id = d ? 1 .0f /d : 0 .0f ;
1638
+
1639
+ dst_data[i00/QK4_1].d = d;
1640
+ dst_data[i00/QK4_1].m = min;
1641
+
1642
+ for (int j = 0 ; j < QK4_1/2 ; ++j) {
1643
+ const float x0 = (src[0 + j] - min)*id;
1644
+ const float x1 = (src[QK4_1/2 + j] - min)*id;
1645
+
1646
+ const uint8_t xi0 = MIN (15 , (int8_t )(x0 + 0 .5f ));
1647
+ const uint8_t xi1 = MIN (15 , (int8_t )(x1 + 0 .5f ));
1648
+
1649
+ dst_data[i00/QK4_1].qs [j] = xi0;
1650
+ dst_data[i00/QK4_1].qs [j] |= xi1 << 4 ;
1651
+ }
1652
+ }
1653
+ }
1654
+
1589
1655
kernel void kernel_concat (
1590
1656
device const char * src0,
1591
1657
device const char * src1,
0 commit comments