Skip to content

Commit 0304cdc

Browse files
Merge pull request #1076 from IntelPython/cached-device-store
Cached device store
2 parents 1754be7 + 0ed6e67 commit 0304cdc

File tree

7 files changed

+99
-11
lines changed

7 files changed

+99
-11
lines changed

dpctl/_sycl_queue_manager.pxd

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,12 @@
1717
# distutils: language = c++
1818
# cython: language_level=3
1919

20+
from ._sycl_device cimport SyclDevice
2021
from ._sycl_queue cimport SyclQueue
2122

2223

2324
cpdef SyclQueue get_current_queue()
2425
cpdef get_current_device_type ()
2526
cpdef get_current_backend()
27+
28+
cpdef object get_device_cached_queue(object)

dpctl/_sycl_queue_manager.pyx

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import logging
2222
from contextlib import ExitStack, contextmanager
23+
from contextvars import ContextVar
2324

2425
from .enum_types import backend_type, device_type
2526

@@ -35,6 +36,7 @@ from ._backend cimport ( # noqa: E211
3536
_device_type,
3637
)
3738
from ._sycl_context cimport SyclContext
39+
from ._sycl_device cimport SyclDevice
3840

3941
__all__ = [
4042
"device_context",
@@ -44,6 +46,7 @@ __all__ = [
4446
"get_num_activated_queues",
4547
"is_in_device_context",
4648
"set_global_queue",
49+
"_global_device_queue_cache",
4750
]
4851

4952
_logger = logging.getLogger(__name__)
@@ -291,3 +294,45 @@ def device_context(arg):
291294
_mgr._remove_current_queue()
292295
else:
293296
_logger.debug("No queue was created so nothing to do")
297+
298+
299+
cdef class _DeviceDefaultQueueCache:
300+
cdef dict __device_queue_map__
301+
302+
def __cinit__(self):
303+
self.__device_queue_map__ = dict()
304+
305+
def get_or_create(self, key):
306+
"""Return instance of SyclQueue and indicator if cache has been modified"""
307+
if isinstance(key, tuple) and len(key) == 2 and isinstance(key[0], SyclContext) and isinstance(key[1], SyclDevice):
308+
ctx_dev = key
309+
q = None
310+
elif isinstance(key, SyclDevice):
311+
q = SyclQueue(key)
312+
ctx_dev = q.sycl_context, key
313+
else:
314+
raise TypeError
315+
if ctx_dev in self.__device_queue_map__:
316+
return self.__device_queue_map__[ctx_dev], False
317+
if q is None: q = SyclQueue(*ctx_dev)
318+
self.__device_queue_map__[ctx_dev] = q
319+
return q, True
320+
321+
cdef _update_map(self, dev_queue_map):
322+
self.__device_queue_map__.update(dev_queue_map)
323+
324+
def __copy__(self):
325+
cdef _DeviceDefaultQueueCache _copy = _DeviceDefaultQueueCache.__new__(_DeviceDefaultQueueCache)
326+
_copy._update_map(self.__device_queue_map__)
327+
return _copy
328+
329+
330+
_global_device_queue_cache = ContextVar('global_device_queue_cache', default=_DeviceDefaultQueueCache())
331+
332+
333+
cpdef object get_device_cached_queue(object key):
334+
"""Get cached queue associated with given device"""
335+
_cache = _global_device_queue_cache.get()
336+
q_, changed_ = _cache.get_or_create(key)
337+
if changed_: _global_device_queue_cache.set(_cache)
338+
return q_

dpctl/memory/_memory.pyx

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ from dpctl._backend cimport ( # noqa: E211
6161
from .._sycl_context cimport SyclContext
6262
from .._sycl_device cimport SyclDevice
6363
from .._sycl_queue cimport SyclQueue
64+
from .._sycl_queue_manager cimport get_device_cached_queue
6465

6566
import collections
6667
import numbers
@@ -150,7 +151,7 @@ cdef class _Memory:
150151

151152
if (nbytes > 0):
152153
if queue is None:
153-
queue = dpctl.SyclQueue()
154+
queue = get_device_cached_queue(dpctl.SyclDevice())
154155

155156
QRef = queue.get_queue_ref()
156157
if (ptr_type == b"shared"):

dpctl/tensor/_device.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
import dpctl
17+
from dpctl._sycl_queue_manager import get_device_cached_queue
1718

1819
__doc__ = "Implementation of array API mandated Device class"
1920

@@ -60,9 +61,7 @@ def create_device(cls, dev):
6061
elif isinstance(dev, dpctl.SyclDevice):
6162
par = dev.parent_device
6263
if par is None:
63-
if dev not in cls.__device_queue_map__:
64-
cls.__device_queue_map__[dev] = dpctl.SyclQueue(dev)
65-
obj.sycl_queue_ = cls.__device_queue_map__[dev]
64+
obj.sycl_queue_ = get_device_cached_queue(dev)
6665
else:
6766
raise ValueError(
6867
f"Using non-root device {dev} to specify offloading "
@@ -74,9 +73,7 @@ def create_device(cls, dev):
7473
_dev = dpctl.SyclDevice()
7574
else:
7675
_dev = dpctl.SyclDevice(dev)
77-
if _dev not in cls.__device_queue_map__:
78-
cls.__device_queue_map__[_dev] = dpctl.SyclQueue(_dev)
79-
obj.sycl_queue_ = cls.__device_queue_map__[_dev]
76+
obj.sycl_queue_ = get_device_cached_queue(_dev)
8077
return obj
8178

8279
@property

dpctl/tensor/_dlpack.pyx

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ from libc.stdint cimport int32_t, int64_t, uint8_t, uint16_t, uint64_t
2424

2525
cimport dpctl as c_dpctl
2626
cimport dpctl.memory as c_dpmem
27+
from dpctl._sycl_queue_manager cimport get_device_cached_queue
2728

2829
from .._backend cimport (
2930
DPCTLDevice_Delete,
@@ -344,12 +345,12 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
344345
if _IS_LINUX:
345346
default_context = root_device.sycl_platform.default_context
346347
else:
347-
default_context = dpctl.SyclQueue(root_device).sycl_context
348+
default_context = get_device_cached_queue(root_device).sycl_context
348349
except RuntimeError:
349-
default_context = dpctl.SyclQueue(root_device).sycl_context
350+
default_context = get_device_cached_queue(root_device).sycl_context
350351
if dlm_tensor.dl_tensor.data is NULL:
351352
usm_type = b"device"
352-
q = dpctl.SyclQueue(default_context, root_device)
353+
q = get_device_cached_queue((default_context, root_device,))
353354
else:
354355
usm_type = c_dpmem._Memory.get_pointer_type(
355356
<DPCTLSyclUSMRef> dlm_tensor.dl_tensor.data,
@@ -364,7 +365,7 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
364365
<DPCTLSyclUSMRef> dlm_tensor.dl_tensor.data,
365366
<c_dpctl.SyclContext>default_context
366367
)
367-
q = dpctl.SyclQueue(default_context, alloc_device)
368+
q = get_device_cached_queue((default_context, alloc_device,))
368369
if dlm_tensor.dl_tensor.dtype.bits % 8:
369370
raise BufferError(
370371
"Can not import DLPack tensor whose element's "

dpctl/tests/test_sycl_queue_manager.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,3 +226,22 @@ def test_nested_context_factory_exception_if_wrong_factory(
226226
with _register_nested_context_factory(factory):
227227
with dpctl.device_context("opencl:cpu:0"):
228228
pass
229+
230+
231+
def test__DeviceDefaultQueueCache():
232+
import copy
233+
234+
from dpctl._sycl_queue_manager import _global_device_queue_cache as cache
235+
from dpctl._sycl_queue_manager import get_device_cached_queue
236+
237+
try:
238+
d = dpctl.SyclDevice()
239+
except dpctl.SyclDeviceCreationError:
240+
pytest.skip("Could not create default device")
241+
242+
q1 = get_device_cached_queue(d)
243+
cache_copy = copy.copy(cache.get())
244+
q2, changed = cache_copy.get_or_create(d)
245+
246+
assert not changed
247+
assert q1 == q2

dpctl/tests/test_usm_ndarray_dlpack.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,28 @@ def test_dlpack_exporter(typestr, usm_type):
8181
assert caps_fn(caps2, b"dltensor")
8282

8383

84+
def test_dlpack_exporter_empty(typestr, usm_type):
85+
caps_fn = ctypes.pythonapi.PyCapsule_IsValid
86+
caps_fn.restype = bool
87+
caps_fn.argtypes = [ctypes.py_object, ctypes.c_char_p]
88+
sycl_dev = dpctl.select_default_device()
89+
skip_if_dtype_not_supported(typestr, sycl_dev)
90+
X = dpt.empty((0,), dtype=typestr, usm_type=usm_type, device=sycl_dev)
91+
caps = X.__dlpack__()
92+
assert caps_fn(caps, b"dltensor")
93+
Y = dpt.empty(
94+
(
95+
1,
96+
0,
97+
),
98+
dtype=typestr,
99+
usm_type=usm_type,
100+
device=sycl_dev,
101+
)
102+
caps = Y.__dlpack__()
103+
assert caps_fn(caps, b"dltensor")
104+
105+
84106
def test_dlpack_exporter_stream():
85107
try:
86108
q1 = dpctl.SyclQueue()

0 commit comments

Comments
 (0)