|
1 | 1 | #include "common.hpp"
|
| 2 | +#include "ggml-sycl/presets.hpp" |
2 | 3 | #include "ggml.h"
|
3 | 4 | #include "element_wise.hpp"
|
| 5 | +#include <cstddef> |
| 6 | +#include <cstdint> |
4 | 7 |
|
5 | 8 | static void acc_f32(const float * x, const float * y, float * dst, const int ne,
|
6 | 9 | const int ne10, const int ne11, const int ne12,
|
@@ -324,6 +327,34 @@ static void clamp(const T * x, T * dst, const float min, const float max, const
|
324 | 327 | dst[i] = x[i] < static_cast<T>(min) ? static_cast<T>(min) : (x[i] > static_cast<T>(max) ? static_cast<T>(max) : x[i]);
|
325 | 328 | }
|
326 | 329 |
|
| 330 | +// Fused GLU kernels |
| 331 | +template<typename T> |
| 332 | +static void gated_op_fused_geglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, const sycl::nd_item<1> &item_ct1) { |
| 333 | + for (auto i = item_ct1.get_global_id(0); i < k; i += item_ct1.get_global_range(0)) { |
| 334 | + const int64_t j = ((i / n) * o) + (i % n); |
| 335 | + const T x_val = x[j]; |
| 336 | + const T gelu_val = x_val * (static_cast<T>(1.0f) / (static_cast<T>(1.0f) + sycl::native::exp(GELU_QUICK_COEF * x_val))); |
| 337 | + |
| 338 | + dst[i] = gelu_val * g[j]; |
| 339 | + } |
| 340 | +} |
| 341 | + |
| 342 | +template<typename T> |
| 343 | +static void gated_op_fused_reglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, const sycl::nd_item<1> &item_ct1) { |
| 344 | + for (auto i = item_ct1.get_global_id(0); i < k; i += item_ct1.get_global_range(0)) { |
| 345 | + const int64_t j = ((i / n) * o) + (i % n); |
| 346 | + dst[i] = sycl::max((x[j]), static_cast<T>(0)) * g[j]; |
| 347 | + } |
| 348 | +} |
| 349 | + |
| 350 | +template<typename T> |
| 351 | +static void gated_op_fused_swiglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, const sycl::nd_item<1> &item_ct1) { |
| 352 | + for (auto i = item_ct1.get_global_id(0); i < k; i += item_ct1.get_global_range(0)) { |
| 353 | + const int64_t j = ((i / n) * o) + (i % n); |
| 354 | + dst[i] = (x[j] / (static_cast<T>(1) + sycl::native::exp(-x[j]))) * g[j]; |
| 355 | + } |
| 356 | +} |
| 357 | + |
327 | 358 | static void acc_f32_sycl(const float *x, const float *y, float *dst,
|
328 | 359 | const int n_elements, const int ne10, const int ne11,
|
329 | 360 | const int ne12, const int nb1, const int nb2,
|
@@ -649,6 +680,33 @@ static void clamp_sycl(const T *x, T *dst, const float min,
|
649 | 680 | });
|
650 | 681 | }
|
651 | 682 |
|
| 683 | +template<typename T> |
| 684 | +static void geglu_sycl(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, queue_ptr main_stream) { |
| 685 | + const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE); |
| 686 | + main_stream->parallel_for( |
| 687 | + sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { |
| 688 | + gated_op_fused_geglu(x, g, dst, k, n, o, item_ct1); |
| 689 | + }); |
| 690 | +} |
| 691 | + |
| 692 | +template<typename T> |
| 693 | +static void reglu_sycl(const T * x, const T* g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, queue_ptr main_stream) { |
| 694 | + const uint32_t num_blocks = ceil_div(k, SYCL_RELU_BLOCK_SIZE); |
| 695 | + main_stream->parallel_for( |
| 696 | + sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { |
| 697 | + gated_op_fused_reglu(x, g, dst, k, n, o, item_ct1); |
| 698 | + }); |
| 699 | +} |
| 700 | + |
| 701 | +template<typename T> |
| 702 | +static void swiglu_sycl(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, queue_ptr main_stream) { |
| 703 | + const uint32_t num_blocks = ceil_div(k, SYCL_SILU_BLOCK_SIZE); |
| 704 | + main_stream->parallel_for( |
| 705 | + sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { |
| 706 | + gated_op_fused_swiglu(x, g, dst, k, n, o, item_ct1); |
| 707 | + }); |
| 708 | +} |
| 709 | + |
652 | 710 | inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
653 | 711 | #if defined (GGML_SYCL_F16)
|
654 | 712 | GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
@@ -1444,6 +1502,152 @@ inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
|
1444 | 1502 | acc_f32_sycl(src0_dd, src1_dd, dst_dd, ggml_nelements(dst), dst->src[1]->ne[0], dst->src[1]->ne[1], dst->src[1]->ne[2], nb1, nb2, offset, main_stream);
|
1445 | 1503 | }
|
1446 | 1504 |
|
| 1505 | +inline void ggml_sycl_op_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { |
| 1506 | +#if defined (GGML_SYCL_F16) |
| 1507 | + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); |
| 1508 | + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); |
| 1509 | + |
| 1510 | +#else |
| 1511 | + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); |
| 1512 | + GGML_ASSERT(dst->type == GGML_TYPE_F32); |
| 1513 | +#endif |
| 1514 | + GGML_ASSERT(dst->src[0]->type == dst->type); |
| 1515 | + dpct::queue_ptr main_stream = ctx.stream(); |
| 1516 | + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); |
| 1517 | + const int64_t nc = dst->src[0]->ne[0] / 2; |
| 1518 | + GGML_ASSERT(dst->ne[0] == nc); |
| 1519 | + GGML_ASSERT(ggml_is_contiguous_1(dst->src[0])); |
| 1520 | + GGML_ASSERT(ggml_is_contiguous(dst)); |
| 1521 | + const int32_t swapped = ((const int32_t *) dst->op_params)[1]; |
| 1522 | + const void * src0_d = dst->src[0]->data; |
| 1523 | + void * dst_d = dst->data; |
| 1524 | + switch (dst->type) { |
| 1525 | +#if defined (GGML_SYCL_F16) |
| 1526 | + case GGML_TYPE_F16: |
| 1527 | + { |
| 1528 | + geglu_sycl((const sycl::half *) src0_d + (swapped ? nc : 0), |
| 1529 | + (const sycl::half *)src0_d + (swapped ? 0 : nc), |
| 1530 | + (sycl::half *) dst_d, |
| 1531 | + ggml_nelements(dst), |
| 1532 | + nc, |
| 1533 | + dst->src[0]->nb[1] / sizeof(sycl::half), |
| 1534 | + main_stream); |
| 1535 | + break; |
| 1536 | + } |
| 1537 | +#endif |
| 1538 | + case GGML_TYPE_F32: |
| 1539 | + { |
| 1540 | + geglu_sycl((const float *) src0_d + (swapped ? nc : 0), |
| 1541 | + (const float *)src0_d + (swapped ? 0 : nc), |
| 1542 | + (float *) dst_d, |
| 1543 | + ggml_nelements(dst), |
| 1544 | + nc, |
| 1545 | + dst->src[0]->nb[1] / sizeof(float), |
| 1546 | + main_stream); |
| 1547 | + break; |
| 1548 | + } |
| 1549 | + default: |
| 1550 | + GGML_ABORT("GGML tensor type not supported!\n"); |
| 1551 | + } |
| 1552 | +} |
| 1553 | + |
| 1554 | +inline void ggml_sycl_op_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { |
| 1555 | +#if defined (GGML_SYCL_F16) |
| 1556 | + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); |
| 1557 | + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); |
| 1558 | + |
| 1559 | +#else |
| 1560 | + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); |
| 1561 | + GGML_ASSERT(dst->type == GGML_TYPE_F32); |
| 1562 | +#endif |
| 1563 | + GGML_ASSERT(dst->src[0]->type == dst->type); |
| 1564 | + dpct::queue_ptr main_stream = ctx.stream(); |
| 1565 | + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); |
| 1566 | + const int64_t nc = dst->src[0]->ne[0] / 2; |
| 1567 | + GGML_ASSERT(dst->ne[0] == nc); |
| 1568 | + GGML_ASSERT(ggml_is_contiguous_1(dst->src[0])); |
| 1569 | + GGML_ASSERT(ggml_is_contiguous(dst)); |
| 1570 | + const int32_t swapped = ((const int32_t *) dst->op_params)[1]; |
| 1571 | + const void * src0_d = dst->src[0]->data; |
| 1572 | + void * dst_d = dst->data; |
| 1573 | + switch (dst->type) { |
| 1574 | +#if defined (GGML_SYCL_F16) |
| 1575 | + case GGML_TYPE_F16: |
| 1576 | + { |
| 1577 | + reglu_sycl((const sycl::half *) src0_d + (swapped ? nc : 0), |
| 1578 | + (const sycl::half *)src0_d + (swapped ? 0 : nc), |
| 1579 | + (sycl::half *) dst_d, |
| 1580 | + ggml_nelements(dst), |
| 1581 | + nc, |
| 1582 | + dst->src[0]->nb[1] / sizeof(sycl::half), |
| 1583 | + main_stream); |
| 1584 | + break; |
| 1585 | + } |
| 1586 | +#endif |
| 1587 | + case GGML_TYPE_F32: |
| 1588 | + { |
| 1589 | + reglu_sycl((const float *) src0_d + (swapped ? nc : 0), |
| 1590 | + (const float *)src0_d + (swapped ? 0 : nc), |
| 1591 | + (float *) dst_d, |
| 1592 | + ggml_nelements(dst), |
| 1593 | + nc, |
| 1594 | + dst->src[0]->nb[1] / sizeof(float), |
| 1595 | + main_stream); |
| 1596 | + break; |
| 1597 | + } |
| 1598 | + default: |
| 1599 | + GGML_ABORT("GGML tensor type not supported!\n"); |
| 1600 | + } |
| 1601 | +} |
| 1602 | + |
| 1603 | +inline void ggml_sycl_op_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { |
| 1604 | +#if defined (GGML_SYCL_F16) |
| 1605 | + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); |
| 1606 | + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); |
| 1607 | + |
| 1608 | +#else |
| 1609 | + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); |
| 1610 | + GGML_ASSERT(dst->type == GGML_TYPE_F32); |
| 1611 | +#endif |
| 1612 | + GGML_ASSERT(dst->src[0]->type == dst->type); |
| 1613 | + dpct::queue_ptr main_stream = ctx.stream(); |
| 1614 | + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); |
| 1615 | + const int64_t nc = dst->src[0]->ne[0] / 2; |
| 1616 | + GGML_ASSERT(dst->ne[0] == nc); |
| 1617 | + GGML_ASSERT(ggml_is_contiguous_1(dst->src[0])); |
| 1618 | + GGML_ASSERT(ggml_is_contiguous(dst)); |
| 1619 | + const int32_t swapped = ((const int32_t *) dst->op_params)[1]; |
| 1620 | + const void * src0_d = dst->src[0]->data; |
| 1621 | + void * dst_d = dst->data; |
| 1622 | + switch (dst->type) { |
| 1623 | +#if defined (GGML_SYCL_F16) |
| 1624 | + case GGML_TYPE_F16: |
| 1625 | + { |
| 1626 | + swiglu_sycl((const sycl::half *) src0_d + (swapped ? nc : 0), |
| 1627 | + (const sycl::half *)src0_d + (swapped ? 0 : nc), |
| 1628 | + (sycl::half *) dst_d, |
| 1629 | + ggml_nelements(dst), |
| 1630 | + nc, |
| 1631 | + dst->src[0]->nb[1] / sizeof(sycl::half), |
| 1632 | + main_stream); |
| 1633 | + break; |
| 1634 | + } |
| 1635 | +#endif |
| 1636 | + case GGML_TYPE_F32: |
| 1637 | + { |
| 1638 | + swiglu_sycl((const float *) src0_d + (swapped ? nc : 0), |
| 1639 | + (const float *)src0_d + (swapped ? 0 : nc), |
| 1640 | + (float *) dst_d, |
| 1641 | + ggml_nelements(dst), |
| 1642 | + nc, |
| 1643 | + dst->src[0]->nb[1] / sizeof(float), |
| 1644 | + main_stream); |
| 1645 | + break; |
| 1646 | + } |
| 1647 | + default: |
| 1648 | + GGML_ABORT("GGML tensor type not supported!\n"); |
| 1649 | + } |
| 1650 | +} |
1447 | 1651 |
|
1448 | 1652 | void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
1449 | 1653 | scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
@@ -1569,3 +1773,20 @@ void ggml_sycl_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
1569 | 1773 | scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
1570 | 1774 | ggml_sycl_op_elu(ctx, dst);
|
1571 | 1775 | }
|
| 1776 | + |
| 1777 | +void ggml_sycl_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { |
| 1778 | + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); |
| 1779 | + ggml_sycl_op_geglu(ctx, dst); |
| 1780 | +} |
| 1781 | + |
| 1782 | +void ggml_sycl_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { |
| 1783 | + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); |
| 1784 | + ggml_sycl_op_reglu(ctx, dst); |
| 1785 | +} |
| 1786 | + |
| 1787 | +void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { |
| 1788 | + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); |
| 1789 | + ggml_sycl_op_swiglu(ctx, dst); |
| 1790 | +} |
| 1791 | + |
| 1792 | + |
0 commit comments