Skip to content

Equals changes and misc #389

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dpctl/_sycl_context.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,5 @@ cdef public class SyclContext(_SyclContext) [object PySyclContextObject, type Py
cdef int _init_context_from_one_device(self, SyclDevice device, int props)
cdef int _init_context_from_devices(self, object devices, int props)
cdef int _init_context_from_capsule(self, object caps)
cpdef bool equals (self, SyclContext ctxt)
cdef bool equals (self, SyclContext ctxt)
cdef DPCTLSyclContextRef get_context_ref (self)
18 changes: 17 additions & 1 deletion dpctl/_sycl_context.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ cdef class SyclContext(_SyclContext):
"Unrecognized error code ({}) encountered.".format(ret)
)

cpdef bool equals(self, SyclContext ctxt):
cdef bool equals(self, SyclContext ctxt):
"""
Returns true if the :class:`dpctl.SyclContext` argument has the
same underlying ``DPCTLSyclContextRef`` object as this
Expand All @@ -312,6 +312,22 @@ cdef class SyclContext(_SyclContext):
"""
return DPCTLContext_AreEq(self._ctxt_ref, ctxt.get_context_ref())

def __eq__(self, other):
"""
Returns True if the :class:`dpctl.SyclContext` argument has the
same underlying ``DPCTLSyclContextRef`` object as this
:class:`dpctl.SyclContext` instance.

Returns:
:obj:`bool`: ``True`` if the two :class:`dpctl.SyclContext` objects
point to the same ``DPCTLSyclContextRef`` object, otherwise
``False``.
"""
if isinstance(other, SyclContext):
return self.equals(<SyclContext> other)
else:
return False

cdef DPCTLSyclContextRef get_context_ref(self):
return self._ctxt_ref

Expand Down
2 changes: 1 addition & 1 deletion dpctl/_sycl_device.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,4 @@ cdef public class SyclDevice(_SyclDevice) [object PySyclDeviceObject, type PySyc
cdef list create_sub_devices_equally(self, size_t count)
cdef list create_sub_devices_by_counts(self, object counts)
cdef list create_sub_devices_by_affinity(self, _partition_affinity_domain_type domain)
cpdef cpp_bool equals(self, SyclDevice q)
cdef cpp_bool equals(self, SyclDevice q)
2 changes: 1 addition & 1 deletion dpctl/_sycl_device.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,7 @@ cdef class SyclDevice(_SyclDevice):
return None
return SyclDevice._create(pDRef)

cpdef cpp_bool equals(self, SyclDevice other):
cdef cpp_bool equals(self, SyclDevice other):
""" Returns true if the SyclDevice argument has the same _device_ref
as this SyclDevice.
"""
Expand Down
2 changes: 1 addition & 1 deletion dpctl/_sycl_queue.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ cdef public class SyclQueue (_SyclQueue) [object PySyclQueueObject, type PySyclQ
cdef SyclQueue _create_from_context_and_device(
SyclContext ctx, SyclDevice dev
)
cpdef cpp_bool equals(self, SyclQueue q)
cdef cpp_bool equals(self, SyclQueue q)
cpdef SyclContext get_sycl_context(self)
cpdef SyclDevice get_sycl_device(self)
cdef DPCTLSyclQueueRef get_queue_ref(self)
Expand Down
23 changes: 18 additions & 5 deletions dpctl/_sycl_queue.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -239,16 +239,14 @@ cdef class SyclQueue(_SyclQueue):
raise SyclQueueCreationError(
"SYCL Context could not be created from '{}'.".format(arg)
)
elif status == -4:
elif status == -4 or status == -6:
if len_args == 2:
arg = args
raise SyclQueueCreationError(
"SYCL Queue failed to be created from '{}'.".format(arg)
)
elif status == -5:
raise TypeError("Input capsule {} contains a null pointer or could not be renamed".format(arg))
elif status == -6:
raise "SYCL Queue failed to be created from '{}'.".format(arg)

cdef int _init_queue_from__SyclQueue(self, _SyclQueue other):
""" Copy data container _SyclQueue fields over.
Expand Down Expand Up @@ -521,12 +519,27 @@ cdef class SyclQueue(_SyclQueue):

return ret

cpdef cpp_bool equals(self, SyclQueue q):
""" Returns true if the SyclQueue argument has the same _queue_ref
cdef cpp_bool equals(self, SyclQueue q):
""" Returns true if the SyclQueue argument `q` has the same _queue_ref
as this SyclQueue.
"""
return DPCTLQueue_AreEq(self._queue_ref, q.get_queue_ref())

def __eq__(self, other):
"""
Returns True if two :class:`dpctl.SyclQueue` compared arguments have
the same underlying ``DPCTLSyclQueueRef`` object.

Returns:
:obj:`bool`: ``True`` if the two :class:`dpctl.SyclQueue` objects
point to the same ``DPCTLSyclQueueRef`` object, otherwise
``False``.
"""
if isinstance(other, SyclQueue):
return self.equals(<SyclQueue> other)
else:
return False

def get_sycl_backend (self):
""" Returns the Sycl backend associated with the queue.
"""
Expand Down
4 changes: 2 additions & 2 deletions dpctl/tests/test_sycl_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def test_context_not_equals():
ctx_cpu = dpctl.SyclContext("cpu")
except ValueError:
pytest.skip()
assert not ctx_cpu.equals(ctx_gpu)
assert ctx_cpu != ctx_gpu


def test_context_equals():
Expand All @@ -358,7 +358,7 @@ def test_context_equals():
ctx0 = dpctl.SyclContext("gpu")
except ValueError:
pytest.skip()
assert ctx0.equals(ctx1)
assert ctx0 == ctx1


def test_context_can_be_used_in_queue(valid_filter):
Expand Down
4 changes: 2 additions & 2 deletions dpctl/tests/test_sycl_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def test_context_not_equals():
except dpctl.SyclQueueCreationError:
pytest.skip()
ctx_cpu = cpuQ.get_sycl_context()
assert not ctx_cpu.equals(ctx_gpu)
assert ctx_cpu != ctx_gpu


def test_context_equals():
Expand All @@ -360,4 +360,4 @@ def test_context_equals():
pytest.skip()
ctx0 = gpuQ0.get_sycl_context()
ctx1 = gpuQ1.get_sycl_context()
assert ctx0.equals(ctx1)
assert ctx0 == ctx1