@@ -546,7 +546,7 @@ cdef class usm_ndarray:
546
546
PyMem_Free(self .shape_)
547
547
if (self .strides_):
548
548
PyMem_Free(self .strides_)
549
- self .flags_ = contig_flag
549
+ self .flags_ = ( contig_flag | ( self .flags_ & USM_ARRAY_WRITABLE))
550
550
self .nd_ = new_nd
551
551
self .shape_ = shape_ptr
552
552
self .strides_ = strides_ptr
@@ -725,13 +725,13 @@ cdef class usm_ndarray:
725
725
buffer = self .base_,
726
726
offset = _meta[2 ]
727
727
)
728
- res.flags_ |= (self .flags_ & USM_ARRAY_WRITABLE)
729
728
res.array_namespace_ = self .array_namespace_
730
729
731
730
adv_ind = _meta[3 ]
732
731
adv_ind_start_p = _meta[4 ]
733
732
734
733
if adv_ind_start_p < 0 :
734
+ res.flags_ ^= (~ self .flags_ & USM_ARRAY_WRITABLE)
735
735
return res
736
736
737
737
from ._copy_utils import _extract_impl, _nonzero_impl, _take_multi_index
@@ -749,6 +749,7 @@ cdef class usm_ndarray:
749
749
if not matching:
750
750
raise IndexError (" boolean index did not match indexed array in dimensions" )
751
751
res = _extract_impl(res, key_, axis = adv_ind_start_p)
752
+ res.flags_ ^= (~ self .flags_ & USM_ARRAY_WRITABLE)
752
753
return res
753
754
754
755
if any (ind.dtype == dpt_bool for ind in adv_ind):
@@ -758,10 +759,13 @@ cdef class usm_ndarray:
758
759
adv_ind_int.extend(_nonzero_impl(ind))
759
760
else :
760
761
adv_ind_int.append(ind)
761
- return _take_multi_index(res, tuple (adv_ind_int), adv_ind_start_p)
762
-
763
- return _take_multi_index(res, adv_ind, adv_ind_start_p)
762
+ res = _take_multi_index(res, tuple (adv_ind_int), adv_ind_start_p)
763
+ res.flags_ ^= ( ~ self .flags_ & USM_ARRAY_WRITABLE)
764
+ return res
764
765
766
+ res = _take_multi_index(res, adv_ind, adv_ind_start_p)
767
+ res.flags_ ^= (~ self .flags_ & USM_ARRAY_WRITABLE)
768
+ return res
765
769
766
770
def to_device (self , target , stream = None ):
767
771
""" to_device(target_device)
@@ -1040,8 +1044,7 @@ cdef class usm_ndarray:
1040
1044
buffer = self .base_,
1041
1045
offset = _meta[2 ],
1042
1046
)
1043
- # set flags and namespace
1044
- Xv.flags_ |= (self .flags_ & USM_ARRAY_WRITABLE)
1047
+ # set namespace
1045
1048
Xv.array_namespace_ = self .array_namespace_
1046
1049
1047
1050
from ._copy_utils import (
@@ -1225,7 +1228,7 @@ cdef usm_ndarray _real_view(usm_ndarray ary):
1225
1228
offset = offset_elems,
1226
1229
order = (' C' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else ' F' )
1227
1230
)
1228
- r.flags_ | = (ary.flags_ & USM_ARRAY_WRITABLE)
1231
+ r.flags_ ^ = (~ ary.flags_ & USM_ARRAY_WRITABLE)
1229
1232
r.array_namespace_ = ary.array_namespace_
1230
1233
return r
1231
1234
@@ -1257,7 +1260,7 @@ cdef usm_ndarray _imag_view(usm_ndarray ary):
1257
1260
offset = offset_elems,
1258
1261
order = (' C' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else ' F' )
1259
1262
)
1260
- r.flags_ | = (ary.flags_ & USM_ARRAY_WRITABLE)
1263
+ r.flags_ ^ = (~ ary.flags_ & USM_ARRAY_WRITABLE)
1261
1264
r.array_namespace_ = ary.array_namespace_
1262
1265
return r
1263
1266
@@ -1277,7 +1280,7 @@ cdef usm_ndarray _transpose(usm_ndarray ary):
1277
1280
order = (' F' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else ' C' ),
1278
1281
offset = ary.get_offset()
1279
1282
)
1280
- r.flags_ | = (ary.flags_ & USM_ARRAY_WRITABLE)
1283
+ r.flags_ ^ = (~ ary.flags_ & USM_ARRAY_WRITABLE)
1281
1284
return r
1282
1285
1283
1286
@@ -1294,7 +1297,7 @@ cdef usm_ndarray _m_transpose(usm_ndarray ary):
1294
1297
order = (' F' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else ' C' ),
1295
1298
offset = ary.get_offset()
1296
1299
)
1297
- r.flags_ | = (ary.flags_ & USM_ARRAY_WRITABLE)
1300
+ r.flags_ ^ = (~ ary.flags_ & USM_ARRAY_WRITABLE)
1298
1301
return r
1299
1302
1300
1303
0 commit comments