Skip to content

SYCL: add gelu_erf kernel #13749

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions ggml/src/ggml-sycl/element_wise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,15 @@ static void gelu_quick(const T *x, T *dst, int k,
dst[i] = x[i] * (static_cast<T>(1.0f) / (static_cast<T>(1.0f) + sycl::native::exp(GELU_QUICK_COEF * x[i])));
}

template<typename T>
static void gelu_erf(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) {
const T SQRT_2_INV = static_cast<T>(0.70710678118654752440084436210484f);
for(auto i = item_ct1.get_global_id(2); i < (const size_t)k; i += item_ct1.get_global_range(2)) {
auto x_i = x[i];
dst[i] = static_cast<T>(0.5f) * x_i * (static_cast<T>(1.0f) + sycl::erf(x_i * SQRT_2_INV));
}
}

template<typename T>
static void tanh(const T *x, T *dst, int k,
const sycl::nd_item<3> &item_ct1) {
Expand Down Expand Up @@ -400,6 +409,20 @@ static void gelu_quick_sycl(const T *x, T *dst, const int k,
});
}


template<typename T>
static void gelu_erf_sycl(const T *x, T *dst, const int k,
queue_ptr stream) {
const int num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
gelu_erf(x, dst, k, item_ct1);
});
}

template<typename T>
static void tanh_sycl(const T *x, T *dst, const int k,
queue_ptr stream) {
Expand Down Expand Up @@ -816,6 +839,38 @@ inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor
}
}

inline void ggml_sycl_op_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
#if defined (GGML_SYCL_F16)
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
#else
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
#endif
GGML_ASSERT(dst->src[0]->type == dst->type);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
switch (dst->type) {
#if defined (GGML_SYCL_F16)
case GGML_TYPE_F16:
{
auto data_pts = cast_data<sycl::half>(dst);
gelu_erf_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
break;
}
#endif
case GGML_TYPE_F32:
{
auto data_pts = cast_data<float>(dst);
gelu_erf_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
break;
}
default:
GGML_ABORT("GGML tensor type not supported!\n");
}
}


inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
#if defined (GGML_SYCL_F16)
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
Expand Down Expand Up @@ -1425,6 +1480,11 @@ void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_op_gelu_quick(ctx, dst);
}

void ggml_sycl_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
ggml_sycl_op_gelu_erf(ctx, dst);
}

void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
ggml_sycl_op_tanh(ctx, dst);
Expand Down
2 changes: 2 additions & 0 deletions ggml/src/ggml-sycl/element_wise.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ void ggml_sycl_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);

void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst);

void ggml_sycl_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst);

void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst);

void ggml_sycl_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
Expand Down
4 changes: 4 additions & 0 deletions ggml/src/ggml-sycl/ggml-sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3543,6 +3543,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
case GGML_UNARY_OP_GELU_QUICK:
ggml_sycl_gelu_quick(ctx, dst);
break;
case GGML_UNARY_OP_GELU_ERF:
ggml_sycl_gelu_erf(ctx, dst);
break;
case GGML_UNARY_OP_TANH:
ggml_sycl_tanh(ctx, dst);
break;
Expand Down Expand Up @@ -4096,6 +4099,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_UNARY_OP_HARDSIGMOID:
case GGML_UNARY_OP_HARDSWISH:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_GELU_ERF:
case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_EXP:
case GGML_UNARY_OP_SGN:
Expand Down
Loading