@@ -149,6 +149,19 @@ cdef bint _is_host_cpu(object dl_device):
149
149
return (dl_type == DLDeviceType.kDLCPU) and (dl_id == 0 )
150
150
151
151
152
+ cdef void _validate_and_use_stream(object stream, c_dpctl.SyclQueue self_queue) except * :
153
+ if (stream is None or stream == self_queue):
154
+ pass
155
+ else :
156
+ if not isinstance (stream, dpctl.SyclQueue):
157
+ raise TypeError (
158
+ " stream argument type was expected to be dpctl.SyclQueue,"
159
+ f" got {type(stream)} instead"
160
+ )
161
+ ev = self_queue.submit_barrier()
162
+ stream.submit_barrier(dependent_events = [ev])
163
+
164
+
152
165
cdef class usm_ndarray:
153
166
""" usm_ndarray(shape, dtype=None, strides=None, buffer="device", \
154
167
offset=0, order="C", buffer_ctor_kwargs=dict(), \
@@ -1025,16 +1038,7 @@ cdef class usm_ndarray:
1025
1038
cdef c_dpmem._Memory arr_buf
1026
1039
d = Device.create_device(target_device)
1027
1040
1028
- if (stream is None or stream == self .sycl_queue):
1029
- pass
1030
- else :
1031
- if not isinstance (stream, dpctl.SyclQueue):
1032
- raise TypeError (
1033
- " stream argument type was expected to be dpctl.SyclQueue,"
1034
- f" got {type(stream)} instead"
1035
- )
1036
- ev = self .sycl_queue.submit_barrier()
1037
- stream.submit_barrier(dependent_events = [ev])
1041
+ _validate_and_use_stream(stream, self .sycl_queue)
1038
1042
1039
1043
if (d.sycl_context == self .sycl_context):
1040
1044
arr_buf = < c_dpmem._Memory> self .usm_data
@@ -1207,17 +1211,7 @@ cdef class usm_ndarray:
1207
1211
# legacy path for DLManagedTensor
1208
1212
# copy kwarg ignored because copy flag can't be set
1209
1213
_caps = c_dlpack.to_dlpack_capsule(self )
1210
- if (stream is None or stream == self .sycl_queue):
1211
- pass
1212
- else :
1213
- if not isinstance (stream, dpctl.SyclQueue):
1214
- raise TypeError (
1215
- " stream keyword argument type is expected to "
1216
- " be dpctl.SyclQueue, "
1217
- f" got {type(stream)} instead"
1218
- )
1219
- ev = self .sycl_queue.submit_barrier()
1220
- stream.submit_barrier(dependent_events = [ev])
1214
+ _validate_and_use_stream(stream, self .sycl_queue)
1221
1215
return _caps
1222
1216
else :
1223
1217
if not isinstance (max_version, tuple ) or len (max_version) != 2 :
@@ -1259,12 +1253,7 @@ cdef class usm_ndarray:
1259
1253
copy = False
1260
1254
# TODO: strategy for handling stream on different device from dl_device
1261
1255
if copy:
1262
- if (stream is None or type (stream) is not dpctl.SyclQueue or
1263
- stream == self .sycl_queue):
1264
- pass
1265
- else :
1266
- ev = self .sycl_queue.submit_barrier()
1267
- stream.submit_barrier(dependent_events = [ev])
1256
+ _validate_and_use_stream(stream, self .sycl_queue)
1268
1257
nbytes = self .usm_data.nbytes
1269
1258
copy_buffer = type (self .usm_data)(
1270
1259
nbytes, queue = self .sycl_queue
@@ -1281,22 +1270,12 @@ cdef class usm_ndarray:
1281
1270
_caps = c_dlpack.to_dlpack_versioned_capsule(_copied_arr, copy)
1282
1271
else :
1283
1272
_caps = c_dlpack.to_dlpack_versioned_capsule(self , copy)
1284
- if (stream is None or type (stream) is not dpctl.SyclQueue or
1285
- stream == self .sycl_queue):
1286
- pass
1287
- else :
1288
- ev = self .sycl_queue.submit_barrier()
1289
- stream.submit_barrier(dependent_events = [ev])
1273
+ _validate_and_use_stream(stream, self .sycl_queue)
1290
1274
return _caps
1291
1275
else :
1292
1276
# legacy path for DLManagedTensor
1293
1277
_caps = c_dlpack.to_dlpack_capsule(self )
1294
- if (stream is None or type (stream) is not dpctl.SyclQueue or
1295
- stream == self .sycl_queue):
1296
- pass
1297
- else :
1298
- ev = self .sycl_queue.submit_barrier()
1299
- stream.submit_barrier(dependent_events = [ev])
1278
+ _validate_and_use_stream(stream, self .sycl_queue)
1300
1279
return _caps
1301
1280
1302
1281
def __dlpack_device__ (self ):
0 commit comments