Skip to content

Commit da413a6

Browse files
committed
Adds _copy_writable for copying the writable flag between arrays
1 parent 71408b0 commit da413a6

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

dpctl/tensor/_usmarray.pyx

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ cdef object _as_zero_dim_ndarray(object usm_ary):
6060
view.shape = tuple()
6161
return view
6262

63+
cdef int _copy_writable(int lhs_flags, int rhs_flags):
64+
"Copy the WRITABLE flag to lhs_flags from rhs_flags"
65+
return (lhs_flag & ~USM_ARRAY_WRITABLE) | (rhs_flag & USM_ARRAY_WRITABLE)
6366

6467
cdef class usm_ndarray:
6568
""" usm_ndarray(shape, dtype=None, strides=None, buffer="device", \
@@ -731,7 +734,7 @@ cdef class usm_ndarray:
731734
adv_ind_start_p = _meta[4]
732735

733736
if adv_ind_start_p < 0:
734-
res.flags_ = (res.flags_ & ~USM_ARRAY_WRITABLE) | (self.flags_ & USM_ARRAY_WRITABLE)
737+
res.flags_ = _copy_writable(res.flags_, self.flags_)
735738
return res
736739

737740
from ._copy_utils import _extract_impl, _nonzero_impl, _take_multi_index
@@ -749,7 +752,7 @@ cdef class usm_ndarray:
749752
if not matching:
750753
raise IndexError("boolean index did not match indexed array in dimensions")
751754
res = _extract_impl(res, key_, axis=adv_ind_start_p)
752-
res.flags_ = (res.flags_ & ~USM_ARRAY_WRITABLE) | (self.flags_ & USM_ARRAY_WRITABLE)
755+
res.flags_ = _copy_writable(res.flags_, self.flags_)
753756
return res
754757

755758
if any(ind.dtype == dpt_bool for ind in adv_ind):
@@ -760,11 +763,11 @@ cdef class usm_ndarray:
760763
else:
761764
adv_ind_int.append(ind)
762765
res = _take_multi_index(res, tuple(adv_ind_int), adv_ind_start_p)
763-
res.flags_ = (res.flags_ & ~USM_ARRAY_WRITABLE) | (self.flags_ & USM_ARRAY_WRITABLE)
766+
res.flags_ = _copy_writable(res.flags_, self.flags_)
764767
return res
765768

766769
res = _take_multi_index(res, adv_ind, adv_ind_start_p)
767-
res.flags_ = (res.flags_ & ~USM_ARRAY_WRITABLE) | (self.flags_ & USM_ARRAY_WRITABLE)
770+
res.flags_ = _copy_writable(res.flags_, self.flags_)
768771
return res
769772

770773
def to_device(self, target, stream=None):
@@ -1228,7 +1231,7 @@ cdef usm_ndarray _real_view(usm_ndarray ary):
12281231
offset=offset_elems,
12291232
order=('C' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else 'F')
12301233
)
1231-
r.flags_ = (r.flags_ & ~USM_ARRAY_WRITABLE) | (ary.flags_ & USM_ARRAY_WRITABLE)
1234+
r.flags_ = _copy_writable(r.flags_, ary.flags_)
12321235
r.array_namespace_ = ary.array_namespace_
12331236
return r
12341237

@@ -1260,7 +1263,7 @@ cdef usm_ndarray _imag_view(usm_ndarray ary):
12601263
offset=offset_elems,
12611264
order=('C' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else 'F')
12621265
)
1263-
r.flags_ = (r.flags_ & ~USM_ARRAY_WRITABLE) | (ary.flags_ & USM_ARRAY_WRITABLE)
1266+
r.flags_ = _copy_writable(r.flags_, ary.flags_)
12641267
r.array_namespace_ = ary.array_namespace_
12651268
return r
12661269

@@ -1280,7 +1283,7 @@ cdef usm_ndarray _transpose(usm_ndarray ary):
12801283
order=('F' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else 'C'),
12811284
offset=ary.get_offset()
12821285
)
1283-
r.flags_ = (r.flags_ & ~USM_ARRAY_WRITABLE) | (ary.flags_ & USM_ARRAY_WRITABLE)
1286+
r.flags_ = _copy_writable(r.flags_, ary.flags_)
12841287
return r
12851288

12861289

@@ -1297,7 +1300,7 @@ cdef usm_ndarray _m_transpose(usm_ndarray ary):
12971300
order=('F' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else 'C'),
12981301
offset=ary.get_offset()
12991302
)
1300-
r.flags_ = (r.flags_ & ~USM_ARRAY_WRITABLE) | (ary.flags_ & USM_ARRAY_WRITABLE)
1303+
r.flags_ = _copy_writable(r.flags_, ary.flags_)
13011304
return r
13021305

13031306

0 commit comments

Comments
 (0)