Skip to content

Commit f73d0ab

Browse files
committed
handling the case when nb11/nb10 != ne10
1 parent d880122 commit f73d0ab

File tree

3 files changed

+38
-16
lines changed

3 files changed

+38
-16
lines changed

ggml/src/ggml-sycl/gemm.hpp

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,30 @@ class DnnlGemmWrapper {
3232
else static_assert(0);
3333
}
3434

35-
static void row_gemm(ggml_backend_sycl_context & ctx, bool a_trans, bool b_trans, int m, int n, int k,
36-
const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q,
37-
dnnl_dim_t batches = 1) {
35+
// matrix A has m rows, k columns
36+
// matrix B has k rows, n columns
37+
// nra - number of elements to skip when moving into next row in A
38+
// nrb - number of elements to skip when moving into next row in B
39+
// nca - number of elements to skip when moving into next column in A
40+
// ncb - number of elements to skip when moving into next column in B
41+
// stride_a - number of elements to skip when moving to next A matrix
42+
// stride_b - number of elements to skip when moving to next B matrix
43+
// batches - number of A matrices, equal to number of B matrices
44+
static void gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,
45+
const void * a, dt at, dnnl_dim_t nra, dnnl_dim_t nca, dnnl_dim_t stride_a,
46+
const void * b, dt bt, dnnl_dim_t nrb, dnnl_dim_t ncb, dnnl_dim_t stride_b,
47+
void * c, dt ct, const queue_ptr & q, dnnl_dim_t batches) {
48+
3849
auto stream = ctx.stream_dnnl(q);
3950
auto eng = ctx.engine_dnnl(q);
4051
dnnl::memory::dims a_dims = { batches, m, k };
4152
dnnl::memory::dims b_dims = { batches, k, n };
4253
dnnl::memory::dims c_dims = { batches, m, n };
43-
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::acb : tag::abc);
44-
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::acb : tag::abc);
54+
dnnl::memory::dims a_strides = { stride_a, nra, nca };
55+
dnnl::memory::dims b_strides = { stride_b, nrb, ncb };
56+
57+
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_strides);
58+
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_strides);
4559
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::abc);
4660

4761
dnnl::primitive_attr primitive_attr;
@@ -64,6 +78,15 @@ class DnnlGemmWrapper {
6478

6579
matmul_prim.execute(stream, matmul_args);
6680
}
81+
82+
// matrices A and B are column major, both having k rows
83+
// matrix A has m column, matrix B has n columns
84+
// output: column major matrix C = A transposed * B
85+
static void row_gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,
86+
const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) {
87+
88+
gemm(ctx, m, n, k, a, at, k, 1, k * m, b, bt, 1, k, n * k, c, ct, q, 1);
89+
}
6790
};
6891

6992
#endif

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2043,7 +2043,7 @@ inline void ggml_sycl_op_mul_mat_sycl(
20432043
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
20442044
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
20452045
#else
2046-
DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ptr,
2046+
DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ptr,
20472047
DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
20482048
dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>(), stream);
20492049
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
@@ -2077,7 +2077,7 @@ inline void ggml_sycl_op_mul_mat_sycl(
20772077
src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10,
20782078
dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
20792079
#else
2080-
DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i,
2080+
DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ddf1_i,
20812081
DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(),
20822082
dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
20832083
#endif
@@ -2774,14 +2774,11 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
27742774

27752775
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
27762776
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
2777-
#ifdef GGML_SYCL_DNNL
2778-
// TODO: use strided dnnl::memory::desc ctor in row_gemm to relax below assertions
2779-
GGML_ASSERT(nb11/nb10 == ne10);
2780-
GGML_ASSERT(nb01/nb00 == ne00);
2781-
2782-
DnnlGemmWrapper::row_gemm(ctx, false, true, ne11, ne01, ne10, src1_f16,
2783-
DnnlGemmWrapper::to_dt<sycl::half>(), src0_as_f16, DnnlGemmWrapper::to_dt<sycl::half>(),
2784-
dst_t, DnnlGemmWrapper::to_dt<float>(), main_stream, ne23);
2777+
#if GGML_SYCL_DNNL
2778+
DnnlGemmWrapper::gemm(ctx, ne11, ne01, ne10,
2779+
src1_f16, DnnlGemmWrapper::to_dt<sycl::half>(), nb11/nb10, 1, nb12/nb10,
2780+
src0_as_f16, DnnlGemmWrapper::to_dt<sycl::half>(), 1, nb01/nb00, nb02/nb00,
2781+
dst_t, DnnlGemmWrapper::to_dt<float>(), main_stream, ne23);
27852782
#else
27862783
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
27872784
*main_stream, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,

tests/test-backend-ops.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3870,7 +3870,7 @@ static const ggml_type other_types[] = {
38703870
// Test cases for evaluation: should try to cover edge cases while using small input sizes to keep the runtime low
38713871
static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
38723872
std::vector<std::unique_ptr<test_case>> test_cases;
3873-
std::default_random_engine rng(0);
3873+
[[maybe_unused]] std::default_random_engine rng(0);
38743874

38753875
// unary ops
38763876
for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
@@ -4188,6 +4188,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
41884188
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 1, 3, 2}));
41894189
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 3, 2, 1}));
41904190

4191+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 1}, {1, 1}, {0, 2, 1, 3}));
4192+
41914193
// test cases with large ne00/ne10 to cover stream-k fixup
41924194
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 1024, {3, 2}, {1, 1}));
41934195
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 1024, {3, 2}, {1, 1}));

0 commit comments

Comments
 (0)