Skip to content

Commit dbcbc06

Browse files
Method equals changed from being cpdef to cdef
Method __eq__ has been added to SyclContext, SyclQueue to rely on equals for the correct type of the other object, and giving False otherwise. Several examples that relied on .equals method were modified accordingly.
1 parent 2709ee4 commit dbcbc06

File tree

8 files changed

+42
-11
lines changed

8 files changed

+42
-11
lines changed

dpctl/_sycl_context.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,5 +42,5 @@ cdef public class SyclContext(_SyclContext) [object PySyclContextObject, type Py
4242
cdef int _init_context_from_one_device(self, SyclDevice device, int props)
4343
cdef int _init_context_from_devices(self, object devices, int props)
4444
cdef int _init_context_from_capsule(self, object caps)
45-
cpdef bool equals (self, SyclContext ctxt)
45+
cdef bool equals (self, SyclContext ctxt)
4646
cdef DPCTLSyclContextRef get_context_ref (self)

dpctl/_sycl_context.pyx

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ cdef class SyclContext(_SyclContext):
299299
"Unrecognized error code ({}) encountered.".format(ret)
300300
)
301301

302-
cpdef bool equals(self, SyclContext ctxt):
302+
cdef bool equals(self, SyclContext ctxt):
303303
"""
304304
Returns true if the :class:`dpctl.SyclContext` argument has the
305305
same underlying ``DPCTLSyclContextRef`` object as this
@@ -312,6 +312,22 @@ cdef class SyclContext(_SyclContext):
312312
"""
313313
return DPCTLContext_AreEq(self._ctxt_ref, ctxt.get_context_ref())
314314

315+
def __eq__(self, other):
316+
"""
317+
Returns True if the :class:`dpctl.SyclContext` argument has the
318+
same underlying ``DPCTLSyclContextRef`` object as this
319+
:class:`dpctl.SyclContext` instance.
320+
321+
Returns:
322+
:obj:`bool`: ``True`` if the two :class:`dpctl.SyclContext` objects
323+
point to the same ``DPCTLSyclContextRef`` object, otherwise
324+
``False``.
325+
"""
326+
if isinstance(other, SyclContext):
327+
return self.equals(<SyclContext> other)
328+
else:
329+
return False
330+
315331
cdef DPCTLSyclContextRef get_context_ref(self):
316332
return self._ctxt_ref
317333

dpctl/_sycl_device.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,4 @@ cdef public class SyclDevice(_SyclDevice) [object PySyclDeviceObject, type PySyc
4949
cdef list create_sub_devices_equally(self, size_t count)
5050
cdef list create_sub_devices_by_counts(self, object counts)
5151
cdef list create_sub_devices_by_affinity(self, _partition_affinity_domain_type domain)
52-
cpdef cpp_bool equals(self, SyclDevice q)
52+
cdef cpp_bool equals(self, SyclDevice q)

dpctl/_sycl_device.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -774,7 +774,7 @@ cdef class SyclDevice(_SyclDevice):
774774
return None
775775
return SyclDevice._create(pDRef)
776776

777-
cpdef cpp_bool equals(self, SyclDevice other):
777+
cdef cpp_bool equals(self, SyclDevice other):
778778
""" Returns true if the SyclDevice argument has the same _device_ref
779779
as this SyclDevice.
780780
"""

dpctl/_sycl_queue.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ cdef public class SyclQueue (_SyclQueue) [object PySyclQueueObject, type PySyclQ
6666
cdef SyclQueue _create_from_context_and_device(
6767
SyclContext ctx, SyclDevice dev
6868
)
69-
cpdef cpp_bool equals(self, SyclQueue q)
69+
cdef cpp_bool equals(self, SyclQueue q)
7070
cpdef SyclContext get_sycl_context(self)
7171
cpdef SyclDevice get_sycl_device(self)
7272
cdef DPCTLSyclQueueRef get_queue_ref(self)

dpctl/_sycl_queue.pyx

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -519,12 +519,27 @@ cdef class SyclQueue(_SyclQueue):
519519

520520
return ret
521521

522-
cpdef cpp_bool equals(self, SyclQueue q):
523-
""" Returns true if the SyclQueue argument has the same _queue_ref
522+
cdef cpp_bool equals(self, SyclQueue q):
523+
""" Returns true if the SyclQueue argument `q` has the same _queue_ref
524524
as this SyclQueue.
525525
"""
526526
return DPCTLQueue_AreEq(self._queue_ref, q.get_queue_ref())
527527

528+
def __eq__(self, other):
529+
"""
530+
Returns True if two :class:`dpctl.SyclQueue` compared arguments have
531+
the same underlying ``DPCTLSyclQueueRef`` object.
532+
533+
Returns:
534+
:obj:`bool`: ``True`` if the two :class:`dpctl.SyclQueue` objects
535+
point to the same ``DPCTLSyclQueueRef`` object, otherwise
536+
``False``.
537+
"""
538+
if isinstance(other, SyclQueue):
539+
return self.equals(<SyclQueue> other)
540+
else:
541+
return False
542+
528543
def get_sycl_backend (self):
529544
""" Returns the Sycl backend associated with the queue.
530545
"""

dpctl/tests/test_sycl_context.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ def test_context_not_equals():
349349
ctx_cpu = dpctl.SyclContext("cpu")
350350
except ValueError:
351351
pytest.skip()
352-
assert not ctx_cpu.equals(ctx_gpu)
352+
assert ctx_cpu != ctx_gpu
353353

354354

355355
def test_context_equals():
@@ -358,7 +358,7 @@ def test_context_equals():
358358
ctx0 = dpctl.SyclContext("gpu")
359359
except ValueError:
360360
pytest.skip()
361-
assert ctx0.equals(ctx1)
361+
assert ctx0 == ctx1
362362

363363

364364
def test_context_can_be_used_in_queue(valid_filter):

dpctl/tests/test_sycl_queue.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ def test_context_not_equals():
349349
except dpctl.SyclQueueCreationError:
350350
pytest.skip()
351351
ctx_cpu = cpuQ.get_sycl_context()
352-
assert not ctx_cpu.equals(ctx_gpu)
352+
assert ctx_cpu != ctx_gpu
353353

354354

355355
def test_context_equals():
@@ -360,4 +360,4 @@ def test_context_equals():
360360
pytest.skip()
361361
ctx0 = gpuQ0.get_sycl_context()
362362
ctx1 = gpuQ1.get_sycl_context()
363-
assert ctx0.equals(ctx1)
363+
assert ctx0 == ctx1

0 commit comments

Comments
 (0)