@@ -32,16 +32,30 @@ class DnnlGemmWrapper {
32
32
else static_assert (0 );
33
33
}
34
34
35
- static void row_gemm (ggml_backend_sycl_context & ctx, bool a_trans, bool b_trans, int m, int n, int k,
36
- const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q,
37
- dnnl_dim_t batches = 1 ) {
35
+ // matrix A has m rows, k columns
36
+ // matrix B has k rows, n columns
37
+ // nra - number of elements to skip when moving into next row in A
38
+ // nrb - number of elements to skip when moving into next row in B
39
+ // nca - number of elements to skip when moving into next column in A
40
+ // ncb - number of elements to skip when moving into next column in B
41
+ // stride_a - number of elements to skip when moving to next A matrix
42
+ // stride_b - number of elements to skip when moving to next B matrix
43
+ // batches - number of A matrices, equal to number of B matrices
44
+ static void gemm (ggml_backend_sycl_context & ctx, int m, int n, int k,
45
+ const void * a, dt at, dnnl_dim_t nra, dnnl_dim_t nca, dnnl_dim_t stride_a,
46
+ const void * b, dt bt, dnnl_dim_t nrb, dnnl_dim_t ncb, dnnl_dim_t stride_b,
47
+ void * c, dt ct, const queue_ptr & q, dnnl_dim_t batches) {
48
+
38
49
auto stream = ctx.stream_dnnl (q);
39
50
auto eng = ctx.engine_dnnl (q);
40
51
dnnl::memory::dims a_dims = { batches, m, k };
41
52
dnnl::memory::dims b_dims = { batches, k, n };
42
53
dnnl::memory::dims c_dims = { batches, m, n };
43
- const auto a_in_md = dnnl::memory::desc (a_dims, at, a_trans ? tag::acb : tag::abc);
44
- const auto b_in_md = dnnl::memory::desc (b_dims, bt, b_trans ? tag::acb : tag::abc);
54
+ dnnl::memory::dims a_strides = { stride_a, nra, nca };
55
+ dnnl::memory::dims b_strides = { stride_b, nrb, ncb };
56
+
57
+ const auto a_in_md = dnnl::memory::desc (a_dims, at, a_strides);
58
+ const auto b_in_md = dnnl::memory::desc (b_dims, bt, b_strides);
45
59
const auto c_md = dnnl::memory::desc (c_dims, ct, tag::abc);
46
60
47
61
dnnl::primitive_attr primitive_attr;
@@ -64,6 +78,15 @@ class DnnlGemmWrapper {
64
78
65
79
matmul_prim.execute (stream, matmul_args);
66
80
}
81
+
82
+ // matrices A and B are column major, both having k rows
83
+ // matrix A has m column, matrix B has n columns
84
+ // output: column major matrix C = A transposed * B
85
+ static void row_gemm (ggml_backend_sycl_context & ctx, int m, int n, int k,
86
+ const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) {
87
+
88
+ gemm (ctx, m, n, k, a, at, k, 1 , k * m, b, bt, 1 , k, n * k, c, ct, q, 1 );
89
+ }
67
90
};
68
91
69
92
#endif
0 commit comments