Skip to content

Commit 7c545a6

Browse files
committed
add onednn
1 parent 554b049 commit 7c545a6

File tree

3 files changed

+86
-2
lines changed

3 files changed

+86
-2
lines changed

ggml/src/CMakeLists.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -549,10 +549,13 @@ if (GGML_SYCL)
549549
file(GLOB GGML_SOURCES_SYCL "ggml-sycl/*.cpp")
550550
list(APPEND GGML_SOURCES_SYCL "ggml-sycl.cpp")
551551

552+
find_package(DNNL)
553+
message("-- DNNL found:"${DNNL_FOUND})
554+
add_compile_definitions(GGML_SYCL_DNNL=${DNNL_FOUND})
552555
if (WIN32)
553556
find_package(IntelSYCL REQUIRED)
554557
find_package(MKL REQUIRED)
555-
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} IntelSYCL::SYCL_CXX MKL::MKL MKL::MKL_SYCL)
558+
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} IntelSYCL::SYCL_CXX MKL::MKL MKL::MKL_SYCL DNNL::dnnl)
556559
else()
557560
if (GGML_SYCL_TARGET STREQUAL "INTEL")
558561
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} -fsycl OpenCL mkl_core pthread m dl mkl_sycl_blas mkl_intel_ilp64 mkl_tbb_thread)

ggml/src/ggml-sycl.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838

3939
#include "ggml-sycl/backend.hpp"
4040
#include "ggml-sycl/presets.hpp"
41+
#include "ggml-sycl/gemm.hpp"
4142

4243
bool ggml_sycl_loaded(void);
4344
void ggml_sycl_free_data(struct ggml_tensor * tensor);
@@ -2545,6 +2546,7 @@ inline void ggml_sycl_op_mul_mat_sycl(
25452546

25462547
const sycl::half alpha_f16 = 1.0f;
25472548
const sycl::half beta_f16 = 0.0f;
2549+
#if GGML_SYCL_DNNL
25482550
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
25492551
*stream, oneapi::mkl::transpose::trans,
25502552
oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
@@ -2554,6 +2556,10 @@ inline void ggml_sycl_op_mul_mat_sycl(
25542556
dpct::library_data_t::real_half)));
25552557
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
25562558
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
2559+
#else
2560+
DnnlGemmWrapper::row_gemm(*stream, false, true, src1_ncols, row_diff, ne10, src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
2561+
src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(), dst_dd_i, DnnlGemmWrapper::to_dt<float>());
2562+
#endif
25572563
}
25582564
else {
25592565
// GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp32 path\n");
@@ -2576,13 +2582,17 @@ inline void ggml_sycl_op_mul_mat_sycl(
25762582

25772583
const float alpha = 1.0f;
25782584
const float beta = 0.0f;
2579-
2585+
#if GGML_SYCL_DNNL
25802586
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
25812587
*stream, oneapi::mkl::transpose::trans,
25822588
oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
25832589
dpct::get_value(&alpha, *stream), src0_ddf_i, ne00,
25842590
src1_ddf1_i, ne10, dpct::get_value(&beta, *stream),
25852591
dst_dd_i, ldc)));
2592+
#else
2593+
DnnlGemmWrapper::row_gemm(*stream, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i, DnnlGemmWrapper::to_dt<float>(),
2594+
src0_ddf_i, DnnlGemmWrapper::to_dt<float>(), dst_dd_i, DnnlGemmWrapper::to_dt<float>());
2595+
#endif
25862596
}
25872597
(void) dst;
25882598
(void) src1_ddq_i;

ggml/src/ggml-sycl/gemm.hpp

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
//
2+
// MIT license
3+
// Copyright (C) 2024 Intel Corporation
4+
// SPDX-License-Identifier: MIT
5+
//
6+
7+
//
8+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
9+
// See https://llvm.org/LICENSE.txt for license information.
10+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
11+
//
12+
13+
#ifndef GGML_SYCL_GEMM_HPP
14+
#define GGML_SYCL_GEMM_HPP
15+
16+
#include <fstream>
17+
#include <iostream>
18+
19+
#include "ggml-sycl.h"
20+
#include "dnnl.hpp"
21+
#include "dnnl_sycl.hpp"
22+
23+
#if GGML_SYCL_DNNL
24+
25+
class DnnlGemmWrapper {
26+
public:
27+
using dt = dnnl::memory::data_type;
28+
using tag = dnnl::memory::format_tag;
29+
30+
template<typename T>
31+
static constexpr dt to_dt() {
32+
if constexpr (std::is_same_v<T, float>) return dt::f32;
33+
else if constexpr (std::is_same_v<T, sycl::half>) return dt::f16;
34+
else static_assert(0);
35+
}
36+
37+
static inline void row_gemm(sycl::queue& q, bool a_trans,
38+
bool b_trans, int m, int n, int k,
39+
const void* a, dt at, const void* b, dt bt, void* c, dt ct)
40+
{
41+
// Get the device associated with the queue
42+
sycl::device dev = q.get_device();
43+
// Get the context associated with the queue
44+
sycl::context ctx = q.get_context();
45+
const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx);
46+
const dnnl::stream stream = dnnl::sycl_interop::make_stream(eng, q);
47+
dnnl::memory::dims a_dims = { m, k };
48+
dnnl::memory::dims b_dims = { k, n };
49+
dnnl::memory::dims c_dims = { m, n };
50+
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
51+
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
52+
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
53+
auto a_mem = dnnl::memory(a_in_md, eng, (void*)a);
54+
auto b_mem = dnnl::memory(b_in_md, eng, (void*)b);
55+
auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md);
56+
auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
57+
58+
// Create the primitive.
59+
auto matmul_prim = dnnl::matmul(matmul_pd);
60+
// Primitive arguments.
61+
std::unordered_map<int, dnnl::memory> matmul_args;
62+
matmul_args.insert({ DNNL_ARG_SRC, a_mem });
63+
matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
64+
matmul_args.insert({ DNNL_ARG_DST, c_mem });
65+
66+
matmul_prim.execute(stream, matmul_args);
67+
}
68+
};
69+
#endif
70+
71+
#endif // GGML_SYCL_GEMM_HPP

0 commit comments

Comments
 (0)