Skip to content

Commit 00b36ee

Browse files
committed
Adds writable flag checks throughout Python bindings
These checks are implemented through a new CheckWritable struct, which has a static method throw_if_not_writable, which throws if an array is read-only
1 parent 275138b commit 00b36ee

21 files changed

+142
-22
lines changed

dpctl/apis/include/dpctl4pybind11.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1072,7 +1072,7 @@ bool queues_are_compatible(const sycl::queue &exec_q,
10721072
return true;
10731073
}
10741074

1075-
/*! @brief Check if all allocation queues of usm_ndarays are the same as
1075+
/*! @brief Check if all allocation queues of usm_ndarays are the same as
10761076
the execution queue */
10771077
template <std::size_t num>
10781078
bool queues_are_compatible(const sycl::queue &exec_q,
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
//===- output_validation.hpp - Utilities for output array validation
2+
//-*-C++-*===//
3+
//
4+
// Data Parallel Control (dpctl)
5+
//
6+
// Copyright 2020-2022 Intel Corporation
7+
//
8+
// Licensed under the Apache License, Version 2.0 (the "License");
9+
// you may not use this file except in compliance with the License.
10+
// You may obtain a copy of the License at
11+
//
12+
// http://www.apache.org/licenses/LICENSE-2.0
13+
//
14+
// Unless required by applicable law or agreed to in writing, software
15+
// distributed under the License is distributed on an "AS IS" BASIS,
16+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
// See the License for the specific language governing permissions and
18+
// limitations under the License.
19+
//
20+
//===----------------------------------------------------------------------===//
21+
///
22+
/// \file
23+
/// This file defines utilities for determining if an array is a valid output
24+
/// array.
25+
//===----------------------------------------------------------------------===//
26+
27+
#pragma once
28+
#include "dpctl4pybind11.hpp"
29+
#include <pybind11/pybind11.h>
30+
31+
namespace dpctl
32+
{
33+
34+
namespace tensor
35+
{
36+
37+
namespace validation
38+
{
39+
40+
/*! @brief Raises a value error if a function would attempt to write
41+
to an array which is read-only.
42+
43+
This should always be called on an array before it will be written to.*/
44+
struct CheckWritable
45+
{
46+
static void throw_if_not_writable(const dpctl::tensor::usm_ndarray &arr)
47+
{
48+
if (!arr.is_writable()) {
49+
throw py::value_error("output array is read-only.");
50+
}
51+
return;
52+
}
53+
};
54+
55+
} // namespace validation
56+
} // namespace tensor
57+
} // namespace dpctl

dpctl/tensor/libtensor/source/accumulators.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include "simplify_iteration_space.hpp"
3636
#include "utils/memory_overlap.hpp"
3737
#include "utils/offset_utils.hpp"
38+
#include "utils/output_validation.hpp"
3839
#include "utils/type_dispatch.hpp"
3940

4041
namespace dpctl
@@ -102,6 +103,8 @@ size_t py_mask_positions(const dpctl::tensor::usm_ndarray &mask,
102103
sycl::queue &exec_q,
103104
const std::vector<sycl::event> &depends)
104105
{
106+
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(cumsum);
107+
105108
// cumsum is 1D
106109
if (cumsum.get_ndim() != 1) {
107110
throw py::value_error("Result array must be one-dimensional.");
@@ -274,6 +277,8 @@ size_t py_cumsum_1d(const dpctl::tensor::usm_ndarray &src,
274277
"Execution queue is not compatible with allocation queues");
275278
}
276279

280+
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(cumsum);
281+
277282
if (src_size == 0) {
278283
return 0;
279284
}

dpctl/tensor/libtensor/source/boolean_advanced_indexing.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
#include "simplify_iteration_space.hpp"
3838
#include "utils/memory_overlap.hpp"
3939
#include "utils/offset_utils.hpp"
40+
#include "utils/output_validation.hpp"
4041
#include "utils/type_dispatch.hpp"
4142

4243
namespace dpctl
@@ -118,6 +119,8 @@ py_extract(const dpctl::tensor::usm_ndarray &src,
118119
sycl::queue &exec_q,
119120
const std::vector<sycl::event> &depends)
120121
{
122+
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst);
123+
121124
int src_nd = src.get_ndim();
122125
if ((axis_start < 0 || axis_end > src_nd || axis_start >= axis_end)) {
123126
throw py::value_error("Specified axes_start and axes_end are invalid.");
@@ -452,6 +455,8 @@ py_place(const dpctl::tensor::usm_ndarray &dst,
452455
sycl::queue &exec_q,
453456
const std::vector<sycl::event> &depends)
454457
{
458+
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst);
459+
455460
int dst_nd = dst.get_ndim();
456461
if ((axis_start < 0 || axis_end > dst_nd || axis_start >= axis_end)) {
457462
throw py::value_error("Specified axes_start and axes_end are invalid.");
@@ -726,6 +731,8 @@ py_nonzero(const dpctl::tensor::usm_ndarray
726731
"Execution queue is not compatible with allocation queues");
727732
}
728733

734+
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(indexes);
735+
729736
int cumsum_nd = cumsum.get_ndim();
730737
if (cumsum_nd != 1 || !cumsum.is_c_contiguous()) {
731738
throw py::value_error("Cumsum array must be a C-contiguous vector");

dpctl/tensor/libtensor/source/clip.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
#include "simplify_iteration_space.hpp"
3838
#include "utils/memory_overlap.hpp"
3939
#include "utils/offset_utils.hpp"
40+
#include "utils/output_validation.hpp"
4041
#include "utils/type_dispatch.hpp"
4142

4243
namespace dpctl
@@ -87,6 +88,8 @@ py_clip(const dpctl::tensor::usm_ndarray &src,
8788
"Execution queue is not compatible with allocation queues");
8889
}
8990

91+
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst);
92+
9093
int nd = src.get_ndim();
9194
int min_nd = min.get_ndim();
9295
int max_nd = max.get_ndim();

dpctl/tensor/libtensor/source/copy_and_cast_usm_to_usm.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
#include "dpctl4pybind11.hpp"
3838
#include "kernels/copy_and_cast.hpp"
3939
#include "utils/memory_overlap.hpp"
40+
#include "utils/output_validation.hpp"
4041
#include "utils/type_dispatch.hpp"
4142
#include "utils/type_utils.hpp"
4243

@@ -118,6 +119,8 @@ copy_usm_ndarray_into_usm_ndarray(const dpctl::tensor::usm_ndarray &src,
118119
"Execution queue is not compatible with allocation queues");
119120
}
120121

122+
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst);
123+
121124
int src_typenum = src.get_typenum();
122125
int dst_typenum = dst.get_typenum();
123126

dpctl/tensor/libtensor/source/copy_for_reshape.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "copy_for_reshape.hpp"
3030
#include "dpctl4pybind11.hpp"
3131
#include "kernels/copy_and_cast.hpp"
32+
#include "utils/output_validation.hpp"
3233
#include "utils/type_dispatch.hpp"
3334
#include <pybind11/pybind11.h>
3435

@@ -105,6 +106,8 @@ copy_usm_ndarray_for_reshape(const dpctl::tensor::usm_ndarray &src,
105106
"Execution queue is not compatible with allocation queues");
106107
}
107108

109+
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst);
110+
108111
if (src_nelems == 1) {
109112
// handle special case of 1-element array
110113
int src_elemsize = src.get_elemsize();

dpctl/tensor/libtensor/source/copy_for_roll.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "copy_for_roll.hpp"
3030
#include "dpctl4pybind11.hpp"
3131
#include "kernels/copy_and_cast.hpp"
32+
#include "utils/output_validation.hpp"
3233
#include "utils/type_dispatch.hpp"
3334
#include <pybind11/pybind11.h>
3435

@@ -128,6 +129,8 @@ copy_usm_ndarray_for_roll_1d(const dpctl::tensor::usm_ndarray &src,
128129
"Execution queue is not compatible with allocation queues");
129130
}
130131

132+
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst);
133+
131134
if (src_nelems == 1) {
132135
// handle special case of 1-element array
133136
int src_elemsize = src.get_elemsize();

dpctl/tensor/libtensor/source/copy_numpy_ndarray_into_usm_ndarray.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include <pybind11/pybind11.h>
3232

3333
#include "kernels/copy_and_cast.hpp"
34+
#include "utils/output_validation.hpp"
3435
#include "utils/type_dispatch.hpp"
3536

3637
#include "copy_numpy_ndarray_into_usm_ndarray.hpp"
@@ -106,6 +107,8 @@ void copy_numpy_ndarray_into_usm_ndarray(
106107
"allocation queue");
107108
}
108109

110+
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst);
111+
109112
// here we assume that NumPy's type numbers agree with ours for types
110113
// supported in both
111114
int src_typenum =

dpctl/tensor/libtensor/source/elementwise_functions/elementwise_functions.hpp

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
#include "simplify_iteration_space.hpp"
3939
#include "utils/memory_overlap.hpp"
4040
#include "utils/offset_utils.hpp"
41+
#include "utils/output_validation.hpp"
4142
#include "utils/type_dispatch.hpp"
4243

4344
namespace py = pybind11;
@@ -69,10 +70,6 @@ py_unary_ufunc(const dpctl::tensor::usm_ndarray &src,
6970
const contig_dispatchT &contig_dispatch_vector,
7071
const strided_dispatchT &strided_dispatch_vector)
7172
{
72-
if (!dst.is_writable()) {
73-
throw py::value_error("Output array is read-only.");
74-
}
75-
7673
int src_typenum = src.get_typenum();
7774
int dst_typenum = dst.get_typenum();
7875

@@ -94,6 +91,8 @@ py_unary_ufunc(const dpctl::tensor::usm_ndarray &src,
9491
"Execution queue is not compatible with allocation queues");
9592
}
9693

94+
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst);
95+
9796
// check that dimensions are the same
9897
int src_nd = src.get_ndim();
9998
if (src_nd != dst.get_ndim()) {
@@ -324,9 +323,6 @@ std::pair<sycl::event, sycl::event> py_binary_ufunc(
324323
const contig_row_matrix_dispatchT
325324
&contig_row_matrix_broadcast_dispatch_table)
326325
{
327-
if (!dst.is_writable()) {
328-
throw py::value_error("Output array is read-only.");
329-
}
330326
// check type_nums
331327
int src1_typenum = src1.get_typenum();
332328
int src2_typenum = src2.get_typenum();
@@ -350,6 +346,8 @@ std::pair<sycl::event, sycl::event> py_binary_ufunc(
350346
"Execution queue is not compatible with allocation queues");
351347
}
352348

349+
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst);
350+
353351
// check shapes, broadcasting is assumed done by caller
354352
// check that dimensions are the same
355353
int dst_nd = dst.get_ndim();
@@ -655,9 +653,7 @@ py_binary_inplace_ufunc(const dpctl::tensor::usm_ndarray &lhs,
655653
const contig_row_matrix_dispatchT
656654
&contig_row_matrix_broadcast_dispatch_table)
657655
{
658-
if (!lhs.is_writable()) {
659-
throw py::value_error("Output array is read-only.");
660-
}
656+
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(lhs);
661657

662658
// check type_nums
663659
int rhs_typenum = rhs.get_typenum();

dpctl/tensor/libtensor/source/eye_ctor.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
#include "eye_ctor.hpp"
3333
#include "kernels/constructors.hpp"
34+
#include "utils/output_validation.hpp"
3435
#include "utils/type_dispatch.hpp"
3536

3637
namespace py = pybind11;
@@ -66,6 +67,8 @@ usm_ndarray_eye(py::ssize_t k,
6667
"allocation queue");
6768
}
6869

70+
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst);
71+
6972
auto array_types = td_ns::usm_ndarray_types();
7073
int dst_typenum = dst.get_typenum();
7174
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);

dpctl/tensor/libtensor/source/full_ctor.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include <vector>
3232

3333
#include "kernels/constructors.hpp"
34+
#include "utils/output_validation.hpp"
3435
#include "utils/type_dispatch.hpp"
3536
#include "utils/type_utils.hpp"
3637

@@ -126,6 +127,8 @@ usm_ndarray_full(const py::object &py_value,
126127
"Execution queue is not compatible with the allocation queue");
127128
}
128129

130+
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst);
131+
129132
auto array_types = td_ns::usm_ndarray_types();
130133
int dst_typenum = dst.get_typenum();
131134
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);

dpctl/tensor/libtensor/source/integer_advanced_indexing.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include "dpctl4pybind11.hpp"
3636
#include "kernels/integer_advanced_indexing.hpp"
3737
#include "utils/memory_overlap.hpp"
38+
#include "utils/output_validation.hpp"
3839
#include "utils/type_dispatch.hpp"
3940
#include "utils/type_utils.hpp"
4041

@@ -259,6 +260,8 @@ usm_ndarray_take(const dpctl::tensor::usm_ndarray &src,
259260
throw py::value_error("Mode must be 0 or 1.");
260261
}
261262

263+
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst);
264+
262265
const dpctl::tensor::usm_ndarray ind_rep = ind[0];
263266

264267
int src_nd = src.get_ndim();
@@ -570,9 +573,7 @@ usm_ndarray_put(const dpctl::tensor::usm_ndarray &dst,
570573
throw py::value_error("Mode must be 0 or 1.");
571574
}
572575

573-
if (!dst.is_writable()) {
574-
throw py::value_error("Output array is read-only.");
575-
}
576+
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst);
576577

577578
const dpctl::tensor::usm_ndarray ind_rep = ind[0];
578579

dpctl/tensor/libtensor/source/linalg_functions/dot.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "simplify_iteration_space.hpp"
1818
#include "utils/memory_overlap.hpp"
1919
#include "utils/offset_utils.hpp"
20+
#include "utils/output_validation.hpp"
2021

2122
namespace dpctl
2223
{
@@ -175,20 +176,17 @@ py_dot(const dpctl::tensor::usm_ndarray &x1,
175176
sycl::queue &exec_q,
176177
const std::vector<sycl::event> &depends)
177178
{
178-
179-
if (!dst.is_writable()) {
180-
throw py::value_error("Output array is read-only.");
179+
if (!dpctl::utils::queues_are_compatible(exec_q, {x1, x2, dst})) {
180+
throw py::value_error(
181+
"Execution queue is not compatible with allocation queues");
181182
}
182183

184+
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst);
185+
183186
if (inner_dims == 0) {
184187
throw py::value_error("No inner dimension for dot");
185188
}
186189

187-
if (!dpctl::utils::queues_are_compatible(exec_q, {x1, x2, dst})) {
188-
throw py::value_error(
189-
"Execution queue is not compatible with allocation queues");
190-
}
191-
192190
int x1_nd = x1.get_ndim();
193191
int x2_nd = x2.get_ndim();
194192
if (x1_nd != (batch_dims + x1_outer_dims + inner_dims) ||

dpctl/tensor/libtensor/source/linear_sequences.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include <vector>
3232

3333
#include "kernels/constructors.hpp"
34+
#include "utils/output_validation.hpp"
3435
#include "utils/type_dispatch.hpp"
3536
#include "utils/type_utils.hpp"
3637

@@ -191,6 +192,8 @@ usm_ndarray_linear_sequence_step(const py::object &start,
191192
"Execution queue is not compatible with the allocation queue");
192193
}
193194

195+
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst);
196+
194197
auto array_types = td_ns::usm_ndarray_types();
195198
int dst_typenum = dst.get_typenum();
196199
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
@@ -239,6 +242,8 @@ usm_ndarray_linear_sequence_affine(const py::object &start,
239242
"Execution queue context is not the same as allocation context");
240243
}
241244

245+
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst);
246+
242247
auto array_types = td_ns::usm_ndarray_types();
243248
int dst_typenum = dst.get_typenum();
244249
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);

0 commit comments

Comments
 (0)