Skip to content

Commit bed2c0c

Browse files
committed
add sycl_f16
1 parent 7c545a6 commit bed2c0c

File tree

3 files changed

+43
-9
lines changed

3 files changed

+43
-9
lines changed

CMakePresets.json

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
{ "name": "release", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "Release" } },
2929
{ "name": "reldbg", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "RelWithDebInfo" } },
3030
{ "name": "static", "hidden": true, "cacheVariables": { "GGML_STATIC": "ON" } },
31+
{ "name": "sycl_f16", "hidden": true, "cacheVariables": { "GGML_SYCL_F16": "ON" } },
3132

3233
{
3334
"name": "arm64-windows-msvc", "hidden": true,
@@ -60,6 +61,8 @@
6061
{ "name": "x64-windows-msvc+static-release", "inherits": [ "base", "reldbg", "static" ] },
6162

6263
{ "name": "x64-windows-sycl-debug" , "inherits": [ "sycl-base", "debug" ] },
63-
{ "name": "x64-windows-sycl-release", "inherits": [ "sycl-base", "release" ] }
64+
{ "name": "x64-windows-sycl-debug-f16", "inherits": [ "sycl-base", "debug", "sycl_f16" ] },
65+
{ "name": "x64-windows-sycl-release", "inherits": [ "sycl-base", "release" ] },
66+
{ "name": "x64-windows-sycl-release-f16", "inherits": [ "sycl-base", "release", "sycl_f16" ] }
6467
]
6568
}

ggml/src/ggml-sycl.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2546,7 +2546,7 @@ inline void ggml_sycl_op_mul_mat_sycl(
25462546

25472547
const sycl::half alpha_f16 = 1.0f;
25482548
const sycl::half beta_f16 = 0.0f;
2549-
#if GGML_SYCL_DNNL
2549+
#if !GGML_SYCL_DNNL
25502550
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
25512551
*stream, oneapi::mkl::transpose::trans,
25522552
oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
@@ -2558,7 +2558,9 @@ inline void ggml_sycl_op_mul_mat_sycl(
25582558
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
25592559
#else
25602560
DnnlGemmWrapper::row_gemm(*stream, false, true, src1_ncols, row_diff, ne10, src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
2561-
src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(), dst_dd_i, DnnlGemmWrapper::to_dt<float>());
2561+
src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(), dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>());
2562+
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
2563+
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
25622564
#endif
25632565
}
25642566
else {
@@ -2582,7 +2584,7 @@ inline void ggml_sycl_op_mul_mat_sycl(
25822584

25832585
const float alpha = 1.0f;
25842586
const float beta = 0.0f;
2585-
#if GGML_SYCL_DNNL
2587+
#if !GGML_SYCL_DNNL
25862588
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
25872589
*stream, oneapi::mkl::transpose::trans,
25882590
oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,

ggml/src/ggml-sycl/gemm.hpp

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,39 @@ class DnnlGemmWrapper {
3939
const void* a, dt at, const void* b, dt bt, void* c, dt ct)
4040
{
4141
// Get the device associated with the queue
42-
sycl::device dev = q.get_device();
43-
// Get the context associated with the queue
44-
sycl::context ctx = q.get_context();
45-
const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx);
46-
const dnnl::stream stream = dnnl::sycl_interop::make_stream(eng, q);
42+
sycl::device dev = q.get_device();
43+
// Get the context associated with the queue
44+
sycl::context ctx = q.get_context();
45+
const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx);
46+
const dnnl::stream stream = dnnl::sycl_interop::make_stream(eng, q);
47+
dnnl::memory::dims a_dims = { m, k };
48+
dnnl::memory::dims b_dims = { k, n };
49+
dnnl::memory::dims c_dims = { m, n };
50+
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
51+
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
52+
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
53+
auto a_mem = dnnl::memory(a_in_md, eng, (void*)a);
54+
auto b_mem = dnnl::memory(b_in_md, eng, (void*)b);
55+
auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md);
56+
auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
57+
58+
// Create the primitive.
59+
auto matmul_prim = dnnl::matmul(matmul_pd);
60+
// Primitive arguments.
61+
std::unordered_map<int, dnnl::memory> matmul_args;
62+
matmul_args.insert({ DNNL_ARG_SRC, a_mem });
63+
matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
64+
matmul_args.insert({ DNNL_ARG_DST, c_mem });
65+
66+
matmul_prim.execute(stream, matmul_args);
67+
}
68+
69+
70+
static inline void row_gemm(const dnnl::stream& stream, bool a_trans,
71+
bool b_trans, int m, int n, int k,
72+
const void* a, dt at, const void* b, dt bt, void* c, dt ct)
73+
{
74+
auto const eng = stream.get_engine();
4775
dnnl::memory::dims a_dims = { m, k };
4876
dnnl::memory::dims b_dims = { k, n };
4977
dnnl::memory::dims c_dims = { m, n };
@@ -66,6 +94,7 @@ class DnnlGemmWrapper {
6694
matmul_prim.execute(stream, matmul_args);
6795
}
6896
};
97+
6998
#endif
7099

71100
#endif // GGML_SYCL_GEMM_HPP

0 commit comments

Comments
 (0)