Skip to content

Commit bacfa59

Browse files
committed
Fixes bugs with writable flag setting
`writable` flag was not being set correctly for indexing, real views, imaginary views, tranposes, and where shape is set directly Also fixes cases where flag could be overridden by functions with `out` kwarg
1 parent 89266a7 commit bacfa59

File tree

4 files changed

+39
-11
lines changed

4 files changed

+39
-11
lines changed

dpctl/tensor/_clip.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,11 @@ def _clip_none(x, val, out, order, _binary_fn):
262262
f"output array must be of usm_ndarray type, got {type(out)}"
263263
)
264264

265+
if not out.flags.writable:
266+
raise ValueError(
267+
"provided `out` array is read-only"
268+
)
269+
265270
if out.shape != res_shape:
266271
raise ValueError(
267272
"The shape of input and output arrays are inconsistent. "
@@ -600,6 +605,11 @@ def clip(x, /, min=None, max=None, out=None, order="K"):
600605
f"{type(out)}"
601606
)
602607

608+
if not out.flags.writable:
609+
raise ValueError(
610+
"provided `out` array is read-only"
611+
)
612+
603613
if out.shape != res_shape:
604614
raise ValueError(
605615
"The shape of input and output arrays are "

dpctl/tensor/_elementwise_common.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,11 @@ def __call__(self, x, out=None, order="K"):
202202
f"output array must be of usm_ndarray type, got {type(out)}"
203203
)
204204

205+
if not out.flags.writable:
206+
raise ValueError(
207+
"provided `out` array is read-only"
208+
)
209+
205210
if out.shape != x.shape:
206211
raise ValueError(
207212
"The shape of input and output arrays are inconsistent. "
@@ -601,6 +606,11 @@ def __call__(self, o1, o2, out=None, order="K"):
601606
f"output array must be of usm_ndarray type, got {type(out)}"
602607
)
603608

609+
if not out.flags.writable:
610+
raise ValueError(
611+
"provided `out` array is read-only"
612+
)
613+
604614
if out.shape != res_shape:
605615
raise ValueError(
606616
"The shape of input and output arrays are inconsistent. "

dpctl/tensor/_linear_algebra_functions.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -738,6 +738,11 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
738738
f"output array must be of usm_ndarray type, got {type(out)}"
739739
)
740740

741+
if not out.flags.writable:
742+
raise ValueError(
743+
"provided `out` array is read-only"
744+
)
745+
741746
if out.shape != res_shape:
742747
raise ValueError(
743748
"The shape of input and output arrays are inconsistent. "

dpctl/tensor/_usmarray.pyx

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -546,7 +546,7 @@ cdef class usm_ndarray:
546546
PyMem_Free(self.shape_)
547547
if (self.strides_):
548548
PyMem_Free(self.strides_)
549-
self.flags_ = contig_flag
549+
self.flags_ = (contig_flag | (self.flags_ & USM_ARRAY_WRITABLE))
550550
self.nd_ = new_nd
551551
self.shape_ = shape_ptr
552552
self.strides_ = strides_ptr
@@ -725,13 +725,13 @@ cdef class usm_ndarray:
725725
buffer=self.base_,
726726
offset=_meta[2]
727727
)
728-
res.flags_ |= (self.flags_ & USM_ARRAY_WRITABLE)
729728
res.array_namespace_ = self.array_namespace_
730729

731730
adv_ind = _meta[3]
732731
adv_ind_start_p = _meta[4]
733732

734733
if adv_ind_start_p < 0:
734+
res.flags_ ^= (~self.flags_ & USM_ARRAY_WRITABLE)
735735
return res
736736

737737
from ._copy_utils import _extract_impl, _nonzero_impl, _take_multi_index
@@ -749,6 +749,7 @@ cdef class usm_ndarray:
749749
if not matching:
750750
raise IndexError("boolean index did not match indexed array in dimensions")
751751
res = _extract_impl(res, key_, axis=adv_ind_start_p)
752+
res.flags_ ^= (~self.flags_ & USM_ARRAY_WRITABLE)
752753
return res
753754

754755
if any(ind.dtype == dpt_bool for ind in adv_ind):
@@ -758,10 +759,13 @@ cdef class usm_ndarray:
758759
adv_ind_int.extend(_nonzero_impl(ind))
759760
else:
760761
adv_ind_int.append(ind)
761-
return _take_multi_index(res, tuple(adv_ind_int), adv_ind_start_p)
762-
763-
return _take_multi_index(res, adv_ind, adv_ind_start_p)
762+
res = _take_multi_index(res, tuple(adv_ind_int), adv_ind_start_p)
763+
res.flags_ ^= (~self.flags_ & USM_ARRAY_WRITABLE)
764+
return res
764765

766+
res = _take_multi_index(res, adv_ind, adv_ind_start_p)
767+
res.flags_ ^= (~self.flags_ & USM_ARRAY_WRITABLE)
768+
return res
765769

766770
def to_device(self, target, stream=None):
767771
""" to_device(target_device)
@@ -1040,8 +1044,7 @@ cdef class usm_ndarray:
10401044
buffer=self.base_,
10411045
offset=_meta[2],
10421046
)
1043-
# set flags and namespace
1044-
Xv.flags_ |= (self.flags_ & USM_ARRAY_WRITABLE)
1047+
# set namespace
10451048
Xv.array_namespace_ = self.array_namespace_
10461049

10471050
from ._copy_utils import (
@@ -1225,7 +1228,7 @@ cdef usm_ndarray _real_view(usm_ndarray ary):
12251228
offset=offset_elems,
12261229
order=('C' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else 'F')
12271230
)
1228-
r.flags_ |= (ary.flags_ & USM_ARRAY_WRITABLE)
1231+
r.flags_ ^= (~ary.flags_ & USM_ARRAY_WRITABLE)
12291232
r.array_namespace_ = ary.array_namespace_
12301233
return r
12311234

@@ -1257,7 +1260,7 @@ cdef usm_ndarray _imag_view(usm_ndarray ary):
12571260
offset=offset_elems,
12581261
order=('C' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else 'F')
12591262
)
1260-
r.flags_ |= (ary.flags_ & USM_ARRAY_WRITABLE)
1263+
r.flags_ ^= (~ary.flags_ & USM_ARRAY_WRITABLE)
12611264
r.array_namespace_ = ary.array_namespace_
12621265
return r
12631266

@@ -1277,7 +1280,7 @@ cdef usm_ndarray _transpose(usm_ndarray ary):
12771280
order=('F' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else 'C'),
12781281
offset=ary.get_offset()
12791282
)
1280-
r.flags_ |= (ary.flags_ & USM_ARRAY_WRITABLE)
1283+
r.flags_ ^= (~ary.flags_ & USM_ARRAY_WRITABLE)
12811284
return r
12821285

12831286

@@ -1294,7 +1297,7 @@ cdef usm_ndarray _m_transpose(usm_ndarray ary):
12941297
order=('F' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else 'C'),
12951298
offset=ary.get_offset()
12961299
)
1297-
r.flags_ |= (ary.flags_ & USM_ARRAY_WRITABLE)
1300+
r.flags_ ^= (~ary.flags_ & USM_ARRAY_WRITABLE)
12981301
return r
12991302

13001303

0 commit comments

Comments
 (0)