Skip to content

Commit a8d97f3

Browse files
Merge pull request #1333 from IntelPython/sycl-platform-equal
SyclPlatform equality operator implemented
2 parents ed93e02 + e54aaa0 commit a8d97f3

File tree

8 files changed

+134
-1
lines changed

8 files changed

+134
-1
lines changed

dpctl/_backend.pxd

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ cdef extern from "syclinterface/dpctl_sycl_platform_manager.h":
299299

300300

301301
cdef extern from "syclinterface/dpctl_sycl_platform_interface.h":
302+
cdef bool DPCTLPlatform_AreEq(const DPCTLSyclPlatformRef, const DPCTLSyclPlatformRef)
302303
cdef DPCTLSyclPlatformRef DPCTLPlatform_Copy(const DPCTLSyclPlatformRef)
303304
cdef DPCTLSyclPlatformRef DPCTLPlatform_Create()
304305
cdef DPCTLSyclPlatformRef DPCTLPlatform_CreateFromSelector(
@@ -308,6 +309,7 @@ cdef extern from "syclinterface/dpctl_sycl_platform_interface.h":
308309
cdef const char *DPCTLPlatform_GetName(const DPCTLSyclPlatformRef)
309310
cdef const char *DPCTLPlatform_GetVendor(const DPCTLSyclPlatformRef)
310311
cdef const char *DPCTLPlatform_GetVersion(const DPCTLSyclPlatformRef)
312+
cdef size_t DPCTLPlatform_Hash(const DPCTLSyclPlatformRef)
311313
cdef DPCTLPlatformVectorRef DPCTLPlatform_GetPlatforms()
312314
cdef DPCTLSyclContextRef DPCTLPlatform_GetDefaultContext(
313315
const DPCTLSyclPlatformRef)

dpctl/_sycl_platform.pxd

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
SYCL platform-related helper functions.
2222
"""
2323

24+
from libcpp cimport bool
25+
2426
from ._backend cimport DPCTLSyclDeviceSelectorRef, DPCTLSyclPlatformRef
2527

2628

@@ -40,6 +42,7 @@ cdef class SyclPlatform(_SyclPlatform):
4042
cdef int _init_from_selector(self, DPCTLSyclDeviceSelectorRef DSRef)
4143
cdef int _init_from__SyclPlatform(self, _SyclPlatform other)
4244
cdef DPCTLSyclPlatformRef get_platform_ref(self)
45+
cdef bool equals(self, SyclPlatform)
4346

4447

4548
cpdef list get_platforms()

dpctl/_sycl_platform.pyx

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,13 @@
2121
""" Implements SyclPlatform Cython extension type.
2222
"""
2323

24+
from libcpp cimport bool
25+
2426
from ._backend cimport ( # noqa: E211
2527
DPCTLCString_Delete,
2628
DPCTLDeviceSelector_Delete,
2729
DPCTLFilterSelector_Create,
30+
DPCTLPlatform_AreEq,
2831
DPCTLPlatform_Copy,
2932
DPCTLPlatform_Create,
3033
DPCTLPlatform_CreateFromSelector,
@@ -35,6 +38,7 @@ from ._backend cimport ( # noqa: E211
3538
DPCTLPlatform_GetPlatforms,
3639
DPCTLPlatform_GetVendor,
3740
DPCTLPlatform_GetVersion,
41+
DPCTLPlatform_Hash,
3842
DPCTLPlatformMgr_GetInfo,
3943
DPCTLPlatformMgr_PrintInfo,
4044
DPCTLPlatformVector_Delete,
@@ -274,6 +278,42 @@ cdef class SyclPlatform(_SyclPlatform):
274278
else:
275279
return SyclContext._create(CRef)
276280

281+
cdef bool equals(self, SyclPlatform other):
282+
"""
283+
Returns true if the :class:`dpctl.SyclPlatform` argument has the
284+
same underlying ``DPCTLSyclPlatformRef`` object as this
285+
:class:`dpctl.SyclPlatform` instance.
286+
287+
Returns:
288+
:obj:`bool`: ``True`` if the two :class:`dpctl.SyclPlatform` objects
289+
point to the same ``DPCTLSyclPlatformRef`` object, otherwise
290+
``False``.
291+
"""
292+
return DPCTLPlatform_AreEq(self._platform_ref, other.get_platform_ref())
293+
294+
def __eq__(self, other):
295+
"""
296+
Returns True if the :class:`dpctl.SyclPlatform` argument has the
297+
same underlying ``DPCTLSyclPlatformRef`` object as this
298+
:class:`dpctl.SyclPlatform` instance.
299+
300+
Returns:
301+
:obj:`bool`: ``True`` if the two :class:`dpctl.SyclPlatform` objects
302+
point to the same ``DPCTLSyclPlatformRef`` object, otherwise
303+
``False``.
304+
"""
305+
if isinstance(other, SyclPlatform):
306+
return self.equals(<SyclPlatform> other)
307+
else:
308+
return False
309+
310+
def __hash__(self):
311+
"""
312+
Returns a hash value by hashing the underlying ``sycl::platform`` object.
313+
314+
"""
315+
return DPCTLPlatform_Hash(self._platform_ref)
316+
277317

278318
def lsplatform(verbosity=0):
279319
"""

dpctl/tests/test_sycl_platform.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
"""Defines unit test cases for the SyclPlatform class.
1818
"""
1919

20+
import sys
21+
2022
import pytest
2123
from helper import has_sycl_platforms
2224

@@ -88,17 +90,37 @@ def check_repr(platform):
8890

8991

9092
def check_default_context(platform):
93+
if "linux" not in sys.platform:
94+
return
9195
r = platform.default_context
9296
assert type(r) is dpctl.SyclContext
9397

9498

99+
def check_equal_and_hash(platform):
100+
assert platform == platform
101+
if "linux" not in sys.platform:
102+
return
103+
default_ctx = platform.default_context
104+
for d in default_ctx.get_devices():
105+
assert platform == d.sycl_platform
106+
assert hash(platform) == hash(d.sycl_platform)
107+
108+
109+
def check_hash_in_dict(platform):
110+
map = {platform: 0}
111+
assert map[platform] == 0
112+
113+
95114
list_of_checks = [
96115
check_name,
97116
check_vendor,
98117
check_version,
99118
check_backend,
100119
check_print_info,
101120
check_repr,
121+
check_default_context,
122+
check_equal_and_hash,
123+
check_hash_in_dict,
102124
]
103125

104126

libsyclinterface/include/dpctl_sycl_context_interface.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,6 @@ void DPCTLContext_Delete(__dpctl_take DPCTLSyclContextRef CtxRef);
159159
* @ingroup ContextInterface
160160
*/
161161
DPCTL_API
162-
size_t DPCTLContext_Hash(__dpctl_take DPCTLSyclContextRef CtxRef);
162+
size_t DPCTLContext_Hash(__dpctl_keep DPCTLSyclContextRef CtxRef);
163163

164164
DPCTL_C_EXTERN_C_END

libsyclinterface/include/dpctl_sycl_platform_interface.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,19 @@ DPCTL_C_EXTERN_C_BEGIN
3939
* @defgroup PlatformInterface Platform class C wrapper
4040
*/
4141

42+
/*!
43+
* @brief Checks if two DPCTLSyclPlatformRef objects point to the same
44+
* sycl::platform.
45+
*
46+
* @param PRef1 First opaque pointer to a ``sycl::platform``.
47+
* @param PRef2 Second opaque pointer to a ``sycl::platform``.
48+
* @return True if the underlying sycl::platform are same, false otherwise.
49+
* @ingroup PlatformInterface
50+
*/
51+
DPCTL_API
52+
bool DPCTLPlatform_AreEq(__dpctl_keep const DPCTLSyclPlatformRef PRef1,
53+
__dpctl_keep const DPCTLSyclPlatformRef PRef2);
54+
4255
/*!
4356
* @brief Returns a copy of the DPCTLSyclPlatformRef object.
4457
*
@@ -155,4 +168,14 @@ DPCTL_API
155168
__dpctl_give DPCTLSyclContextRef
156169
DPCTLPlatform_GetDefaultContext(__dpctl_keep const DPCTLSyclPlatformRef PRef);
157170

171+
/*!
172+
* @brief Wrapper over std::hash<sycl::platform>'s operator()
173+
*
174+
* @param PRef The DPCTLSyclPlatformRef pointer.
175+
* @return Hash value of the underlying ``sycl::platform`` instance.
176+
* @ingroup PlatformInterface
177+
*/
178+
DPCTL_API
179+
size_t DPCTLPlatform_Hash(__dpctl_keep DPCTLSyclPlatformRef CtxRef);
180+
158181
DPCTL_C_EXTERN_C_END

libsyclinterface/source/dpctl_sycl_platform_interface.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,3 +234,27 @@ DPCTLPlatform_GetDefaultContext(__dpctl_keep const DPCTLSyclPlatformRef PRef)
234234
return nullptr;
235235
}
236236
}
237+
238+
bool DPCTLPlatform_AreEq(__dpctl_keep const DPCTLSyclPlatformRef PRef1,
239+
__dpctl_keep const DPCTLSyclPlatformRef PRef2)
240+
{
241+
auto P1 = unwrap<platform>(PRef1);
242+
auto P2 = unwrap<platform>(PRef2);
243+
if (P1 && P2)
244+
return *P1 == *P2;
245+
else
246+
return false;
247+
}
248+
249+
size_t DPCTLPlatform_Hash(__dpctl_keep const DPCTLSyclPlatformRef PRef)
250+
{
251+
if (PRef) {
252+
auto P = unwrap<platform>(PRef);
253+
std::hash<platform> hash_fn;
254+
return hash_fn(*P);
255+
}
256+
else {
257+
error_handler("Argument PRef is null.", __FILE__, __func__, __LINE__);
258+
return 0;
259+
}
260+
}

libsyclinterface/tests/test_sycl_platform_interface.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,25 @@ TEST_P(TestDPCTLSyclPlatformInterface, ChkPrintInfoNullArg)
264264
EXPECT_NO_FATAL_FAILURE(DPCTLPlatformMgr_PrintInfo(Null_PRef, 0));
265265
}
266266

267+
TEST_P(TestDPCTLSyclPlatformInterface, ChkAreEq)
268+
{
269+
DPCTLSyclPlatformRef PRef_Copy = nullptr;
270+
271+
EXPECT_NO_FATAL_FAILURE(PRef_Copy = DPCTLPlatform_Copy(PRef));
272+
273+
ASSERT_TRUE(DPCTLPlatform_AreEq(PRef, PRef_Copy));
274+
EXPECT_TRUE(DPCTLPlatform_Hash(PRef) == DPCTLPlatform_Hash(PRef_Copy));
275+
276+
EXPECT_NO_FATAL_FAILURE(DPCTLPlatform_Delete(PRef_Copy));
277+
}
278+
279+
TEST_P(TestDPCTLSyclPlatformInterface, ChkAreEqNullArg)
280+
{
281+
DPCTLSyclPlatformRef Null_PRef = nullptr;
282+
ASSERT_FALSE(DPCTLPlatform_AreEq(PRef, Null_PRef));
283+
ASSERT_TRUE(DPCTLPlatform_Hash(Null_PRef) == 0);
284+
}
285+
267286
TEST_F(TestDPCTLSyclDefaultPlatform, ChkGetName)
268287
{
269288
check_platform_name(PRef);

0 commit comments

Comments
 (0)