Skip to content

Commit 6d957b5

Browse files
committed
add engine map
1 parent 79d2005 commit 6d957b5

File tree

2 files changed

+31
-7
lines changed

2 files changed

+31
-7
lines changed

ggml/src/ggml-sycl.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2558,10 +2558,15 @@ inline void ggml_sycl_op_mul_mat_sycl(
25582558
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
25592559
#else
25602560
auto dnnl_stream = ctx.stream_dnnl(stream);
2561+
#if 0
25612562
DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
25622563
src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(), dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>());
25632564
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
25642565
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
2566+
#else
2567+
DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
2568+
src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(), dst_dd_i, DnnlGemmWrapper::to_dt<float>());
2569+
#endif
25652570
#endif
25662571
}
25672572
else {

ggml/src/ggml-sycl/common.hpp

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -281,26 +281,45 @@ struct ggml_backend_sycl_context {
281281
}
282282

283283
#if GGML_SYCL_DNNL
284-
dnnl::stream make_stream(sycl::queue& q) {
284+
dnnl::engine make_engine(sycl::queue* q) {
285285
// Get the device associated with the queue
286-
sycl::device dev = q.get_device();
286+
sycl::device dev = q->get_device();
287287
// Get the context associated with the queue
288-
sycl::context ctx = q.get_context();
288+
sycl::context ctx = q->get_context();
289289
const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx);
290-
dnnl::stream stream = dnnl::sycl_interop::make_stream(eng, q);
291-
return stream;
290+
return eng;
292291
}
292+
293293
std::unordered_map<sycl::queue*, dnnl::stream> stream_map;
294+
std::unordered_map<sycl::queue*, dnnl::engine> engine_map;
294295
dnnl::stream stream_dnnl(int device, int _stream) {
295296
auto q = stream(device, _stream);
296297
return stream_dnnl(q);
297298
}
299+
dnnl::engine engine_dnnl(sycl::queue* qptr) {
300+
auto it = engine_map.find(qptr);
301+
if (it == engine_map.end()) {
302+
auto eng = make_engine(qptr);
303+
engine_map[qptr] = eng;
304+
return eng;
305+
}
306+
else
307+
{
308+
return it->second;
309+
}
310+
}
298311
dnnl::stream stream_dnnl(sycl::queue* qptr) {
299312
auto it = stream_map.find(qptr);
300313
if (it == stream_map.end()) {
301-
stream_map[qptr] = make_stream(*qptr);
314+
auto eng = engine_dnnl(qptr);
315+
auto stream = dnnl::sycl_interop::make_stream(eng, *qptr);
316+
stream_map[qptr] = stream;
317+
return stream;
318+
}
319+
else
320+
{
321+
return it->second;
302322
}
303-
return it->second;
304323
}
305324
dnnl::stream stream_dnnl() {
306325
return stream_dnnl(device, 0);

0 commit comments

Comments
 (0)