@@ -39,11 +39,39 @@ class DnnlGemmWrapper {
39
39
const void * a, dt at, const void * b, dt bt, void * c, dt ct)
40
40
{
41
41
// Get the device associated with the queue
42
- sycl::device dev = q.get_device ();
43
- // Get the context associated with the queue
44
- sycl::context ctx = q.get_context ();
45
- const dnnl::engine eng = dnnl::sycl_interop::make_engine (dev, ctx);
46
- const dnnl::stream stream = dnnl::sycl_interop::make_stream (eng, q);
42
+ sycl::device dev = q.get_device ();
43
+ // Get the context associated with the queue
44
+ sycl::context ctx = q.get_context ();
45
+ const dnnl::engine eng = dnnl::sycl_interop::make_engine (dev, ctx);
46
+ const dnnl::stream stream = dnnl::sycl_interop::make_stream (eng, q);
47
+ dnnl::memory::dims a_dims = { m, k };
48
+ dnnl::memory::dims b_dims = { k, n };
49
+ dnnl::memory::dims c_dims = { m, n };
50
+ const auto a_in_md = dnnl::memory::desc (a_dims, at, a_trans ? tag::ba : tag::ab);
51
+ const auto b_in_md = dnnl::memory::desc (b_dims, bt, b_trans ? tag::ba : tag::ab);
52
+ const auto c_md = dnnl::memory::desc (c_dims, ct, tag::ab);
53
+ auto a_mem = dnnl::memory (a_in_md, eng, (void *)a);
54
+ auto b_mem = dnnl::memory (b_in_md, eng, (void *)b);
55
+ auto matmul_pd = dnnl::matmul::primitive_desc (eng, a_in_md, b_in_md, c_md);
56
+ auto c_mem = dnnl::memory (matmul_pd.dst_desc (), eng, c);
57
+
58
+ // Create the primitive.
59
+ auto matmul_prim = dnnl::matmul (matmul_pd);
60
+ // Primitive arguments.
61
+ std::unordered_map<int , dnnl::memory> matmul_args;
62
+ matmul_args.insert ({ DNNL_ARG_SRC, a_mem });
63
+ matmul_args.insert ({ DNNL_ARG_WEIGHTS, b_mem });
64
+ matmul_args.insert ({ DNNL_ARG_DST, c_mem });
65
+
66
+ matmul_prim.execute (stream, matmul_args);
67
+ }
68
+
69
+
70
+ static inline void row_gemm (const dnnl::stream& stream, bool a_trans,
71
+ bool b_trans, int m, int n, int k,
72
+ const void * a, dt at, const void * b, dt bt, void * c, dt ct)
73
+ {
74
+ auto const eng = stream.get_engine ();
47
75
dnnl::memory::dims a_dims = { m, k };
48
76
dnnl::memory::dims b_dims = { k, n };
49
77
dnnl::memory::dims c_dims = { m, n };
@@ -66,6 +94,7 @@ class DnnlGemmWrapper {
66
94
matmul_prim.execute (stream, matmul_args);
67
95
}
68
96
};
97
+
69
98
#endif
70
99
71
100
#endif // GGML_SYCL_GEMM_HPP
0 commit comments