Skip to content

Commit 0bfc6f8

Browse files
Implemented SyclPlatform.__eq__ and SyclPlatform.__hash__
1 parent 358fc5d commit 0bfc6f8

File tree

3 files changed

+45
-0
lines changed

3 files changed

+45
-0
lines changed

dpctl/_backend.pxd

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ cdef extern from "syclinterface/dpctl_sycl_platform_manager.h":
299299

300300

301301
cdef extern from "syclinterface/dpctl_sycl_platform_interface.h":
302+
cdef bool DPCTLPlatform_AreEq(const DPCTLSyclPlatformRef, const DPCTLSyclPlatformRef)
302303
cdef DPCTLSyclPlatformRef DPCTLPlatform_Copy(const DPCTLSyclPlatformRef)
303304
cdef DPCTLSyclPlatformRef DPCTLPlatform_Create()
304305
cdef DPCTLSyclPlatformRef DPCTLPlatform_CreateFromSelector(
@@ -308,6 +309,7 @@ cdef extern from "syclinterface/dpctl_sycl_platform_interface.h":
308309
cdef const char *DPCTLPlatform_GetName(const DPCTLSyclPlatformRef)
309310
cdef const char *DPCTLPlatform_GetVendor(const DPCTLSyclPlatformRef)
310311
cdef const char *DPCTLPlatform_GetVersion(const DPCTLSyclPlatformRef)
312+
cdef size_t DPCTLPlatform_Hash(const DPCTLSyclPlatformRef)
311313
cdef DPCTLPlatformVectorRef DPCTLPlatform_GetPlatforms()
312314
cdef DPCTLSyclContextRef DPCTLPlatform_GetDefaultContext(
313315
const DPCTLSyclPlatformRef)

dpctl/_sycl_platform.pxd

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
SYCL platform-related helper functions.
2222
"""
2323

24+
from libcpp cimport bool
25+
2426
from ._backend cimport DPCTLSyclDeviceSelectorRef, DPCTLSyclPlatformRef
2527

2628

@@ -40,6 +42,7 @@ cdef class SyclPlatform(_SyclPlatform):
4042
cdef int _init_from_selector(self, DPCTLSyclDeviceSelectorRef DSRef)
4143
cdef int _init_from__SyclPlatform(self, _SyclPlatform other)
4244
cdef DPCTLSyclPlatformRef get_platform_ref(self)
45+
cdef bool equals(self, SyclPlatform)
4346

4447

4548
cpdef list get_platforms()

dpctl/_sycl_platform.pyx

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,13 @@
2121
""" Implements SyclPlatform Cython extension type.
2222
"""
2323

24+
from libcpp cimport bool
25+
2426
from ._backend cimport ( # noqa: E211
2527
DPCTLCString_Delete,
2628
DPCTLDeviceSelector_Delete,
2729
DPCTLFilterSelector_Create,
30+
DPCTLPlatform_AreEq,
2831
DPCTLPlatform_Copy,
2932
DPCTLPlatform_Create,
3033
DPCTLPlatform_CreateFromSelector,
@@ -35,6 +38,7 @@ from ._backend cimport ( # noqa: E211
3538
DPCTLPlatform_GetPlatforms,
3639
DPCTLPlatform_GetVendor,
3740
DPCTLPlatform_GetVersion,
41+
DPCTLPlatform_Hash,
3842
DPCTLPlatformMgr_GetInfo,
3943
DPCTLPlatformMgr_PrintInfo,
4044
DPCTLPlatformVector_Delete,
@@ -274,6 +278,42 @@ cdef class SyclPlatform(_SyclPlatform):
274278
else:
275279
return SyclContext._create(CRef)
276280

281+
cdef bool equals(self, SyclPlatform other):
282+
"""
283+
Returns true if the :class:`dpctl.SyclPlatform` argument has the
284+
same underlying ``DPCTLSyclPlatformRef`` object as this
285+
:class:`dpctl.SyclPlatform` instance.
286+
287+
Returns:
288+
:obj:`bool`: ``True`` if the two :class:`dpctl.SyclPlatform` objects
289+
point to the same ``DPCTLSyclPlatformRef`` object, otherwise
290+
``False``.
291+
"""
292+
return DPCTLPlatform_AreEq(self._platform_ref, other.get_platform_ref())
293+
294+
def __eq__(self, other):
295+
"""
296+
Returns True if the :class:`dpctl.SyclPlatform` argument has the
297+
same underlying ``DPCTLSyclPlatformRef`` object as this
298+
:class:`dpctl.SyclPlatform` instance.
299+
300+
Returns:
301+
:obj:`bool`: ``True`` if the two :class:`dpctl.SyclPlatform` objects
302+
point to the same ``DPCTLSyclPlatformRef`` object, otherwise
303+
``False``.
304+
"""
305+
if isinstance(other, SyclPlatform):
306+
return self.equals(<SyclPlatform> other)
307+
else:
308+
return False
309+
310+
def __hash__(self):
311+
"""
312+
Returns a hash value by hashing the underlying ``sycl::platform`` object.
313+
314+
"""
315+
return DPCTLPlatform_Hash(self._platform_ref)
316+
277317

278318
def lsplatform(verbosity=0):
279319
"""

0 commit comments

Comments
 (0)