|
1 | 1 | #include "common.hpp"
|
| 2 | +#include "ggml-sycl/presets.hpp" |
2 | 3 | #include "ggml.h"
|
3 | 4 | #include "element_wise.hpp"
|
| 5 | +#include <cstddef> |
| 6 | +#include <cstdint> |
4 | 7 |
|
5 | 8 | static void acc_f32(const float * x, const float * y, float * dst, const int ne,
|
6 | 9 | const int ne10, const int ne11, const int ne12,
|
@@ -324,6 +327,34 @@ static void clamp(const T * x, T * dst, const float min, const float max, const
|
324 | 327 | dst[i] = x[i] < static_cast<T>(min) ? static_cast<T>(min) : (x[i] > static_cast<T>(max) ? static_cast<T>(max) : x[i]);
|
325 | 328 | }
|
326 | 329 |
|
| 330 | +// Fused GLU kernels |
| 331 | +template<typename T> |
| 332 | +static void gated_op_fused_geglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, const sycl::nd_item<1> &item_ct1) { |
| 333 | + for (auto i = item_ct1.get_global_id(0); i < k; i += item_ct1.get_global_range(0)) { |
| 334 | + const int64_t j = ((i / n) * o) + (i % n); |
| 335 | + const T x_val = x[j]; |
| 336 | + const T gelu_val = x_val * (static_cast<T>(1.0f) / (static_cast<T>(1.0f) + sycl::native::exp(GELU_QUICK_COEF * x_val))); |
| 337 | + |
| 338 | + dst[i] = gelu_val * g[j]; |
| 339 | + } |
| 340 | +} |
| 341 | + |
| 342 | +template<typename T> |
| 343 | +static void gated_op_fused_reglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, const sycl::nd_item<1> &item_ct1) { |
| 344 | + for (auto i = item_ct1.get_global_id(0); i < k; i += item_ct1.get_global_range(0)) { |
| 345 | + const int64_t j = ((i / n) * o) + (i % n); |
| 346 | + dst[i] = sycl::max((x[j]), static_cast<T>(0)) * g[j]; |
| 347 | + } |
| 348 | +} |
| 349 | + |
| 350 | +template<typename T> |
| 351 | +static void gated_op_fused_swiglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, const sycl::nd_item<1> &item_ct1) { |
| 352 | + for (auto i = item_ct1.get_global_id(0); i < k; i += item_ct1.get_global_range(0)) { |
| 353 | + const int64_t j = ((i / n) * o) + (i % n); |
| 354 | + dst[i] = (x[j] / (static_cast<T>(1) + sycl::native::exp(-x[j]))) * g[j]; |
| 355 | + } |
| 356 | +} |
| 357 | + |
327 | 358 | static void acc_f32_sycl(const float *x, const float *y, float *dst,
|
328 | 359 | const int n_elements, const int ne10, const int ne11,
|
329 | 360 | const int ne12, const int nb1, const int nb2,
|
@@ -589,6 +620,33 @@ static void clamp_sycl(const T *x, T *dst, const float min,
|
589 | 620 | [=](sycl::nd_item<3> item_ct1) { clamp(x, dst, min, max, k, item_ct1); });
|
590 | 621 | }
|
591 | 622 |
|
| 623 | +template<typename T> |
| 624 | +static void geglu_sycl(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, queue_ptr main_stream) { |
| 625 | + const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE); |
| 626 | + main_stream->parallel_for( |
| 627 | + sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { |
| 628 | + gated_op_fused_geglu(x, g, dst, k, n, o, item_ct1); |
| 629 | + }); |
| 630 | +} |
| 631 | + |
| 632 | +template<typename T> |
| 633 | +static void reglu_sycl(const T * x, const T* g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, queue_ptr main_stream) { |
| 634 | + const uint32_t num_blocks = ceil_div(k, SYCL_RELU_BLOCK_SIZE); |
| 635 | + main_stream->parallel_for( |
| 636 | + sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { |
| 637 | + gated_op_fused_reglu(x, g, dst, k, n, o, item_ct1); |
| 638 | + }); |
| 639 | +} |
| 640 | + |
| 641 | +template<typename T> |
| 642 | +static void swiglu_sycl(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, queue_ptr main_stream) { |
| 643 | + const uint32_t num_blocks = ceil_div(k, SYCL_SILU_BLOCK_SIZE); |
| 644 | + main_stream->parallel_for( |
| 645 | + sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { |
| 646 | + gated_op_fused_swiglu(x, g, dst, k, n, o, item_ct1); |
| 647 | + }); |
| 648 | +} |
| 649 | + |
592 | 650 | inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
593 | 651 | #if defined (GGML_SYCL_F16)
|
594 | 652 | GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
@@ -1384,6 +1442,152 @@ inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
|
1384 | 1442 | acc_f32_sycl(src0_dd, src1_dd, dst_dd, ggml_nelements(dst), dst->src[1]->ne[0], dst->src[1]->ne[1], dst->src[1]->ne[2], nb1, nb2, offset, main_stream);
|
1385 | 1443 | }
|
1386 | 1444 |
|
| 1445 | +inline void ggml_sycl_op_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { |
| 1446 | +#if defined (GGML_SYCL_F16) |
| 1447 | + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); |
| 1448 | + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); |
| 1449 | + |
| 1450 | +#else |
| 1451 | + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); |
| 1452 | + GGML_ASSERT(dst->type == GGML_TYPE_F32); |
| 1453 | +#endif |
| 1454 | + GGML_ASSERT(dst->src[0]->type == dst->type); |
| 1455 | + dpct::queue_ptr main_stream = ctx.stream(); |
| 1456 | + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); |
| 1457 | + const int64_t nc = dst->src[0]->ne[0] / 2; |
| 1458 | + GGML_ASSERT(dst->ne[0] == nc); |
| 1459 | + GGML_ASSERT(ggml_is_contiguous_1(dst->src[0])); |
| 1460 | + GGML_ASSERT(ggml_is_contiguous(dst)); |
| 1461 | + const int32_t swapped = ((const int32_t *) dst->op_params)[1]; |
| 1462 | + const void * src0_d = dst->src[0]->data; |
| 1463 | + void * dst_d = dst->data; |
| 1464 | + switch (dst->type) { |
| 1465 | +#if defined (GGML_SYCL_F16) |
| 1466 | + case GGML_TYPE_F16: |
| 1467 | + { |
| 1468 | + geglu_sycl((const sycl::half *) src0_d + (swapped ? nc : 0), |
| 1469 | + (const sycl::half *)src0_d + (swapped ? 0 : nc), |
| 1470 | + (sycl::half *) dst_d, |
| 1471 | + ggml_nelements(dst), |
| 1472 | + nc, |
| 1473 | + dst->src[0]->nb[1] / sizeof(sycl::half), |
| 1474 | + main_stream); |
| 1475 | + break; |
| 1476 | + } |
| 1477 | +#endif |
| 1478 | + case GGML_TYPE_F32: |
| 1479 | + { |
| 1480 | + geglu_sycl((const float *) src0_d + (swapped ? nc : 0), |
| 1481 | + (const float *)src0_d + (swapped ? 0 : nc), |
| 1482 | + (float *) dst_d, |
| 1483 | + ggml_nelements(dst), |
| 1484 | + nc, |
| 1485 | + dst->src[0]->nb[1] / sizeof(float), |
| 1486 | + main_stream); |
| 1487 | + break; |
| 1488 | + } |
| 1489 | + default: |
| 1490 | + GGML_ABORT("GGML tensor type not supported!\n"); |
| 1491 | + } |
| 1492 | +} |
| 1493 | + |
| 1494 | +inline void ggml_sycl_op_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { |
| 1495 | +#if defined (GGML_SYCL_F16) |
| 1496 | + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); |
| 1497 | + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); |
| 1498 | + |
| 1499 | +#else |
| 1500 | + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); |
| 1501 | + GGML_ASSERT(dst->type == GGML_TYPE_F32); |
| 1502 | +#endif |
| 1503 | + GGML_ASSERT(dst->src[0]->type == dst->type); |
| 1504 | + dpct::queue_ptr main_stream = ctx.stream(); |
| 1505 | + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); |
| 1506 | + const int64_t nc = dst->src[0]->ne[0] / 2; |
| 1507 | + GGML_ASSERT(dst->ne[0] == nc); |
| 1508 | + GGML_ASSERT(ggml_is_contiguous_1(dst->src[0])); |
| 1509 | + GGML_ASSERT(ggml_is_contiguous(dst)); |
| 1510 | + const int32_t swapped = ((const int32_t *) dst->op_params)[1]; |
| 1511 | + const void * src0_d = dst->src[0]->data; |
| 1512 | + void * dst_d = dst->data; |
| 1513 | + switch (dst->type) { |
| 1514 | +#if defined (GGML_SYCL_F16) |
| 1515 | + case GGML_TYPE_F16: |
| 1516 | + { |
| 1517 | + reglu_sycl((const sycl::half *) src0_d + (swapped ? nc : 0), |
| 1518 | + (const sycl::half *)src0_d + (swapped ? 0 : nc), |
| 1519 | + (sycl::half *) dst_d, |
| 1520 | + ggml_nelements(dst), |
| 1521 | + nc, |
| 1522 | + dst->src[0]->nb[1] / sizeof(sycl::half), |
| 1523 | + main_stream); |
| 1524 | + break; |
| 1525 | + } |
| 1526 | +#endif |
| 1527 | + case GGML_TYPE_F32: |
| 1528 | + { |
| 1529 | + reglu_sycl((const float *) src0_d + (swapped ? nc : 0), |
| 1530 | + (const float *)src0_d + (swapped ? 0 : nc), |
| 1531 | + (float *) dst_d, |
| 1532 | + ggml_nelements(dst), |
| 1533 | + nc, |
| 1534 | + dst->src[0]->nb[1] / sizeof(float), |
| 1535 | + main_stream); |
| 1536 | + break; |
| 1537 | + } |
| 1538 | + default: |
| 1539 | + GGML_ABORT("GGML tensor type not supported!\n"); |
| 1540 | + } |
| 1541 | +} |
| 1542 | + |
| 1543 | +inline void ggml_sycl_op_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { |
| 1544 | +#if defined (GGML_SYCL_F16) |
| 1545 | + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); |
| 1546 | + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); |
| 1547 | + |
| 1548 | +#else |
| 1549 | + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); |
| 1550 | + GGML_ASSERT(dst->type == GGML_TYPE_F32); |
| 1551 | +#endif |
| 1552 | + GGML_ASSERT(dst->src[0]->type == dst->type); |
| 1553 | + dpct::queue_ptr main_stream = ctx.stream(); |
| 1554 | + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); |
| 1555 | + const int64_t nc = dst->src[0]->ne[0] / 2; |
| 1556 | + GGML_ASSERT(dst->ne[0] == nc); |
| 1557 | + GGML_ASSERT(ggml_is_contiguous_1(dst->src[0])); |
| 1558 | + GGML_ASSERT(ggml_is_contiguous(dst)); |
| 1559 | + const int32_t swapped = ((const int32_t *) dst->op_params)[1]; |
| 1560 | + const void * src0_d = dst->src[0]->data; |
| 1561 | + void * dst_d = dst->data; |
| 1562 | + switch (dst->type) { |
| 1563 | +#if defined (GGML_SYCL_F16) |
| 1564 | + case GGML_TYPE_F16: |
| 1565 | + { |
| 1566 | + swiglu_sycl((const sycl::half *) src0_d + (swapped ? nc : 0), |
| 1567 | + (const sycl::half *)src0_d + (swapped ? 0 : nc), |
| 1568 | + (sycl::half *) dst_d, |
| 1569 | + ggml_nelements(dst), |
| 1570 | + nc, |
| 1571 | + dst->src[0]->nb[1] / sizeof(sycl::half), |
| 1572 | + main_stream); |
| 1573 | + break; |
| 1574 | + } |
| 1575 | +#endif |
| 1576 | + case GGML_TYPE_F32: |
| 1577 | + { |
| 1578 | + swiglu_sycl((const float *) src0_d + (swapped ? nc : 0), |
| 1579 | + (const float *)src0_d + (swapped ? 0 : nc), |
| 1580 | + (float *) dst_d, |
| 1581 | + ggml_nelements(dst), |
| 1582 | + nc, |
| 1583 | + dst->src[0]->nb[1] / sizeof(float), |
| 1584 | + main_stream); |
| 1585 | + break; |
| 1586 | + } |
| 1587 | + default: |
| 1588 | + GGML_ABORT("GGML tensor type not supported!\n"); |
| 1589 | + } |
| 1590 | +} |
1387 | 1591 |
|
1388 | 1592 | void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
1389 | 1593 | scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
@@ -1509,3 +1713,20 @@ void ggml_sycl_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
1509 | 1713 | scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
1510 | 1714 | ggml_sycl_op_elu(ctx, dst);
|
1511 | 1715 | }
|
| 1716 | + |
| 1717 | +void ggml_sycl_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { |
| 1718 | + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); |
| 1719 | + ggml_sycl_op_geglu(ctx, dst); |
| 1720 | +} |
| 1721 | + |
| 1722 | +void ggml_sycl_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { |
| 1723 | + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); |
| 1724 | + ggml_sycl_op_reglu(ctx, dst); |
| 1725 | +} |
| 1726 | + |
| 1727 | +void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { |
| 1728 | + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); |
| 1729 | + ggml_sycl_op_swiglu(ctx, dst); |
| 1730 | +} |
| 1731 | + |
| 1732 | + |
0 commit comments