Skip to content

Commit a9aedf4

Browse files
committed
SYCL: Implement fused kernel GEGLU, SWIGLU and REGLU for single up+gate
1 parent 34d1aed commit a9aedf4

File tree

3 files changed

+254
-0
lines changed

3 files changed

+254
-0
lines changed

ggml/src/ggml-sycl/element_wise.cpp

Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
#include "common.hpp"
2+
#include "ggml-sycl/presets.hpp"
23
#include "ggml.h"
34
#include "element_wise.hpp"
5+
#include <cstddef>
6+
#include <cstdint>
47

58
static void acc_f32(const float * x, const float * y, float * dst, const int ne,
69
const int ne10, const int ne11, const int ne12,
@@ -324,6 +327,34 @@ static void clamp(const T * x, T * dst, const float min, const float max, const
324327
dst[i] = x[i] < static_cast<T>(min) ? static_cast<T>(min) : (x[i] > static_cast<T>(max) ? static_cast<T>(max) : x[i]);
325328
}
326329

330+
// Fused GLU kernels
331+
template<typename T>
332+
static void gated_op_fused_geglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, const sycl::nd_item<1> &item_ct1) {
333+
for (auto i = item_ct1.get_global_id(0); i < k; i += item_ct1.get_global_range(0)) {
334+
const int64_t j = ((i / n) * o) + (i % n);
335+
const T x_val = x[j];
336+
const T gelu_val = x_val * (static_cast<T>(1.0f) / (static_cast<T>(1.0f) + sycl::native::exp(GELU_QUICK_COEF * x_val)));
337+
338+
dst[i] = gelu_val * g[j];
339+
}
340+
}
341+
342+
template<typename T>
343+
static void gated_op_fused_reglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, const sycl::nd_item<1> &item_ct1) {
344+
for (auto i = item_ct1.get_global_id(0); i < k; i += item_ct1.get_global_range(0)) {
345+
const int64_t j = ((i / n) * o) + (i % n);
346+
dst[i] = sycl::max((x[j]), static_cast<T>(0)) * g[j];
347+
}
348+
}
349+
350+
template<typename T>
351+
static void gated_op_fused_swiglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, const sycl::nd_item<1> &item_ct1) {
352+
for (auto i = item_ct1.get_global_id(0); i < k; i += item_ct1.get_global_range(0)) {
353+
const int64_t j = ((i / n) * o) + (i % n);
354+
dst[i] = (x[j] / (static_cast<T>(1) + sycl::native::exp(-x[j]))) * g[j];
355+
}
356+
}
357+
327358
static void acc_f32_sycl(const float *x, const float *y, float *dst,
328359
const int n_elements, const int ne10, const int ne11,
329360
const int ne12, const int nb1, const int nb2,
@@ -589,6 +620,33 @@ static void clamp_sycl(const T *x, T *dst, const float min,
589620
[=](sycl::nd_item<3> item_ct1) { clamp(x, dst, min, max, k, item_ct1); });
590621
}
591622

623+
template<typename T>
624+
static void geglu_sycl(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, queue_ptr main_stream) {
625+
const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
626+
main_stream->parallel_for(
627+
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
628+
gated_op_fused_geglu(x, g, dst, k, n, o, item_ct1);
629+
});
630+
}
631+
632+
template<typename T>
633+
static void reglu_sycl(const T * x, const T* g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, queue_ptr main_stream) {
634+
const uint32_t num_blocks = ceil_div(k, SYCL_RELU_BLOCK_SIZE);
635+
main_stream->parallel_for(
636+
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
637+
gated_op_fused_reglu(x, g, dst, k, n, o, item_ct1);
638+
});
639+
}
640+
641+
template<typename T>
642+
static void swiglu_sycl(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, queue_ptr main_stream) {
643+
const uint32_t num_blocks = ceil_div(k, SYCL_SILU_BLOCK_SIZE);
644+
main_stream->parallel_for(
645+
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
646+
gated_op_fused_swiglu(x, g, dst, k, n, o, item_ct1);
647+
});
648+
}
649+
592650
inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
593651
#if defined (GGML_SYCL_F16)
594652
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
@@ -1384,6 +1442,152 @@ inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
13841442
acc_f32_sycl(src0_dd, src1_dd, dst_dd, ggml_nelements(dst), dst->src[1]->ne[0], dst->src[1]->ne[1], dst->src[1]->ne[2], nb1, nb2, offset, main_stream);
13851443
}
13861444

1445+
inline void ggml_sycl_op_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1446+
#if defined (GGML_SYCL_F16)
1447+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
1448+
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
1449+
1450+
#else
1451+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
1452+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
1453+
#endif
1454+
GGML_ASSERT(dst->src[0]->type == dst->type);
1455+
dpct::queue_ptr main_stream = ctx.stream();
1456+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
1457+
const int64_t nc = dst->src[0]->ne[0] / 2;
1458+
GGML_ASSERT(dst->ne[0] == nc);
1459+
GGML_ASSERT(ggml_is_contiguous_1(dst->src[0]));
1460+
GGML_ASSERT(ggml_is_contiguous(dst));
1461+
const int32_t swapped = ((const int32_t *) dst->op_params)[1];
1462+
const void * src0_d = dst->src[0]->data;
1463+
void * dst_d = dst->data;
1464+
switch (dst->type) {
1465+
#if defined (GGML_SYCL_F16)
1466+
case GGML_TYPE_F16:
1467+
{
1468+
geglu_sycl((const sycl::half *) src0_d + (swapped ? nc : 0),
1469+
(const sycl::half *)src0_d + (swapped ? 0 : nc),
1470+
(sycl::half *) dst_d,
1471+
ggml_nelements(dst),
1472+
nc,
1473+
dst->src[0]->nb[1] / sizeof(sycl::half),
1474+
main_stream);
1475+
break;
1476+
}
1477+
#endif
1478+
case GGML_TYPE_F32:
1479+
{
1480+
geglu_sycl((const float *) src0_d + (swapped ? nc : 0),
1481+
(const float *)src0_d + (swapped ? 0 : nc),
1482+
(float *) dst_d,
1483+
ggml_nelements(dst),
1484+
nc,
1485+
dst->src[0]->nb[1] / sizeof(float),
1486+
main_stream);
1487+
break;
1488+
}
1489+
default:
1490+
GGML_ABORT("GGML tensor type not supported!\n");
1491+
}
1492+
}
1493+
1494+
inline void ggml_sycl_op_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1495+
#if defined (GGML_SYCL_F16)
1496+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
1497+
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
1498+
1499+
#else
1500+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
1501+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
1502+
#endif
1503+
GGML_ASSERT(dst->src[0]->type == dst->type);
1504+
dpct::queue_ptr main_stream = ctx.stream();
1505+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
1506+
const int64_t nc = dst->src[0]->ne[0] / 2;
1507+
GGML_ASSERT(dst->ne[0] == nc);
1508+
GGML_ASSERT(ggml_is_contiguous_1(dst->src[0]));
1509+
GGML_ASSERT(ggml_is_contiguous(dst));
1510+
const int32_t swapped = ((const int32_t *) dst->op_params)[1];
1511+
const void * src0_d = dst->src[0]->data;
1512+
void * dst_d = dst->data;
1513+
switch (dst->type) {
1514+
#if defined (GGML_SYCL_F16)
1515+
case GGML_TYPE_F16:
1516+
{
1517+
reglu_sycl((const sycl::half *) src0_d + (swapped ? nc : 0),
1518+
(const sycl::half *)src0_d + (swapped ? 0 : nc),
1519+
(sycl::half *) dst_d,
1520+
ggml_nelements(dst),
1521+
nc,
1522+
dst->src[0]->nb[1] / sizeof(sycl::half),
1523+
main_stream);
1524+
break;
1525+
}
1526+
#endif
1527+
case GGML_TYPE_F32:
1528+
{
1529+
reglu_sycl((const float *) src0_d + (swapped ? nc : 0),
1530+
(const float *)src0_d + (swapped ? 0 : nc),
1531+
(float *) dst_d,
1532+
ggml_nelements(dst),
1533+
nc,
1534+
dst->src[0]->nb[1] / sizeof(float),
1535+
main_stream);
1536+
break;
1537+
}
1538+
default:
1539+
GGML_ABORT("GGML tensor type not supported!\n");
1540+
}
1541+
}
1542+
1543+
inline void ggml_sycl_op_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1544+
#if defined (GGML_SYCL_F16)
1545+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
1546+
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
1547+
1548+
#else
1549+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
1550+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
1551+
#endif
1552+
GGML_ASSERT(dst->src[0]->type == dst->type);
1553+
dpct::queue_ptr main_stream = ctx.stream();
1554+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
1555+
const int64_t nc = dst->src[0]->ne[0] / 2;
1556+
GGML_ASSERT(dst->ne[0] == nc);
1557+
GGML_ASSERT(ggml_is_contiguous_1(dst->src[0]));
1558+
GGML_ASSERT(ggml_is_contiguous(dst));
1559+
const int32_t swapped = ((const int32_t *) dst->op_params)[1];
1560+
const void * src0_d = dst->src[0]->data;
1561+
void * dst_d = dst->data;
1562+
switch (dst->type) {
1563+
#if defined (GGML_SYCL_F16)
1564+
case GGML_TYPE_F16:
1565+
{
1566+
swiglu_sycl((const sycl::half *) src0_d + (swapped ? nc : 0),
1567+
(const sycl::half *)src0_d + (swapped ? 0 : nc),
1568+
(sycl::half *) dst_d,
1569+
ggml_nelements(dst),
1570+
nc,
1571+
dst->src[0]->nb[1] / sizeof(sycl::half),
1572+
main_stream);
1573+
break;
1574+
}
1575+
#endif
1576+
case GGML_TYPE_F32:
1577+
{
1578+
swiglu_sycl((const float *) src0_d + (swapped ? nc : 0),
1579+
(const float *)src0_d + (swapped ? 0 : nc),
1580+
(float *) dst_d,
1581+
ggml_nelements(dst),
1582+
nc,
1583+
dst->src[0]->nb[1] / sizeof(float),
1584+
main_stream);
1585+
break;
1586+
}
1587+
default:
1588+
GGML_ABORT("GGML tensor type not supported!\n");
1589+
}
1590+
}
13871591

13881592
void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
13891593
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
@@ -1509,3 +1713,20 @@ void ggml_sycl_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
15091713
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
15101714
ggml_sycl_op_elu(ctx, dst);
15111715
}
1716+
1717+
void ggml_sycl_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1718+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1719+
ggml_sycl_op_geglu(ctx, dst);
1720+
}
1721+
1722+
void ggml_sycl_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1723+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1724+
ggml_sycl_op_reglu(ctx, dst);
1725+
}
1726+
1727+
void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1728+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1729+
ggml_sycl_op_swiglu(ctx, dst);
1730+
}
1731+
1732+

ggml/src/ggml-sycl/element_wise.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ typed_data<T> cast_data(ggml_tensor * dst) {
2424
};
2525
}
2626

27+
const float GELU_QUICK_COEF = -1.702f;
28+
29+
2730
void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
2831

2932
void ggml_sycl_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
@@ -73,5 +76,10 @@ void ggml_sycl_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
7376
void ggml_sycl_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
7477

7578
void ggml_sycl_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
79+
80+
void ggml_sycl_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
81+
void ggml_sycl_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
82+
void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
83+
7684
#endif // GGML_SYCL_ELEMENTWISE_HPP
7785

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3678,6 +3678,21 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
36783678
return false;
36793679
}
36803680
break;
3681+
case GGML_OP_GLU:
3682+
switch (ggml_get_glu_op(dst)) {
3683+
case GGML_GLU_OP_REGLU:
3684+
ggml_sycl_reglu(ctx, dst);
3685+
break;
3686+
case GGML_GLU_OP_GEGLU:
3687+
ggml_sycl_geglu(ctx, dst);
3688+
break;
3689+
case GGML_GLU_OP_SWIGLU:
3690+
ggml_sycl_swiglu(ctx, dst);
3691+
break;
3692+
default:
3693+
return false;
3694+
}
3695+
break;
36813696
case GGML_OP_NORM:
36823697
ggml_sycl_norm(ctx, dst);
36833698
break;
@@ -4214,6 +4229,16 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
42144229
default:
42154230
return false;
42164231
}
4232+
case GGML_OP_GLU:
4233+
switch (ggml_get_glu_op(op)) {
4234+
case GGML_GLU_OP_REGLU:
4235+
case GGML_GLU_OP_GEGLU:
4236+
case GGML_GLU_OP_SWIGLU:
4237+
return ggml_is_contiguous_1(op->src[0]);
4238+
default:
4239+
return false;
4240+
}
4241+
break;
42174242
case GGML_OP_MUL_MAT:
42184243
case GGML_OP_MUL_MAT_ID:
42194244
{

0 commit comments

Comments
 (0)