@@ -60,6 +60,9 @@ cdef object _as_zero_dim_ndarray(object usm_ary):
60
60
view.shape = tuple ()
61
61
return view
62
62
63
+ cdef int _copy_writable(int lhs_flags, int rhs_flags):
64
+ " Copy the WRITABLE flag to lhs_flags from rhs_flags"
65
+ return (lhs_flag & ~ USM_ARRAY_WRITABLE) | (rhs_flag & USM_ARRAY_WRITABLE)
63
66
64
67
cdef class usm_ndarray:
65
68
""" usm_ndarray(shape, dtype=None, strides=None, buffer="device", \
@@ -731,7 +734,7 @@ cdef class usm_ndarray:
731
734
adv_ind_start_p = _meta[4 ]
732
735
733
736
if adv_ind_start_p < 0 :
734
- res.flags_ = (res.flags_ & ~ USM_ARRAY_WRITABLE) | ( self .flags_ & USM_ARRAY_WRITABLE )
737
+ res.flags_ = _copy_writable (res.flags_, self .flags_)
735
738
return res
736
739
737
740
from ._copy_utils import _extract_impl, _nonzero_impl, _take_multi_index
@@ -749,7 +752,7 @@ cdef class usm_ndarray:
749
752
if not matching:
750
753
raise IndexError (" boolean index did not match indexed array in dimensions" )
751
754
res = _extract_impl(res, key_, axis = adv_ind_start_p)
752
- res.flags_ = (res.flags_ & ~ USM_ARRAY_WRITABLE) | ( self .flags_ & USM_ARRAY_WRITABLE )
755
+ res.flags_ = _copy_writable (res.flags_, self .flags_)
753
756
return res
754
757
755
758
if any (ind.dtype == dpt_bool for ind in adv_ind):
@@ -760,11 +763,11 @@ cdef class usm_ndarray:
760
763
else :
761
764
adv_ind_int.append(ind)
762
765
res = _take_multi_index(res, tuple (adv_ind_int), adv_ind_start_p)
763
- res.flags_ = (res.flags_ & ~ USM_ARRAY_WRITABLE) | ( self .flags_ & USM_ARRAY_WRITABLE )
766
+ res.flags_ = _copy_writable (res.flags_, self .flags_)
764
767
return res
765
768
766
769
res = _take_multi_index(res, adv_ind, adv_ind_start_p)
767
- res.flags_ = (res.flags_ & ~ USM_ARRAY_WRITABLE) | ( self .flags_ & USM_ARRAY_WRITABLE )
770
+ res.flags_ = _copy_writable (res.flags_, self .flags_)
768
771
return res
769
772
770
773
def to_device (self , target , stream = None ):
@@ -1228,7 +1231,7 @@ cdef usm_ndarray _real_view(usm_ndarray ary):
1228
1231
offset = offset_elems,
1229
1232
order = (' C' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else ' F' )
1230
1233
)
1231
- r.flags_ = (r.flags_ & ~ USM_ARRAY_WRITABLE) | ( ary.flags_ & USM_ARRAY_WRITABLE )
1234
+ r.flags_ = _copy_writable (r.flags_, ary.flags_)
1232
1235
r.array_namespace_ = ary.array_namespace_
1233
1236
return r
1234
1237
@@ -1260,7 +1263,7 @@ cdef usm_ndarray _imag_view(usm_ndarray ary):
1260
1263
offset = offset_elems,
1261
1264
order = (' C' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else ' F' )
1262
1265
)
1263
- r.flags_ = (r.flags_ & ~ USM_ARRAY_WRITABLE) | ( ary.flags_ & USM_ARRAY_WRITABLE )
1266
+ r.flags_ = _copy_writable (r.flags_, ary.flags_)
1264
1267
r.array_namespace_ = ary.array_namespace_
1265
1268
return r
1266
1269
@@ -1280,7 +1283,7 @@ cdef usm_ndarray _transpose(usm_ndarray ary):
1280
1283
order = (' F' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else ' C' ),
1281
1284
offset = ary.get_offset()
1282
1285
)
1283
- r.flags_ = (r.flags_ & ~ USM_ARRAY_WRITABLE) | ( ary.flags_ & USM_ARRAY_WRITABLE )
1286
+ r.flags_ = _copy_writable (r.flags_, ary.flags_)
1284
1287
return r
1285
1288
1286
1289
@@ -1297,7 +1300,7 @@ cdef usm_ndarray _m_transpose(usm_ndarray ary):
1297
1300
order = (' F' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else ' C' ),
1298
1301
offset = ary.get_offset()
1299
1302
)
1300
- r.flags_ = (r.flags_ & ~ USM_ARRAY_WRITABLE) | ( ary.flags_ & USM_ARRAY_WRITABLE )
1303
+ r.flags_ = _copy_writable (r.flags_, ary.flags_)
1301
1304
return r
1302
1305
1303
1306
0 commit comments