Skip to content

Commit 95e4be0

Browse files
committed
SYCL: Implement fused kernel GEGLU, SWIGLU and REGLU for single up+gate
1 parent 8dc1d9f commit 95e4be0

File tree

3 files changed

+254
-0
lines changed

3 files changed

+254
-0
lines changed

ggml/src/ggml-sycl/element_wise.cpp

Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
#include "common.hpp"
2+
#include "ggml-sycl/presets.hpp"
23
#include "ggml.h"
34
#include "element_wise.hpp"
5+
#include <cstddef>
6+
#include <cstdint>
47

58
static void acc_f32(const float * x, const float * y, float * dst, const int ne,
69
const int ne10, const int ne11, const int ne12,
@@ -324,6 +327,34 @@ static void clamp(const T * x, T * dst, const float min, const float max, const
324327
dst[i] = x[i] < static_cast<T>(min) ? static_cast<T>(min) : (x[i] > static_cast<T>(max) ? static_cast<T>(max) : x[i]);
325328
}
326329

330+
// Fused GLU kernels
331+
template<typename T>
332+
static void gated_op_fused_geglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, const sycl::nd_item<1> &item_ct1) {
333+
for (auto i = item_ct1.get_global_id(0); i < k; i += item_ct1.get_global_range(0)) {
334+
const int64_t j = ((i / n) * o) + (i % n);
335+
const T x_val = x[j];
336+
const T gelu_val = x_val * (static_cast<T>(1.0f) / (static_cast<T>(1.0f) + sycl::native::exp(GELU_QUICK_COEF * x_val)));
337+
338+
dst[i] = gelu_val * g[j];
339+
}
340+
}
341+
342+
template<typename T>
343+
static void gated_op_fused_reglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, const sycl::nd_item<1> &item_ct1) {
344+
for (auto i = item_ct1.get_global_id(0); i < k; i += item_ct1.get_global_range(0)) {
345+
const int64_t j = ((i / n) * o) + (i % n);
346+
dst[i] = sycl::max((x[j]), static_cast<T>(0)) * g[j];
347+
}
348+
}
349+
350+
template<typename T>
351+
static void gated_op_fused_swiglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, const sycl::nd_item<1> &item_ct1) {
352+
for (auto i = item_ct1.get_global_id(0); i < k; i += item_ct1.get_global_range(0)) {
353+
const int64_t j = ((i / n) * o) + (i % n);
354+
dst[i] = (x[j] / (static_cast<T>(1) + sycl::native::exp(-x[j]))) * g[j];
355+
}
356+
}
357+
327358
static void acc_f32_sycl(const float *x, const float *y, float *dst,
328359
const int n_elements, const int ne10, const int ne11,
329360
const int ne12, const int nb1, const int nb2,
@@ -649,6 +680,33 @@ static void clamp_sycl(const T *x, T *dst, const float min,
649680
});
650681
}
651682

683+
template<typename T>
684+
static void geglu_sycl(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, queue_ptr main_stream) {
685+
const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
686+
main_stream->parallel_for(
687+
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
688+
gated_op_fused_geglu(x, g, dst, k, n, o, item_ct1);
689+
});
690+
}
691+
692+
template<typename T>
693+
static void reglu_sycl(const T * x, const T* g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, queue_ptr main_stream) {
694+
const uint32_t num_blocks = ceil_div(k, SYCL_RELU_BLOCK_SIZE);
695+
main_stream->parallel_for(
696+
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
697+
gated_op_fused_reglu(x, g, dst, k, n, o, item_ct1);
698+
});
699+
}
700+
701+
template<typename T>
702+
static void swiglu_sycl(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, queue_ptr main_stream) {
703+
const uint32_t num_blocks = ceil_div(k, SYCL_SILU_BLOCK_SIZE);
704+
main_stream->parallel_for(
705+
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
706+
gated_op_fused_swiglu(x, g, dst, k, n, o, item_ct1);
707+
});
708+
}
709+
652710
inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
653711
#if defined (GGML_SYCL_F16)
654712
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
@@ -1444,6 +1502,152 @@ inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
14441502
acc_f32_sycl(src0_dd, src1_dd, dst_dd, ggml_nelements(dst), dst->src[1]->ne[0], dst->src[1]->ne[1], dst->src[1]->ne[2], nb1, nb2, offset, main_stream);
14451503
}
14461504

1505+
inline void ggml_sycl_op_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1506+
#if defined (GGML_SYCL_F16)
1507+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
1508+
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
1509+
1510+
#else
1511+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
1512+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
1513+
#endif
1514+
GGML_ASSERT(dst->src[0]->type == dst->type);
1515+
dpct::queue_ptr main_stream = ctx.stream();
1516+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
1517+
const int64_t nc = dst->src[0]->ne[0] / 2;
1518+
GGML_ASSERT(dst->ne[0] == nc);
1519+
GGML_ASSERT(ggml_is_contiguous_1(dst->src[0]));
1520+
GGML_ASSERT(ggml_is_contiguous(dst));
1521+
const int32_t swapped = ((const int32_t *) dst->op_params)[1];
1522+
const void * src0_d = dst->src[0]->data;
1523+
void * dst_d = dst->data;
1524+
switch (dst->type) {
1525+
#if defined (GGML_SYCL_F16)
1526+
case GGML_TYPE_F16:
1527+
{
1528+
geglu_sycl((const sycl::half *) src0_d + (swapped ? nc : 0),
1529+
(const sycl::half *)src0_d + (swapped ? 0 : nc),
1530+
(sycl::half *) dst_d,
1531+
ggml_nelements(dst),
1532+
nc,
1533+
dst->src[0]->nb[1] / sizeof(sycl::half),
1534+
main_stream);
1535+
break;
1536+
}
1537+
#endif
1538+
case GGML_TYPE_F32:
1539+
{
1540+
geglu_sycl((const float *) src0_d + (swapped ? nc : 0),
1541+
(const float *)src0_d + (swapped ? 0 : nc),
1542+
(float *) dst_d,
1543+
ggml_nelements(dst),
1544+
nc,
1545+
dst->src[0]->nb[1] / sizeof(float),
1546+
main_stream);
1547+
break;
1548+
}
1549+
default:
1550+
GGML_ABORT("GGML tensor type not supported!\n");
1551+
}
1552+
}
1553+
1554+
inline void ggml_sycl_op_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1555+
#if defined (GGML_SYCL_F16)
1556+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
1557+
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
1558+
1559+
#else
1560+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
1561+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
1562+
#endif
1563+
GGML_ASSERT(dst->src[0]->type == dst->type);
1564+
dpct::queue_ptr main_stream = ctx.stream();
1565+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
1566+
const int64_t nc = dst->src[0]->ne[0] / 2;
1567+
GGML_ASSERT(dst->ne[0] == nc);
1568+
GGML_ASSERT(ggml_is_contiguous_1(dst->src[0]));
1569+
GGML_ASSERT(ggml_is_contiguous(dst));
1570+
const int32_t swapped = ((const int32_t *) dst->op_params)[1];
1571+
const void * src0_d = dst->src[0]->data;
1572+
void * dst_d = dst->data;
1573+
switch (dst->type) {
1574+
#if defined (GGML_SYCL_F16)
1575+
case GGML_TYPE_F16:
1576+
{
1577+
reglu_sycl((const sycl::half *) src0_d + (swapped ? nc : 0),
1578+
(const sycl::half *)src0_d + (swapped ? 0 : nc),
1579+
(sycl::half *) dst_d,
1580+
ggml_nelements(dst),
1581+
nc,
1582+
dst->src[0]->nb[1] / sizeof(sycl::half),
1583+
main_stream);
1584+
break;
1585+
}
1586+
#endif
1587+
case GGML_TYPE_F32:
1588+
{
1589+
reglu_sycl((const float *) src0_d + (swapped ? nc : 0),
1590+
(const float *)src0_d + (swapped ? 0 : nc),
1591+
(float *) dst_d,
1592+
ggml_nelements(dst),
1593+
nc,
1594+
dst->src[0]->nb[1] / sizeof(float),
1595+
main_stream);
1596+
break;
1597+
}
1598+
default:
1599+
GGML_ABORT("GGML tensor type not supported!\n");
1600+
}
1601+
}
1602+
1603+
inline void ggml_sycl_op_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1604+
#if defined (GGML_SYCL_F16)
1605+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
1606+
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
1607+
1608+
#else
1609+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
1610+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
1611+
#endif
1612+
GGML_ASSERT(dst->src[0]->type == dst->type);
1613+
dpct::queue_ptr main_stream = ctx.stream();
1614+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
1615+
const int64_t nc = dst->src[0]->ne[0] / 2;
1616+
GGML_ASSERT(dst->ne[0] == nc);
1617+
GGML_ASSERT(ggml_is_contiguous_1(dst->src[0]));
1618+
GGML_ASSERT(ggml_is_contiguous(dst));
1619+
const int32_t swapped = ((const int32_t *) dst->op_params)[1];
1620+
const void * src0_d = dst->src[0]->data;
1621+
void * dst_d = dst->data;
1622+
switch (dst->type) {
1623+
#if defined (GGML_SYCL_F16)
1624+
case GGML_TYPE_F16:
1625+
{
1626+
swiglu_sycl((const sycl::half *) src0_d + (swapped ? nc : 0),
1627+
(const sycl::half *)src0_d + (swapped ? 0 : nc),
1628+
(sycl::half *) dst_d,
1629+
ggml_nelements(dst),
1630+
nc,
1631+
dst->src[0]->nb[1] / sizeof(sycl::half),
1632+
main_stream);
1633+
break;
1634+
}
1635+
#endif
1636+
case GGML_TYPE_F32:
1637+
{
1638+
swiglu_sycl((const float *) src0_d + (swapped ? nc : 0),
1639+
(const float *)src0_d + (swapped ? 0 : nc),
1640+
(float *) dst_d,
1641+
ggml_nelements(dst),
1642+
nc,
1643+
dst->src[0]->nb[1] / sizeof(float),
1644+
main_stream);
1645+
break;
1646+
}
1647+
default:
1648+
GGML_ABORT("GGML tensor type not supported!\n");
1649+
}
1650+
}
14471651

14481652
void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
14491653
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
@@ -1569,3 +1773,20 @@ void ggml_sycl_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
15691773
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
15701774
ggml_sycl_op_elu(ctx, dst);
15711775
}
1776+
1777+
void ggml_sycl_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1778+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1779+
ggml_sycl_op_geglu(ctx, dst);
1780+
}
1781+
1782+
void ggml_sycl_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1783+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1784+
ggml_sycl_op_reglu(ctx, dst);
1785+
}
1786+
1787+
void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1788+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1789+
ggml_sycl_op_swiglu(ctx, dst);
1790+
}
1791+
1792+

ggml/src/ggml-sycl/element_wise.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ typed_data<T> cast_data(ggml_tensor * dst) {
2424
};
2525
}
2626

27+
const float GELU_QUICK_COEF = -1.702f;
28+
29+
2730
void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
2831

2932
void ggml_sycl_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
@@ -73,5 +76,10 @@ void ggml_sycl_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
7376
void ggml_sycl_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
7477

7578
void ggml_sycl_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
79+
80+
void ggml_sycl_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
81+
void ggml_sycl_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
82+
void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
83+
7684
#endif // GGML_SYCL_ELEMENTWISE_HPP
7785

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3685,6 +3685,21 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
36853685
return false;
36863686
}
36873687
break;
3688+
case GGML_OP_GLU:
3689+
switch (ggml_get_glu_op(dst)) {
3690+
case GGML_GLU_OP_REGLU:
3691+
ggml_sycl_reglu(ctx, dst);
3692+
break;
3693+
case GGML_GLU_OP_GEGLU:
3694+
ggml_sycl_geglu(ctx, dst);
3695+
break;
3696+
case GGML_GLU_OP_SWIGLU:
3697+
ggml_sycl_swiglu(ctx, dst);
3698+
break;
3699+
default:
3700+
return false;
3701+
}
3702+
break;
36883703
case GGML_OP_NORM:
36893704
ggml_sycl_norm(ctx, dst);
36903705
break;
@@ -4221,6 +4236,16 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
42214236
default:
42224237
return false;
42234238
}
4239+
case GGML_OP_GLU:
4240+
switch (ggml_get_glu_op(op)) {
4241+
case GGML_GLU_OP_REGLU:
4242+
case GGML_GLU_OP_GEGLU:
4243+
case GGML_GLU_OP_SWIGLU:
4244+
return ggml_is_contiguous_1(op->src[0]);
4245+
default:
4246+
return false;
4247+
}
4248+
break;
42244249
case GGML_OP_MUL_MAT:
42254250
case GGML_OP_MUL_MAT_ID:
42264251
{

0 commit comments

Comments
 (0)