Skip to content

Commit bdb2f75

Browse files
Replaced use of synchronizing __sycl_usm_array_interface__ atribute
Instead of relying on SUAI attribute which has to synchronize to get the offset, use `X._element_offset` attribute directly.
1 parent 3f0f935 commit bdb2f75

File tree

2 files changed

+8
-9
lines changed

2 files changed

+8
-9
lines changed

dpctl/tensor/_copy_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def _copy_to_numpy(ary):
4545
h = np.ndarray(nb, dtype="u1", buffer=hh).view(ary.dtype)
4646
itsz = ary.itemsize
4747
strides_bytes = tuple(si * itsz for si in ary.strides)
48-
offset = ary.__sycl_usm_array_interface__.get("offset", 0) * itsz
48+
offset = ary._element_offset * itsz
4949
# ensure that content of ary.usm_data is final
5050
q.wait()
5151
hh.copy_from_device(ary.usm_data)
@@ -645,7 +645,7 @@ def astype(
645645
target_dtype, d.has_aspect_fp16, d.has_aspect_fp64
646646
):
647647
raise ValueError(
648-
f"Requested dtype `{target_dtype}` is not supported by the "
648+
f"Requested dtype '{target_dtype}' is not supported by the "
649649
"target device"
650650
)
651651
usm_ary = usm_ary.to_device(device)

dpctl/tensor/_manipulation_functions.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def permute_dims(X, /, axes):
116116
dtype=X.dtype,
117117
buffer=X,
118118
strides=newstrides,
119-
offset=X.__sycl_usm_array_interface__.get("offset", 0),
119+
offset=X._element_offset,
120120
)
121121

122122

@@ -244,7 +244,7 @@ def broadcast_to(X, /, shape):
244244
dtype=X.dtype,
245245
buffer=X,
246246
strides=new_sts,
247-
offset=X.__sycl_usm_array_interface__.get("offset", 0),
247+
offset=X._element_offset,
248248
)
249249

250250

@@ -817,8 +817,8 @@ def repeat(x, repeats, /, *, axis=None):
817817
dpctl.utils.validate_usm_type(usm_type, allow_none=False)
818818
if not dpt.can_cast(repeats.dtype, dpt.int64, casting="same_kind"):
819819
raise TypeError(
820-
f"`repeats` data type `{repeats.dtype}` cannot be cast to "
821-
"`int64` according to the casting rule ''safe.''"
820+
f"'repeats' data type {repeats.dtype} cannot be cast to "
821+
"'int64' according to the casting rule ''safe.''"
822822
)
823823
if repeats.size == 1:
824824
scalar = True
@@ -829,11 +829,11 @@ def repeat(x, repeats, /, *, axis=None):
829829
else:
830830
if repeats.size != axis_size:
831831
raise ValueError(
832-
"`repeats` array must be broadcastable to the size of "
832+
"'repeats' array must be broadcastable to the size of "
833833
"the repeated axis"
834834
)
835835
if not dpt.all(repeats >= 0):
836-
raise ValueError("`repeats` elements must be positive")
836+
raise ValueError("'repeats' elements must be positive")
837837

838838
elif isinstance(repeats, (tuple, list, range)):
839839
usm_type = x.usm_type
@@ -862,7 +862,6 @@ def repeat(x, repeats, /, *, axis=None):
862862
f"got {type(repeats)}"
863863
)
864864

865-
866865
_manager = dputils.SequentialOrderManager[exec_q]
867866
dep_evs = _manager.submitted_events
868867
if scalar:

0 commit comments

Comments
 (0)