Skip to content

Commit c751e65

Browse files
committed
add engine map
1 parent 4dc5515 commit c751e65

File tree

2 files changed

+31
-7
lines changed

2 files changed

+31
-7
lines changed

ggml/src/ggml-sycl.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2495,10 +2495,15 @@ inline void ggml_sycl_op_mul_mat_sycl(
24952495
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
24962496
#else
24972497
auto dnnl_stream = ctx.stream_dnnl(stream);
2498+
#if 0
24982499
DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
24992500
src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(), dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>());
25002501
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
25012502
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
2503+
#else
2504+
DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
2505+
src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(), dst_dd_i, DnnlGemmWrapper::to_dt<float>());
2506+
#endif
25022507
#endif
25032508
}
25042509
else {

ggml/src/ggml-sycl/common.hpp

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -282,26 +282,45 @@ struct ggml_backend_sycl_context {
282282
}
283283

284284
#if GGML_SYCL_DNNL
285-
dnnl::stream make_stream(sycl::queue& q) {
285+
dnnl::engine make_engine(sycl::queue* q) {
286286
// Get the device associated with the queue
287-
sycl::device dev = q.get_device();
287+
sycl::device dev = q->get_device();
288288
// Get the context associated with the queue
289-
sycl::context ctx = q.get_context();
289+
sycl::context ctx = q->get_context();
290290
const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx);
291-
dnnl::stream stream = dnnl::sycl_interop::make_stream(eng, q);
292-
return stream;
291+
return eng;
293292
}
293+
294294
std::unordered_map<sycl::queue*, dnnl::stream> stream_map;
295+
std::unordered_map<sycl::queue*, dnnl::engine> engine_map;
295296
dnnl::stream stream_dnnl(int device, int _stream) {
296297
auto q = stream(device, _stream);
297298
return stream_dnnl(q);
298299
}
300+
dnnl::engine engine_dnnl(sycl::queue* qptr) {
301+
auto it = engine_map.find(qptr);
302+
if (it == engine_map.end()) {
303+
auto eng = make_engine(qptr);
304+
engine_map[qptr] = eng;
305+
return eng;
306+
}
307+
else
308+
{
309+
return it->second;
310+
}
311+
}
299312
dnnl::stream stream_dnnl(sycl::queue* qptr) {
300313
auto it = stream_map.find(qptr);
301314
if (it == stream_map.end()) {
302-
stream_map[qptr] = make_stream(*qptr);
315+
auto eng = engine_dnnl(qptr);
316+
auto stream = dnnl::sycl_interop::make_stream(eng, *qptr);
317+
stream_map[qptr] = stream;
318+
return stream;
319+
}
320+
else
321+
{
322+
return it->second;
303323
}
304-
return it->second;
305324
}
306325
dnnl::stream stream_dnnl() {
307326
return stream_dnnl(device, 0);

0 commit comments

Comments
 (0)