Skip to content

Commit 74066bb

Browse files
Modularize stream validation into a function.
Stream keyword validation and deployment was copy-pasted in several places. Created function _stream_validate_and_use, and used it in a couple of places. This brings uniformity of error messages, and should improve coverage and maintainability.
1 parent df8c1f3 commit 74066bb

File tree

1 file changed

+18
-39
lines changed

1 file changed

+18
-39
lines changed

dpctl/tensor/_usmarray.pyx

Lines changed: 18 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,19 @@ cdef bint _is_host_cpu(object dl_device):
149149
return (dl_type == DLDeviceType.kDLCPU) and (dl_id == 0)
150150

151151

152+
cdef void _validate_and_use_stream(object stream, c_dpctl.SyclQueue self_queue) except *:
153+
if (stream is None or stream == self_queue):
154+
pass
155+
else:
156+
if not isinstance(stream, dpctl.SyclQueue):
157+
raise TypeError(
158+
"stream argument type was expected to be dpctl.SyclQueue,"
159+
f" got {type(stream)} instead"
160+
)
161+
ev = self_queue.submit_barrier()
162+
stream.submit_barrier(dependent_events=[ev])
163+
164+
152165
cdef class usm_ndarray:
153166
""" usm_ndarray(shape, dtype=None, strides=None, buffer="device", \
154167
offset=0, order="C", buffer_ctor_kwargs=dict(), \
@@ -1025,16 +1038,7 @@ cdef class usm_ndarray:
10251038
cdef c_dpmem._Memory arr_buf
10261039
d = Device.create_device(target_device)
10271040

1028-
if (stream is None or stream == self.sycl_queue):
1029-
pass
1030-
else:
1031-
if not isinstance(stream, dpctl.SyclQueue):
1032-
raise TypeError(
1033-
"stream argument type was expected to be dpctl.SyclQueue,"
1034-
f" got {type(stream)} instead"
1035-
)
1036-
ev = self.sycl_queue.submit_barrier()
1037-
stream.submit_barrier(dependent_events=[ev])
1041+
_validate_and_use_stream(stream, self.sycl_queue)
10381042

10391043
if (d.sycl_context == self.sycl_context):
10401044
arr_buf = <c_dpmem._Memory> self.usm_data
@@ -1207,17 +1211,7 @@ cdef class usm_ndarray:
12071211
# legacy path for DLManagedTensor
12081212
# copy kwarg ignored because copy flag can't be set
12091213
_caps = c_dlpack.to_dlpack_capsule(self)
1210-
if (stream is None or stream == self.sycl_queue):
1211-
pass
1212-
else:
1213-
if not isinstance(stream, dpctl.SyclQueue):
1214-
raise TypeError(
1215-
"stream keyword argument type is expected to "
1216-
"be dpctl.SyclQueue, "
1217-
f" got {type(stream)} instead"
1218-
)
1219-
ev = self.sycl_queue.submit_barrier()
1220-
stream.submit_barrier(dependent_events=[ev])
1214+
_validate_and_use_stream(stream, self.sycl_queue)
12211215
return _caps
12221216
else:
12231217
if not isinstance(max_version, tuple) or len(max_version) != 2:
@@ -1259,12 +1253,7 @@ cdef class usm_ndarray:
12591253
copy = False
12601254
# TODO: strategy for handling stream on different device from dl_device
12611255
if copy:
1262-
if (stream is None or type(stream) is not dpctl.SyclQueue or
1263-
stream == self.sycl_queue):
1264-
pass
1265-
else:
1266-
ev = self.sycl_queue.submit_barrier()
1267-
stream.submit_barrier(dependent_events=[ev])
1256+
_validate_and_use_stream(stream, self.sycl_queue)
12681257
nbytes = self.usm_data.nbytes
12691258
copy_buffer = type(self.usm_data)(
12701259
nbytes, queue=self.sycl_queue
@@ -1281,22 +1270,12 @@ cdef class usm_ndarray:
12811270
_caps = c_dlpack.to_dlpack_versioned_capsule(_copied_arr, copy)
12821271
else:
12831272
_caps = c_dlpack.to_dlpack_versioned_capsule(self, copy)
1284-
if (stream is None or type(stream) is not dpctl.SyclQueue or
1285-
stream == self.sycl_queue):
1286-
pass
1287-
else:
1288-
ev = self.sycl_queue.submit_barrier()
1289-
stream.submit_barrier(dependent_events=[ev])
1273+
_validate_and_use_stream(stream, self.sycl_queue)
12901274
return _caps
12911275
else:
12921276
# legacy path for DLManagedTensor
12931277
_caps = c_dlpack.to_dlpack_capsule(self)
1294-
if (stream is None or type(stream) is not dpctl.SyclQueue or
1295-
stream == self.sycl_queue):
1296-
pass
1297-
else:
1298-
ev = self.sycl_queue.submit_barrier()
1299-
stream.submit_barrier(dependent_events=[ev])
1278+
_validate_and_use_stream(stream, self.sycl_queue)
13001279
return _caps
13011280

13021281
def __dlpack_device__(self):

0 commit comments

Comments
 (0)