|
3 | 3 | using namespace metal;
|
4 | 4 |
|
5 | 5 | #define MAX(x, y) ((x) > (y) ? (x) : (y))
|
| 6 | +#define MIN(x, y) ((x) < (y) ? (x) : (y)) |
6 | 7 |
|
7 | 8 | #define QK4_0 32
|
8 | 9 | #define QR4_0 2
|
@@ -1518,6 +1519,73 @@ kernel void kernel_cpy_f32_q8_0(
|
1518 | 1519 | }
|
1519 | 1520 | }
|
1520 | 1521 |
|
| 1522 | +kernel void kernel_cpy_f32_q4_0( |
| 1523 | + device const float * src0, |
| 1524 | + device void * dst, |
| 1525 | + constant int64_t & ne00, |
| 1526 | + constant int64_t & ne01, |
| 1527 | + constant int64_t & ne02, |
| 1528 | + constant int64_t & ne03, |
| 1529 | + constant uint64_t & nb00, |
| 1530 | + constant uint64_t & nb01, |
| 1531 | + constant uint64_t & nb02, |
| 1532 | + constant uint64_t & nb03, |
| 1533 | + constant int64_t & ne0, |
| 1534 | + constant int64_t & ne1, |
| 1535 | + constant int64_t & ne2, |
| 1536 | + constant int64_t & ne3, |
| 1537 | + constant uint64_t & nb0, |
| 1538 | + constant uint64_t & nb1, |
| 1539 | + constant uint64_t & nb2, |
| 1540 | + constant uint64_t & nb3, |
| 1541 | + uint3 tgpig[[threadgroup_position_in_grid]], |
| 1542 | + uint3 tpitg[[thread_position_in_threadgroup]], |
| 1543 | + uint3 ntg[[threads_per_threadgroup]]) { |
| 1544 | + const int64_t i03 = tgpig[2]; |
| 1545 | + const int64_t i02 = tgpig[1]; |
| 1546 | + const int64_t i01 = tgpig[0]; |
| 1547 | + |
| 1548 | + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; |
| 1549 | + |
| 1550 | + const int64_t i3 = n / (ne2*ne1*ne0); |
| 1551 | + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); |
| 1552 | + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; |
| 1553 | + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_0; |
| 1554 | + |
| 1555 | + device block_q4_0 * dst_data = (device block_q4_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); |
| 1556 | + |
| 1557 | + for (int64_t i00 = tpitg.x*QK4_0; i00 < ne00; i00 += ntg.x*QK4_0) { |
| 1558 | + device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); |
| 1559 | + |
| 1560 | + float amax = 0.0f; // absolute max |
| 1561 | + float max = 0.0f; |
| 1562 | + |
| 1563 | + for (int j = 0; j < QK4_0; j++) { |
| 1564 | + const float v = src[j]; |
| 1565 | + if (amax < fabs(v)) { |
| 1566 | + amax = fabs(v); |
| 1567 | + max = v; |
| 1568 | + } |
| 1569 | + } |
| 1570 | + |
| 1571 | + const float d = max / -8; |
| 1572 | + const float id = d ? 1.0f/d : 0.0f; |
| 1573 | + |
| 1574 | + dst_data[i00/QK4_0].d = d; |
| 1575 | + |
| 1576 | + for (int j = 0; j < QK4_0/2; ++j) { |
| 1577 | + const float x0 = src[0 + j]*id; |
| 1578 | + const float x1 = src[QK4_0/2 + j]*id; |
| 1579 | + |
| 1580 | + const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f)); |
| 1581 | + const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f)); |
| 1582 | + |
| 1583 | + dst_data[i00/QK4_0].qs[j] = xi0; |
| 1584 | + dst_data[i00/QK4_0].qs[j] |= xi1 << 4; |
| 1585 | + } |
| 1586 | + } |
| 1587 | +} |
| 1588 | + |
1521 | 1589 | kernel void kernel_concat(
|
1522 | 1590 | device const char * src0,
|
1523 | 1591 | device const char * src1,
|
|
0 commit comments