Skip to content

Commit d880122

Browse files
committed
sycl: use DNN in the first part of ggml_sycl_mul_mat_batched_sycl
1 parent cb06a3c commit d880122

File tree

2 files changed

+33
-25
lines changed

2 files changed

+33
-25
lines changed

ggml/src/ggml-sycl/gemm.hpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,17 @@ class DnnlGemmWrapper {
3232
else static_assert(0);
3333
}
3434

35-
static inline void row_gemm(ggml_backend_sycl_context & ctx, bool a_trans, bool b_trans, int m, int n, int k,
36-
const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) {
35+
static void row_gemm(ggml_backend_sycl_context & ctx, bool a_trans, bool b_trans, int m, int n, int k,
36+
const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q,
37+
dnnl_dim_t batches = 1) {
3738
auto stream = ctx.stream_dnnl(q);
3839
auto eng = ctx.engine_dnnl(q);
39-
dnnl::memory::dims a_dims = { m, k };
40-
dnnl::memory::dims b_dims = { k, n };
41-
dnnl::memory::dims c_dims = { m, n };
42-
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
43-
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
44-
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
40+
dnnl::memory::dims a_dims = { batches, m, k };
41+
dnnl::memory::dims b_dims = { batches, k, n };
42+
dnnl::memory::dims c_dims = { batches, m, n };
43+
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::acb : tag::abc);
44+
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::acb : tag::abc);
45+
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::abc);
4546

4647
dnnl::primitive_attr primitive_attr;
4748
primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1982,7 +1982,7 @@ inline void ggml_sycl_op_mul_mat_sycl(
19821982

19831983
const int64_t ne00 = src0->ne[0];
19841984
const int64_t ne10 = src1->ne[0];
1985-
1985+
GGML_ASSERT(ne00 == ne10);
19861986

19871987
const int64_t row_diff = row_high - row_low;
19881988

@@ -2727,10 +2727,10 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
27272727
GGML_ASSERT(!ggml_is_transposed(src1));
27282728
GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
27292729
GGML_ASSERT(src0->type == GGML_TYPE_F16);
2730+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
27302731

27312732
GGML_TENSOR_BINARY_OP_LOCALS
27322733

2733-
27342734
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
27352735
queue_ptr main_stream = ctx.stream();;
27362736

@@ -2751,39 +2751,45 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
27512751
sycl::half *src1_f16 = src1->type == GGML_TYPE_F16 ? (sycl::half *)src1_ddf
27522752
: src1_f16_alloc.get();
27532753

2754-
char * dst_t;
2755-
2756-
dpct::library_data_t cu_compute_type = dpct::library_data_t::real_float;
2757-
dpct::library_data_t cu_data_type = dpct::library_data_t::real_float;
2758-
2759-
// dst strides
2760-
size_t nbd2 = dst->nb[2];
2761-
size_t nbd3 = dst->nb[3];
2754+
const dpct::library_data_t cu_compute_type = dpct::library_data_t::real_float;
2755+
const dpct::library_data_t cu_data_type = dpct::library_data_t::real_float;
27622756

27632757
const float alpha_f32 = 1.0f;
27642758
const float beta_f32 = 0.0f;
27652759

27662760
const void * alpha = &alpha_f32;
27672761
const void * beta = &beta_f32;
27682762

2769-
dst_t = (char *) dst_ddf;
2763+
char * dst_t = (char *) dst_ddf;
27702764

27712765
GGML_ASSERT(ne12 % ne02 == 0);
27722766
GGML_ASSERT(ne13 % ne03 == 0);
2767+
GGML_ASSERT(ne01 == static_cast<int64_t>(nb1/nb0));
2768+
GGML_ASSERT(ne10 == ne00);
27732769

27742770
// broadcast factors
2775-
const int64_t r2 = ne12/ne02;
2776-
const int64_t r3 = ne13/ne03;
2771+
const auto r2 = ne12/ne02;
2772+
const auto r3 = ne13/ne03;
2773+
const auto ne23 = ne12*ne13;
27772774

27782775
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
27792776
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
2777+
#ifdef GGML_SYCL_DNNL
2778+
// TODO: use strided dnnl::memory::desc ctor in row_gemm to relax below assertions
2779+
GGML_ASSERT(nb11/nb10 == ne10);
2780+
GGML_ASSERT(nb01/nb00 == ne00);
2781+
2782+
DnnlGemmWrapper::row_gemm(ctx, false, true, ne11, ne01, ne10, src1_f16,
2783+
DnnlGemmWrapper::to_dt<sycl::half>(), src0_as_f16, DnnlGemmWrapper::to_dt<sycl::half>(),
2784+
dst_t, DnnlGemmWrapper::to_dt<float>(), main_stream, ne23);
2785+
#else
27802786
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
27812787
*main_stream, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
27822788
(const char *) src0_as_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00,
27832789
(const char *) src1_f16, dpct::library_data_t::real_half, nb11 / nb10, nb12 / nb10, beta, (char *) dst_t,
2784-
cu_data_type, ne01, nb2 / nb0, ne12 * ne13, cu_compute_type)));
2790+
cu_data_type, ne01, nb2 / nb0, ne23, cu_compute_type)));
2791+
#endif
27852792
} else {
2786-
const int ne23 = ne12*ne13;
27872793

27882794
ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
27892795
ggml_sycl_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
@@ -2811,7 +2817,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
28112817
dst_t, ptrs_src_get,
28122818
ptrs_dst_get, ne12, ne13, ne23,
28132819
nb02, nb03, nb12_scaled, nb13_scaled,
2814-
nbd2, nbd3, r2, r3, item_ct1);
2820+
nb2, nb3, r2, r3, item_ct1);
28152821
});
28162822
});
28172823
}
@@ -3651,7 +3657,8 @@ static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_
36513657
return GGML_STATUS_SUCCESS;
36523658
}
36533659

3654-
sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream()));
3660+
sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream()), {sycl_ex::property::graph::assume_buffer_outlives_graph{}});
3661+
36553662
model_sycl_graph.begin_recording(*(sycl_ctx->stream()));
36563663
ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
36573664
model_sycl_graph.end_recording();

0 commit comments

Comments
 (0)