Skip to content

Commit 7767b4f

Browse files
Merge pull request #389 from IntelPython/equals_changes_and_misc
Equals changes and misc
2 parents 48dbf1f + dbcbc06 commit 7767b4f

File tree

8 files changed

+43
-14
lines changed

8 files changed

+43
-14
lines changed

dpctl/_sycl_context.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,5 +42,5 @@ cdef public class SyclContext(_SyclContext) [object PySyclContextObject, type Py
4242
cdef int _init_context_from_one_device(self, SyclDevice device, int props)
4343
cdef int _init_context_from_devices(self, object devices, int props)
4444
cdef int _init_context_from_capsule(self, object caps)
45-
cpdef bool equals (self, SyclContext ctxt)
45+
cdef bool equals (self, SyclContext ctxt)
4646
cdef DPCTLSyclContextRef get_context_ref (self)

dpctl/_sycl_context.pyx

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ cdef class SyclContext(_SyclContext):
301301
"Unrecognized error code ({}) encountered.".format(ret)
302302
)
303303

304-
cpdef bool equals(self, SyclContext ctxt):
304+
cdef bool equals(self, SyclContext ctxt):
305305
"""
306306
Returns true if the :class:`dpctl.SyclContext` argument has the
307307
same underlying ``DPCTLSyclContextRef`` object as this
@@ -314,6 +314,22 @@ cdef class SyclContext(_SyclContext):
314314
"""
315315
return DPCTLContext_AreEq(self._ctxt_ref, ctxt.get_context_ref())
316316

317+
def __eq__(self, other):
318+
"""
319+
Returns True if the :class:`dpctl.SyclContext` argument has the
320+
same underlying ``DPCTLSyclContextRef`` object as this
321+
:class:`dpctl.SyclContext` instance.
322+
323+
Returns:
324+
:obj:`bool`: ``True`` if the two :class:`dpctl.SyclContext` objects
325+
point to the same ``DPCTLSyclContextRef`` object, otherwise
326+
``False``.
327+
"""
328+
if isinstance(other, SyclContext):
329+
return self.equals(<SyclContext> other)
330+
else:
331+
return False
332+
317333
cdef DPCTLSyclContextRef get_context_ref(self):
318334
return self._ctxt_ref
319335

dpctl/_sycl_device.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,4 @@ cdef public class SyclDevice(_SyclDevice) [object PySyclDeviceObject, type PySyc
4949
cdef list create_sub_devices_equally(self, size_t count)
5050
cdef list create_sub_devices_by_counts(self, object counts)
5151
cdef list create_sub_devices_by_affinity(self, _partition_affinity_domain_type domain)
52-
cpdef cpp_bool equals(self, SyclDevice q)
52+
cdef cpp_bool equals(self, SyclDevice q)

dpctl/_sycl_device.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -774,7 +774,7 @@ cdef class SyclDevice(_SyclDevice):
774774
return None
775775
return SyclDevice._create(pDRef)
776776

777-
cpdef cpp_bool equals(self, SyclDevice other):
777+
cdef cpp_bool equals(self, SyclDevice other):
778778
""" Returns true if the SyclDevice argument has the same _device_ref
779779
as this SyclDevice.
780780
"""

dpctl/_sycl_queue.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ cdef public class SyclQueue (_SyclQueue) [object PySyclQueueObject, type PySyclQ
6666
cdef SyclQueue _create_from_context_and_device(
6767
SyclContext ctx, SyclDevice dev
6868
)
69-
cpdef cpp_bool equals(self, SyclQueue q)
69+
cdef cpp_bool equals(self, SyclQueue q)
7070
cpdef SyclContext get_sycl_context(self)
7171
cpdef SyclDevice get_sycl_device(self)
7272
cdef DPCTLSyclQueueRef get_queue_ref(self)

dpctl/_sycl_queue.pyx

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -319,16 +319,14 @@ cdef class SyclQueue(_SyclQueue):
319319
raise SyclQueueCreationError(
320320
"SYCL Context could not be created from '{}'.".format(arg)
321321
)
322-
elif status == -4:
322+
elif status == -4 or status == -6:
323323
if len_args == 2:
324324
arg = args
325325
raise SyclQueueCreationError(
326326
"SYCL Queue failed to be created from '{}'.".format(arg)
327327
)
328328
elif status == -5:
329329
raise TypeError("Input capsule {} contains a null pointer or could not be renamed".format(arg))
330-
elif status == -6:
331-
raise "SYCL Queue failed to be created from '{}'.".format(arg)
332330

333331
cdef int _init_queue_from__SyclQueue(self, _SyclQueue other):
334332
""" Copy data container _SyclQueue fields over.
@@ -601,12 +599,27 @@ cdef class SyclQueue(_SyclQueue):
601599

602600
return ret
603601

604-
cpdef cpp_bool equals(self, SyclQueue q):
605-
""" Returns true if the SyclQueue argument has the same _queue_ref
602+
cdef cpp_bool equals(self, SyclQueue q):
603+
""" Returns true if the SyclQueue argument `q` has the same _queue_ref
606604
as this SyclQueue.
607605
"""
608606
return DPCTLQueue_AreEq(self._queue_ref, q.get_queue_ref())
609607

608+
def __eq__(self, other):
609+
"""
610+
Returns True if two :class:`dpctl.SyclQueue` compared arguments have
611+
the same underlying ``DPCTLSyclQueueRef`` object.
612+
613+
Returns:
614+
:obj:`bool`: ``True`` if the two :class:`dpctl.SyclQueue` objects
615+
point to the same ``DPCTLSyclQueueRef`` object, otherwise
616+
``False``.
617+
"""
618+
if isinstance(other, SyclQueue):
619+
return self.equals(<SyclQueue> other)
620+
else:
621+
return False
622+
610623
def get_sycl_backend (self):
611624
""" Returns the Sycl backend associated with the queue.
612625
"""

dpctl/tests/test_sycl_context.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ def test_context_not_equals():
349349
ctx_cpu = dpctl.SyclContext("cpu")
350350
except ValueError:
351351
pytest.skip()
352-
assert not ctx_cpu.equals(ctx_gpu)
352+
assert ctx_cpu != ctx_gpu
353353

354354

355355
def test_context_equals():
@@ -358,7 +358,7 @@ def test_context_equals():
358358
ctx0 = dpctl.SyclContext("gpu")
359359
except ValueError:
360360
pytest.skip()
361-
assert ctx0.equals(ctx1)
361+
assert ctx0 == ctx1
362362

363363

364364
def test_context_can_be_used_in_queue(valid_filter):

dpctl/tests/test_sycl_queue.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ def test_context_not_equals():
349349
except dpctl.SyclQueueCreationError:
350350
pytest.skip()
351351
ctx_cpu = cpuQ.get_sycl_context()
352-
assert not ctx_cpu.equals(ctx_gpu)
352+
assert ctx_cpu != ctx_gpu
353353

354354

355355
def test_context_equals():
@@ -360,4 +360,4 @@ def test_context_equals():
360360
pytest.skip()
361361
ctx0 = gpuQ0.get_sycl_context()
362362
ctx1 = gpuQ1.get_sycl_context()
363-
assert ctx0.equals(ctx1)
363+
assert ctx0 == ctx1

0 commit comments

Comments
 (0)