Skip to content

SYCL: Kernel function refactor #11515

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 40 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
2d72bd9
SYCL: remove ggml_sycl_op_flatten function
qnixsynapse Jan 30, 2025
957c11b
binbcast: use void pointer to prevent intermediate type conversions
qnixsynapse Jan 31, 2025
108be39
binbcast: move to a separate file
qnixsynapse Jan 31, 2025
e1326a7
binbcast: add try catch sycl::exception
qnixsynapse Jan 31, 2025
fa7c4d8
Fix GGML_SYCL_DEBUG in kernels in other files
qnixsynapse Jan 31, 2025
95a09ab
ARGMAX: move to a separate file
qnixsynapse Feb 1, 2025
5288bd5
Argsort: move to a separate file
qnixsynapse Feb 1, 2025
a153f19
ggml_sycl_compute_forward: fixup function calling names and remove co…
qnixsynapse Feb 1, 2025
51bedb8
argmax: move missing function to file and fix function name
qnixsynapse Feb 1, 2025
3a34659
argsort: add a space at the end of file
qnixsynapse Feb 1, 2025
aaf9ed0
Add spaces
qnixsynapse Feb 1, 2025
a16b6b7
eltwise: sort includes
qnixsynapse Feb 1, 2025
ecacff3
CPY: move to a separate file
qnixsynapse Feb 1, 2025
7d8d689
eltwise: add back split buffer type checks
qnixsynapse Feb 1, 2025
04d8b03
Add back split buffer type checks
qnixsynapse Feb 1, 2025
98f5fd2
getrows: move to a separate file
qnixsynapse Feb 1, 2025
8e86732
diagmask: move to a separate file
qnixsynapse Feb 1, 2025
7f2d24f
rope: add try catch sycl exception and debug log
qnixsynapse Feb 2, 2025
927925f
scale: move to a separate file
qnixsynapse Feb 2, 2025
0c319bf
DUP: move to cpy.cpp, set debug logs and adjust include
qnixsynapse Feb 2, 2025
ddc5e42
clamp: move to a separate file
qnixsynapse Feb 2, 2025
ba79258
Add spaces to end of files
qnixsynapse Feb 2, 2025
4db56d6
im2col: add try catch block and move wrapper function from ggml-sycl.cpp
qnixsynapse Feb 2, 2025
eb466d7
pool2d: move to a separate file
qnixsynapse Feb 2, 2025
5c05a3e
Move sum and sum rows to a separate file
qnixsynapse Feb 2, 2025
d31c62d
norm: add try catch sycl exception
qnixsynapse Feb 2, 2025
1ccfaae
Add sum to backend hpp
qnixsynapse Feb 2, 2025
bba4b66
concat: Handle SYCL exceptions
qnixsynapse Feb 2, 2025
6dbb7ac
softmax: handle SYCL exceptions and add debug logs
qnixsynapse Feb 2, 2025
a6a239c
norm: add a space at the end of file
qnixsynapse Feb 2, 2025
6eb30d9
Adjust EOF spaces and usused variable
qnixsynapse Feb 2, 2025
539b0c6
ggml-sycl: sort includes
qnixsynapse Feb 3, 2025
18d706a
gemm.hpp: remove unused include
qnixsynapse Feb 3, 2025
0ae9a07
ggml_sycl_op_argmax)Add debug logs to ggml_sycl_mul_ma0
qnixsynapse Feb 3, 2025
7369e54
Add back ggml_sycl_set_device to kernels
qnixsynapse Feb 3, 2025
e592637
Add remaining SYCL exception handler to kernel and refactor
qnixsynapse Feb 3, 2025
52b0652
conv: add space before eof
qnixsynapse Feb 3, 2025
0b602f0
Final touches
qnixsynapse Feb 3, 2025
efb5773
ggml-sycl: hide matrix engine info for now from print sycl devices
qnixsynapse Feb 5, 2025
cfa2cc1
Disable non-contiguous tensor support in norm kernels and add newline…
qnixsynapse Feb 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions ggml/include/ggml-sycl.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@
#include "ggml.h"
#include "ggml-backend.h"

#define GGML_SYCL_NAME "SYCL"
#define GGML_SYCL_MAX_DEVICES 48

#ifdef __cplusplus
extern "C" {
#endif
Expand Down
75 changes: 75 additions & 0 deletions ggml/src/ggml-sycl/argmax.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#include "argmax.hpp"

static void argmax_f32_i32_sycl(const float * x, int * dst, const int ncols, const int nrows, queue_ptr stream) {
const sycl::range<3> block_dims(1, 1, SYCL_ARGMAX_BLOCK_SIZE);
const sycl::range<3> block_nums(1, nrows, 1);
const size_t shared_mem = 256 * sizeof(float);

stream->submit([&](sycl::handler & cgh) {
sycl::local_accessor<float, 1> shared_data(sycl::range<1>(shared_mem / sizeof(float)), cgh);
sycl::local_accessor<int, 1> shared_indices(sycl::range<1>(shared_mem / sizeof(float)), cgh);

cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
const int tid = item_ct1.get_local_id(2);
const int row = item_ct1.get_global_id(1);

float max_val = -INFINITY;
int max_idx = -1;

for (int col = tid; col < ncols; col += 256) {
float val = x[row * ncols + col];
if (val > max_val) {
max_val = val;
max_idx = col;
}
}

shared_data[tid] = max_val;
shared_indices[tid] = max_idx;
item_ct1.barrier(sycl::access::fence_space::local_space);

for (int stride = 256 / 2; stride > 0; stride >>= 1) {
if (tid < stride) {
float val1 = shared_data[tid];
float val2 = shared_data[tid + stride];
if (val2 > val1) {
shared_data[tid] = val2;
shared_indices[tid] = shared_indices[tid + stride];
}
}
item_ct1.barrier(sycl::access::fence_space::local_space);
}

if (tid == 0) {
dst[row] = shared_indices[0];
}
});
});
}

static void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) try {
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));

GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_I32);
GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(dst->buffer));

const int64_t ncols = dst->src[0]->ne[0];
const int64_t nrows = ggml_nrows(dst->src[0]);

dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
int32_t * dst_dd = static_cast<int32_t *>(dst->data);
argmax_f32_i32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
} catch (const sycl::exception & exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
std::exit(1);
}

void ggml_sycl_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_argmax(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
8 changes: 8 additions & 0 deletions ggml/src/ggml-sycl/argmax.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef GGML_SYCL_ARGMAX_HPP
#define GGML_SYCL_ARGMAX_HPP

#include "common.hpp"

void ggml_sycl_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst);

#endif // GGML_SYCL_ARGMAX_HPP
130 changes: 130 additions & 0 deletions ggml/src/ggml-sycl/argsort.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
#include "argsort.hpp"

static int next_power_of_2(int x) {
int n = 1;
while (n < x) {
n *= 2;
}
return n;
}

template <typename T>
static inline void ggml_sycl_swap(T & a, T & b) {
T tmp = a;
a = b;
b = tmp;
}

template <ggml_sort_order order>
__dpct_inline__ static void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad,
const sycl::nd_item<3> & item_ct1, uint8_t * dpct_local) {
// bitonic sort
int col = item_ct1.get_local_id(2);
int row = item_ct1.get_group(1);

if (col >= ncols_pad) {
return;
}

const float * x_row = x + row * ncols;
auto dst_row = (int *) dpct_local;

// initialize indices
dst_row[col] = col;

item_ct1.barrier(sycl::access::fence_space::local_space);

for (int k = 2; k <= ncols_pad; k *= 2) {
for (int j = k / 2; j > 0; j /= 2) {
int ixj = col ^ j;
if (ixj > col) {
if ((col & k) == 0) {
if (dst_row[col] >= ncols ||
(dst_row[ixj] < ncols &&
(order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] :
x_row[dst_row[col]] < x_row[dst_row[ixj]]))) {
ggml_sycl_swap(dst_row[col], dst_row[ixj]);
}
} else {
if (dst_row[ixj] >= ncols ||
(dst_row[col] < ncols &&
(order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] :
x_row[dst_row[col]] > x_row[dst_row[ixj]]))) {
ggml_sycl_swap(dst_row[col], dst_row[ixj]);
}
}
}
/*
DPCT1118:1: SYCL group functions and algorithms must be encountered
in converged control flow. You may need to adjust the code.
*/
item_ct1.barrier(sycl::access::fence_space::local_space);
}
}

// copy the result to dst without the padding
if (col < ncols) {
dst[row * ncols + col] = dst_row[col];
}
}

static void argsort_f32_i32_sycl(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order,
queue_ptr stream) {
// bitonic sort requires ncols to be power of 2
const int ncols_pad = next_power_of_2(ncols);

const sycl::range<3> block_dims(1, 1, ncols_pad);
const sycl::range<3> block_nums(1, nrows, 1);
const size_t shared_mem = ncols_pad * sizeof(int);

if (order == GGML_SORT_ORDER_ASC) {
stream->submit([&](sycl::handler & cgh) {
sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(sycl::range<1>(shared_mem), cgh);

cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
k_argsort_f32_i32<GGML_SORT_ORDER_ASC>(
x, dst, ncols, ncols_pad, item_ct1,
dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>().get());
});
});
} else if (order == GGML_SORT_ORDER_DESC) {
stream->submit([&](sycl::handler & cgh) {
sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(sycl::range<1>(shared_mem), cgh);

cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
k_argsort_f32_i32<GGML_SORT_ORDER_DESC>(
x, dst, ncols, ncols_pad, item_ct1,
dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>().get());
});
});
} else {
GGML_ABORT("fatal error");
}
}

inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) try {
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_I32);
GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(dst->buffer));

const int64_t ncols = dst->src[0]->ne[0];
const int64_t nrows = ggml_nrows(dst->src[0]);

enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
int32_t * dst_dd = static_cast<int32_t *>(dst->data);

argsort_f32_i32_sycl(src0_dd, dst_dd, ncols, nrows, order, main_stream);
} catch (const sycl::exception & exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
std::exit(1);
}

void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_argsort(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
8 changes: 8 additions & 0 deletions ggml/src/ggml-sycl/argsort.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef GGML_SYCL_ARGSORT_HPP
#define GGML_SYCL_ARGSORT_HPP

#include "common.hpp"

void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst);

#endif // GGML_SYCL_ARGSORT_HPP
10 changes: 10 additions & 0 deletions ggml/src/ggml-sycl/backend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,16 @@
#include "wkv6.hpp"
#include "outprod.hpp"
#include "element_wise.hpp"
#include "binbcast.hpp"
#include "argmax.hpp"
#include "argsort.hpp"
#include "cpy.hpp"
#include "getrows.hpp"
#include "diagmask.hpp"
#include "scale.hpp"
#include "clamp.hpp"
#include "pool2d.hpp"
#include "sum.hpp"
#include "gla.hpp"

#endif // GGML_SYCL_BACKEND_HPP
Loading
Loading