@@ -282,26 +282,45 @@ struct ggml_backend_sycl_context {
282
282
}
283
283
284
284
#if GGML_SYCL_DNNL
285
- dnnl::stream make_stream (sycl::queue& q) {
285
+ dnnl::engine make_engine (sycl::queue* q) {
286
286
// Get the device associated with the queue
287
- sycl::device dev = q. get_device ();
287
+ sycl::device dev = q-> get_device ();
288
288
// Get the context associated with the queue
289
- sycl::context ctx = q. get_context ();
289
+ sycl::context ctx = q-> get_context ();
290
290
const dnnl::engine eng = dnnl::sycl_interop::make_engine (dev, ctx);
291
- dnnl::stream stream = dnnl::sycl_interop::make_stream (eng, q);
292
- return stream;
291
+ return eng;
293
292
}
293
+
294
294
std::unordered_map<sycl::queue*, dnnl::stream> stream_map;
295
+ std::unordered_map<sycl::queue*, dnnl::engine> engine_map;
295
296
dnnl::stream stream_dnnl (int device, int _stream) {
296
297
auto q = stream (device, _stream);
297
298
return stream_dnnl (q);
298
299
}
300
+ dnnl::engine engine_dnnl (sycl::queue* qptr) {
301
+ auto it = engine_map.find (qptr);
302
+ if (it == engine_map.end ()) {
303
+ auto eng = make_engine (qptr);
304
+ engine_map[qptr] = eng;
305
+ return eng;
306
+ }
307
+ else
308
+ {
309
+ return it->second ;
310
+ }
311
+ }
299
312
dnnl::stream stream_dnnl (sycl::queue* qptr) {
300
313
auto it = stream_map.find (qptr);
301
314
if (it == stream_map.end ()) {
302
- stream_map[qptr] = make_stream (*qptr);
315
+ auto eng = engine_dnnl (qptr);
316
+ auto stream = dnnl::sycl_interop::make_stream (eng, *qptr);
317
+ stream_map[qptr] = stream;
318
+ return stream;
319
+ }
320
+ else
321
+ {
322
+ return it->second ;
303
323
}
304
- return it->second ;
305
324
}
306
325
dnnl::stream stream_dnnl () {
307
326
return stream_dnnl (device, 0 );
0 commit comments