Skip to content

Commit 5025510

Browse files
Add stream argument validation
1 parent 2e8c9c0 commit 5025510

File tree

1 file changed

+20
-11
lines changed

1 file changed

+20
-11
lines changed

dpctl/tensor/_usmarray.pyx

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,10 +1025,14 @@ cdef class usm_ndarray:
10251025
cdef c_dpmem._Memory arr_buf
10261026
d = Device.create_device(target_device)
10271027

1028-
if (stream is None or not isinstance(stream, dpctl.SyclQueue) or
1029-
stream == self.sycl_queue):
1028+
if (stream is None or stream == self.sycl_queue):
10301029
pass
10311030
else:
1031+
if not isinstance(stream, dpctl.SyclQueue):
1032+
raise TypeError(
1033+
"stream argument type was expected to be dpctl.SyclQueue,"
1034+
f" got {type(stream)} instead"
1035+
)
10321036
ev = self.sycl_queue.submit_barrier()
10331037
stream.submit_barrier(dependent_events=[ev])
10341038

@@ -1203,10 +1207,15 @@ cdef class usm_ndarray:
12031207
# legacy path for DLManagedTensor
12041208
# copy kwarg ignored because copy flag can't be set
12051209
_caps = c_dlpack.to_dlpack_capsule(self)
1206-
if (stream is None or type(stream) is not dpctl.SyclQueue or
1207-
stream == self.sycl_queue):
1210+
if (stream is None or stream == self.sycl_queue):
12081211
pass
12091212
else:
1213+
if not isinstance(stream, dpctl.SyclQueue):
1214+
raise TypeError(
1215+
"stream keyword argument type is expected to "
1216+
"be dpctl.SyclQueue, "
1217+
f" got {type(stream)} instead"
1218+
)
12101219
ev = self.sycl_queue.submit_barrier()
12111220
stream.submit_barrier(dependent_events=[ev])
12121221
return _caps
@@ -1555,17 +1564,17 @@ cdef class usm_ndarray:
15551564
def __array__(self, dtype=None, /, *, copy=None):
15561565
"""NumPy's array protocol method to disallow implicit conversion.
15571566
1558-
Without this definition, `numpy.asarray(usm_ar)` converts
1559-
usm_ndarray instance into NumPy array with data type `object`
1560-
and every element being 0d usm_ndarray.
1567+
Without this definition, `numpy.asarray(usm_ar)` converts
1568+
usm_ndarray instance into NumPy array with data type `object`
1569+
and every element being 0d usm_ndarray.
15611570
15621571
https://github.com/IntelPython/dpctl/pull/1384#issuecomment-1707212972
1563-
"""
1572+
"""
15641573
raise TypeError(
15651574
"Implicit conversion to a NumPy array is not allowed. "
1566-
"Use `dpctl.tensor.asnumpy` to copy data from this "
1567-
"`dpctl.tensor.usm_ndarray` instance to NumPy array"
1568-
)
1575+
"Use `dpctl.tensor.asnumpy` to copy data from this "
1576+
"`dpctl.tensor.usm_ndarray` instance to NumPy array"
1577+
)
15691578

15701579

15711580
cdef usm_ndarray _real_view(usm_ndarray ary):

0 commit comments

Comments
 (0)