@@ -281,26 +281,45 @@ struct ggml_backend_sycl_context {
281
281
}
282
282
283
283
#if GGML_SYCL_DNNL
284
- dnnl::stream make_stream (sycl::queue& q) {
284
+ dnnl::engine make_engine (sycl::queue* q) {
285
285
// Get the device associated with the queue
286
- sycl::device dev = q. get_device ();
286
+ sycl::device dev = q-> get_device ();
287
287
// Get the context associated with the queue
288
- sycl::context ctx = q. get_context ();
288
+ sycl::context ctx = q-> get_context ();
289
289
const dnnl::engine eng = dnnl::sycl_interop::make_engine (dev, ctx);
290
- dnnl::stream stream = dnnl::sycl_interop::make_stream (eng, q);
291
- return stream;
290
+ return eng;
292
291
}
292
+
293
293
std::unordered_map<sycl::queue*, dnnl::stream> stream_map;
294
+ std::unordered_map<sycl::queue*, dnnl::engine> engine_map;
294
295
dnnl::stream stream_dnnl (int device, int _stream) {
295
296
auto q = stream (device, _stream);
296
297
return stream_dnnl (q);
297
298
}
299
+ dnnl::engine engine_dnnl (sycl::queue* qptr) {
300
+ auto it = engine_map.find (qptr);
301
+ if (it == engine_map.end ()) {
302
+ auto eng = make_engine (qptr);
303
+ engine_map[qptr] = eng;
304
+ return eng;
305
+ }
306
+ else
307
+ {
308
+ return it->second ;
309
+ }
310
+ }
298
311
dnnl::stream stream_dnnl (sycl::queue* qptr) {
299
312
auto it = stream_map.find (qptr);
300
313
if (it == stream_map.end ()) {
301
- stream_map[qptr] = make_stream (*qptr);
314
+ auto eng = engine_dnnl (qptr);
315
+ auto stream = dnnl::sycl_interop::make_stream (eng, *qptr);
316
+ stream_map[qptr] = stream;
317
+ return stream;
318
+ }
319
+ else
320
+ {
321
+ return it->second ;
302
322
}
303
- return it->second ;
304
323
}
305
324
dnnl::stream stream_dnnl () {
306
325
return stream_dnnl (device, 0 );
0 commit comments