Skip to content

Commit 71408b0

Browse files
committed
Removes assumption that new array is writable
Now flags are set based on input regardless of whether a new array is writable per review feedback
1 parent 87cd798 commit 71408b0

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

dpctl/tensor/_usmarray.pyx

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -731,7 +731,7 @@ cdef class usm_ndarray:
731731
adv_ind_start_p = _meta[4]
732732

733733
if adv_ind_start_p < 0:
734-
res.flags_ ^= (~self.flags_ & USM_ARRAY_WRITABLE)
734+
res.flags_ = (res.flags_ & ~USM_ARRAY_WRITABLE) | (self.flags_ & USM_ARRAY_WRITABLE)
735735
return res
736736

737737
from ._copy_utils import _extract_impl, _nonzero_impl, _take_multi_index
@@ -749,7 +749,7 @@ cdef class usm_ndarray:
749749
if not matching:
750750
raise IndexError("boolean index did not match indexed array in dimensions")
751751
res = _extract_impl(res, key_, axis=adv_ind_start_p)
752-
res.flags_ ^= (~self.flags_ & USM_ARRAY_WRITABLE)
752+
res.flags_ = (res.flags_ & ~USM_ARRAY_WRITABLE) | (self.flags_ & USM_ARRAY_WRITABLE)
753753
return res
754754

755755
if any(ind.dtype == dpt_bool for ind in adv_ind):
@@ -760,11 +760,11 @@ cdef class usm_ndarray:
760760
else:
761761
adv_ind_int.append(ind)
762762
res = _take_multi_index(res, tuple(adv_ind_int), adv_ind_start_p)
763-
res.flags_ ^= (~self.flags_ & USM_ARRAY_WRITABLE)
763+
res.flags_ = (res.flags_ & ~USM_ARRAY_WRITABLE) | (self.flags_ & USM_ARRAY_WRITABLE)
764764
return res
765765

766766
res = _take_multi_index(res, adv_ind, adv_ind_start_p)
767-
res.flags_ ^= (~self.flags_ & USM_ARRAY_WRITABLE)
767+
res.flags_ = (res.flags_ & ~USM_ARRAY_WRITABLE) | (self.flags_ & USM_ARRAY_WRITABLE)
768768
return res
769769

770770
def to_device(self, target, stream=None):
@@ -1228,7 +1228,7 @@ cdef usm_ndarray _real_view(usm_ndarray ary):
12281228
offset=offset_elems,
12291229
order=('C' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else 'F')
12301230
)
1231-
r.flags_ ^= (~ary.flags_ & USM_ARRAY_WRITABLE)
1231+
r.flags_ = (r.flags_ & ~USM_ARRAY_WRITABLE) | (ary.flags_ & USM_ARRAY_WRITABLE)
12321232
r.array_namespace_ = ary.array_namespace_
12331233
return r
12341234

@@ -1260,7 +1260,7 @@ cdef usm_ndarray _imag_view(usm_ndarray ary):
12601260
offset=offset_elems,
12611261
order=('C' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else 'F')
12621262
)
1263-
r.flags_ ^= (~ary.flags_ & USM_ARRAY_WRITABLE)
1263+
r.flags_ = (r.flags_ & ~USM_ARRAY_WRITABLE) | (ary.flags_ & USM_ARRAY_WRITABLE)
12641264
r.array_namespace_ = ary.array_namespace_
12651265
return r
12661266

@@ -1280,7 +1280,7 @@ cdef usm_ndarray _transpose(usm_ndarray ary):
12801280
order=('F' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else 'C'),
12811281
offset=ary.get_offset()
12821282
)
1283-
r.flags_ ^= (~ary.flags_ & USM_ARRAY_WRITABLE)
1283+
r.flags_ = (r.flags_ & ~USM_ARRAY_WRITABLE) | (ary.flags_ & USM_ARRAY_WRITABLE)
12841284
return r
12851285

12861286

@@ -1297,7 +1297,7 @@ cdef usm_ndarray _m_transpose(usm_ndarray ary):
12971297
order=('F' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else 'C'),
12981298
offset=ary.get_offset()
12991299
)
1300-
r.flags_ ^= (~ary.flags_ & USM_ARRAY_WRITABLE)
1300+
r.flags_ = (r.flags_ & ~USM_ARRAY_WRITABLE) | (ary.flags_ & USM_ARRAY_WRITABLE)
13011301
return r
13021302

13031303

0 commit comments

Comments
 (0)