|
| 1 | +// |
| 2 | +// MIT license |
| 3 | +// Copyright (C) 2024 Intel Corporation |
| 4 | +// SPDX-License-Identifier: MIT |
| 5 | +// |
| 6 | + |
| 7 | +// |
| 8 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 9 | +// See https://llvm.org/LICENSE.txt for license information. |
| 10 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 11 | +// |
| 12 | + |
| 13 | +#ifndef GGML_SYCL_GEMM_HPP |
| 14 | +#define GGML_SYCL_GEMM_HPP |
| 15 | + |
| 16 | +#include <fstream> |
| 17 | +#include <iostream> |
| 18 | + |
| 19 | +#include "ggml-sycl.h" |
| 20 | +#include "dnnl.hpp" |
| 21 | +#include "dnnl_sycl.hpp" |
| 22 | + |
| 23 | +#if GGML_SYCL_DNNL |
| 24 | + |
| 25 | +class DnnlGemmWrapper { |
| 26 | +public: |
| 27 | + using dt = dnnl::memory::data_type; |
| 28 | + using tag = dnnl::memory::format_tag; |
| 29 | + |
| 30 | + template<typename T> |
| 31 | + static constexpr dt to_dt() { |
| 32 | + if constexpr (std::is_same_v<T, float>) return dt::f32; |
| 33 | + else if constexpr (std::is_same_v<T, sycl::half>) return dt::f16; |
| 34 | + else static_assert(0); |
| 35 | + } |
| 36 | + |
| 37 | + static inline void row_gemm(sycl::queue& q, bool a_trans, |
| 38 | + bool b_trans, int m, int n, int k, |
| 39 | + const void* a, dt at, const void* b, dt bt, void* c, dt ct) |
| 40 | + { |
| 41 | + // Get the device associated with the queue |
| 42 | + sycl::device dev = q.get_device(); |
| 43 | + // Get the context associated with the queue |
| 44 | + sycl::context ctx = q.get_context(); |
| 45 | + const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx); |
| 46 | + const dnnl::stream stream = dnnl::sycl_interop::make_stream(eng, q); |
| 47 | + dnnl::memory::dims a_dims = { m, k }; |
| 48 | + dnnl::memory::dims b_dims = { k, n }; |
| 49 | + dnnl::memory::dims c_dims = { m, n }; |
| 50 | + const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab); |
| 51 | + const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab); |
| 52 | + const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab); |
| 53 | + auto a_mem = dnnl::memory(a_in_md, eng, (void*)a); |
| 54 | + auto b_mem = dnnl::memory(b_in_md, eng, (void*)b); |
| 55 | + auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md); |
| 56 | + auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c); |
| 57 | + |
| 58 | + // Create the primitive. |
| 59 | + auto matmul_prim = dnnl::matmul(matmul_pd); |
| 60 | + // Primitive arguments. |
| 61 | + std::unordered_map<int, dnnl::memory> matmul_args; |
| 62 | + matmul_args.insert({ DNNL_ARG_SRC, a_mem }); |
| 63 | + matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem }); |
| 64 | + matmul_args.insert({ DNNL_ARG_DST, c_mem }); |
| 65 | + |
| 66 | + matmul_prim.execute(stream, matmul_args); |
| 67 | + } |
| 68 | +}; |
| 69 | +#endif |
| 70 | + |
| 71 | +#endif // GGML_SYCL_GEMM_HPP |
0 commit comments