Skip to content

Commit 81d0f9e

Browse files
committed
Factors checks for sufficient memory range into output_validation.hpp
1 parent 00b36ee commit 81d0f9e

15 files changed

+58
-314
lines changed

dpctl/tensor/libtensor/include/utils/output_validation.hpp

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,9 @@ namespace tensor
3737
namespace validation
3838
{
3939

40-
/*! @brief Raises a value error if a function would attempt to write
41-
to an array which is read-only.
40+
/*! @brief Raises a value error if an array is read-only.
4241
43-
This should always be called on an array before it will be written to.*/
42+
This should be called with an array before writing.*/
4443
struct CheckWritable
4544
{
4645
static void throw_if_not_writable(const dpctl::tensor::usm_ndarray &arr)
@@ -52,6 +51,26 @@ struct CheckWritable
5251
}
5352
};
5453

54+
/*! @brief Raises a value error if an array's memory is not sufficiently ample
55+
to accommodate an input number of elements.
56+
57+
This should be called with an array before writing.*/
58+
struct AmpleMemory
59+
{
60+
template <typename T>
61+
static void throw_if_not_ample(const dpctl::tensor::usm_ndarray &arr,
62+
T nelems)
63+
{
64+
auto arr_offsets = arr.get_minmax_offsets();
65+
T range = static_cast<T>(arr_offsets.second - arr_offsets.first);
66+
if (range + 1 < nelems) {
67+
throw py::value_error("Memory addressed by the output array is not "
68+
"sufficiently ample.");
69+
}
70+
return;
71+
}
72+
};
73+
5574
} // namespace validation
5675
} // namespace tensor
5776
} // namespace dpctl

dpctl/tensor/libtensor/source/boolean_advanced_indexing.cpp

Lines changed: 6 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -174,19 +174,8 @@ py_extract(const dpctl::tensor::usm_ndarray &src,
174174
throw py::value_error("Inconsistent array dimensions");
175175
}
176176

177-
// ensure that dst is sufficiently ample
178-
auto dst_offsets = dst.get_minmax_offsets();
179-
// destination must be ample enough to accommodate all elements
180-
{
181-
size_t range =
182-
static_cast<size_t>(dst_offsets.second - dst_offsets.first);
183-
if (range + 1 < static_cast<size_t>(ortho_nelems * masked_dst_nelems)) {
184-
throw py::value_error(
185-
"Memory addressed by the destination array can not "
186-
"accommodate all the "
187-
"array elements.");
188-
}
189-
}
177+
dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(
178+
dst, ortho_nelems * masked_dst_nelems);
190179

191180
auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
192181
// check that dst does not intersect with src, not with cumsum.
@@ -507,19 +496,8 @@ py_place(const dpctl::tensor::usm_ndarray &dst,
507496
throw py::value_error("Inconsistent array dimensions");
508497
}
509498

510-
// ensure that dst is sufficiently ample
511-
auto dst_offsets = dst.get_minmax_offsets();
512-
// destination must be ample enough to accommodate all elements
513-
{
514-
size_t range =
515-
static_cast<size_t>(dst_offsets.second - dst_offsets.first);
516-
if (range + 1 < static_cast<size_t>(ortho_nelems * masked_dst_nelems)) {
517-
throw py::value_error(
518-
"Memory addressed by the destination array can not "
519-
"accommodate all the "
520-
"array elements.");
521-
}
522-
}
499+
dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(
500+
dst, ortho_nelems * masked_dst_nelems);
523501

524502
auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
525503
// check that dst does not intersect with src, not with cumsum.
@@ -794,18 +772,8 @@ py_nonzero(const dpctl::tensor::usm_ndarray
794772
throw py::value_error("Arrays are expected to ave no memory overlap");
795773
}
796774

797-
// ensure that dst is sufficiently ample
798-
auto indexes_offsets = indexes.get_minmax_offsets();
799-
// destination must be ample enough to accommodate all elements
800-
{
801-
size_t range =
802-
static_cast<size_t>(indexes_offsets.second - indexes_offsets.first);
803-
if (range + 1 < static_cast<size_t>(nz_elems * _ndim)) {
804-
throw py::value_error(
805-
"Memory addressed by the destination array can not "
806-
"accommodate all the array elements.");
807-
}
808-
}
775+
dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(
776+
indexes, nz_elems * _ndim);
809777

810778
std::vector<sycl::event> host_task_events;
811779
host_task_events.reserve(2);

dpctl/tensor/libtensor/source/clip.cpp

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -155,19 +155,7 @@ py_clip(const dpctl::tensor::usm_ndarray &src,
155155
"have the same data type");
156156
}
157157

158-
// ensure that dst is sufficiently ample
159-
auto dst_offsets = dst.get_minmax_offsets();
160-
// destination must be ample enough to accommodate all elements
161-
{
162-
size_t range =
163-
static_cast<size_t>(dst_offsets.second - dst_offsets.first);
164-
if (range + 1 < static_cast<size_t>(nelems)) {
165-
throw py::value_error(
166-
"Memory addressed by the destination array can not "
167-
"accommodate all the "
168-
"array elements.");
169-
}
170-
}
158+
dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst, nelems);
171159

172160
char *src_data = src.get_data();
173161
char *min_data = min.get_data();

dpctl/tensor/libtensor/source/copy_and_cast_usm_to_usm.cpp

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -101,17 +101,7 @@ copy_usm_ndarray_into_usm_ndarray(const dpctl::tensor::usm_ndarray &src,
101101
return std::make_pair(sycl::event(), sycl::event());
102102
}
103103

104-
// destination must be ample enough to accommodate all elements
105-
{
106-
auto dst_offsets = dst.get_minmax_offsets();
107-
size_t range =
108-
static_cast<size_t>(dst_offsets.second - dst_offsets.first);
109-
if (range + 1 < src_nelems) {
110-
throw py::value_error(
111-
"Destination array can not accommodate all the "
112-
"elements of source array.");
113-
}
114-
}
104+
dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst, src_nelems);
115105

116106
// check compatibility of execution queue and allocation queue
117107
if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) {

dpctl/tensor/libtensor/source/copy_for_reshape.cpp

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -88,17 +88,7 @@ copy_usm_ndarray_for_reshape(const dpctl::tensor::usm_ndarray &src,
8888
return std::make_pair(sycl::event(), sycl::event());
8989
}
9090

91-
// destination must be ample enough to accommodate all elements
92-
{
93-
auto dst_offsets = dst.get_minmax_offsets();
94-
py::ssize_t range =
95-
static_cast<py::ssize_t>(dst_offsets.second - dst_offsets.first);
96-
if (range + 1 < src_nelems) {
97-
throw py::value_error(
98-
"Destination array can not accommodate all the "
99-
"elements of source array.");
100-
}
101-
}
91+
dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst, src_nelems);
10292

10393
// check same contexts
10494
if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) {

dpctl/tensor/libtensor/source/copy_for_roll.cpp

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -111,17 +111,7 @@ copy_usm_ndarray_for_roll_1d(const dpctl::tensor::usm_ndarray &src,
111111
return std::make_pair(sycl::event(), sycl::event());
112112
}
113113

114-
// destination must be ample enough to accommodate all elements
115-
{
116-
auto dst_offsets = dst.get_minmax_offsets();
117-
py::ssize_t range =
118-
static_cast<py::ssize_t>(dst_offsets.second - dst_offsets.first);
119-
if (range + 1 < src_nelems) {
120-
throw py::value_error(
121-
"Destination array can not accommodate all the "
122-
"elements of source array.");
123-
}
124-
}
114+
dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst, src_nelems);
125115

126116
// check same contexts
127117
if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) {
@@ -301,17 +291,7 @@ copy_usm_ndarray_for_roll_nd(const dpctl::tensor::usm_ndarray &src,
301291
return std::make_pair(sycl::event(), sycl::event());
302292
}
303293

304-
// destination must be ample enough to accommodate all elements
305-
{
306-
auto dst_offsets = dst.get_minmax_offsets();
307-
py::ssize_t range =
308-
static_cast<py::ssize_t>(dst_offsets.second - dst_offsets.first);
309-
if (range + 1 < src_nelems) {
310-
throw py::value_error(
311-
"Destination array can not accommodate all the "
312-
"elements of source array.");
313-
}
314-
}
294+
dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst, src_nelems);
315295

316296
// check for compatible queues
317297
if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) {

dpctl/tensor/libtensor/source/copy_numpy_ndarray_into_usm_ndarray.cpp

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -89,18 +89,7 @@ void copy_numpy_ndarray_into_usm_ndarray(
8989
return;
9090
}
9191

92-
auto dst_offsets = dst.get_minmax_offsets();
93-
// destination must be ample enough to accommodate all elements of source
94-
// array
95-
{
96-
size_t range =
97-
static_cast<size_t>(dst_offsets.second - dst_offsets.first);
98-
if (range + 1 < src_nelems) {
99-
throw py::value_error(
100-
"Destination array can not accommodate all the "
101-
"elements of source array.");
102-
}
103-
}
92+
dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst, src_nelems);
10493

10594
if (!dpctl::utils::queues_are_compatible(exec_q, {dst})) {
10695
throw py::value_error("Execution queue is not compatible with the "

dpctl/tensor/libtensor/source/elementwise_functions/elementwise_functions.hpp

Lines changed: 3 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -118,18 +118,7 @@ py_unary_ufunc(const dpctl::tensor::usm_ndarray &src,
118118
return std::make_pair(sycl::event(), sycl::event());
119119
}
120120

121-
// ensure that output is ample enough to accommodate all elements
122-
auto dst_offsets = dst.get_minmax_offsets();
123-
// destination must be ample enough to accommodate all elements
124-
{
125-
size_t range =
126-
static_cast<size_t>(dst_offsets.second - dst_offsets.first);
127-
if (range + 1 < src_nelems) {
128-
throw py::value_error(
129-
"Destination array can not accommodate all the "
130-
"elements of source array.");
131-
}
132-
}
121+
dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst, src_nelems);
133122

134123
// check memory overlap
135124
auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
@@ -376,18 +365,7 @@ std::pair<sycl::event, sycl::event> py_binary_ufunc(
376365
return std::make_pair(sycl::event(), sycl::event());
377366
}
378367

379-
// ensure that output is ample enough to accommodate all elements
380-
auto dst_offsets = dst.get_minmax_offsets();
381-
// destination must be ample enough to accommodate all elements
382-
{
383-
size_t range =
384-
static_cast<size_t>(dst_offsets.second - dst_offsets.first);
385-
if (range + 1 < src_nelems) {
386-
throw py::value_error(
387-
"Destination array can not accommodate all the "
388-
"elements of source array.");
389-
}
390-
}
368+
dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst, src_nelems);
391369

392370
auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
393371
auto const &same_logical_tensors =
@@ -702,18 +680,7 @@ py_binary_inplace_ufunc(const dpctl::tensor::usm_ndarray &lhs,
702680
return std::make_pair(sycl::event(), sycl::event());
703681
}
704682

705-
// ensure that output is ample enough to accommodate all elements
706-
auto lhs_offsets = lhs.get_minmax_offsets();
707-
// destination must be ample enough to accommodate all elements
708-
{
709-
size_t range =
710-
static_cast<size_t>(lhs_offsets.second - lhs_offsets.first);
711-
if (range + 1 < rhs_nelems) {
712-
throw py::value_error(
713-
"Destination array can not accommodate all the "
714-
"elements of source array.");
715-
}
716-
}
683+
dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(rhs, rhs_nelems);
717684

718685
// check memory overlap
719686
auto const &same_logical_tensors =

dpctl/tensor/libtensor/source/integer_advanced_indexing.cpp

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -352,17 +352,8 @@ usm_ndarray_take(const dpctl::tensor::usm_ndarray &src,
352352
}
353353
}
354354

355-
// destination must be ample enough to accommodate all elements
356-
{
357-
auto dst_offsets = dst.get_minmax_offsets();
358-
size_t range =
359-
static_cast<size_t>(dst_offsets.second - dst_offsets.first);
360-
if ((range + 1) < (orthog_nelems * ind_nelems)) {
361-
throw py::value_error(
362-
"Destination array can not accommodate all the "
363-
"elements of source array.");
364-
}
365-
}
355+
dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(
356+
dst, orthog_nelems * ind_nelems);
366357

367358
int ind_sh_elems = std::max<int>(ind_nd, 1);
368359

@@ -641,17 +632,7 @@ usm_ndarray_put(const dpctl::tensor::usm_ndarray &dst,
641632
py::ssize_t dst_offset = py::ssize_t(0);
642633
py::ssize_t val_offset = py::ssize_t(0);
643634

644-
// destination must be ample enough to accommodate all possible elements
645-
{
646-
auto dst_offsets = dst.get_minmax_offsets();
647-
size_t range =
648-
static_cast<size_t>(dst_offsets.second - dst_offsets.first);
649-
if ((range + 1) < dst_nelems) {
650-
throw py::value_error(
651-
"Destination array can not accommodate all the "
652-
"elements of source array.");
653-
}
654-
}
635+
dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst, dst_nelems);
655636

656637
int dst_typenum = dst.get_typenum();
657638
int val_typenum = val.get_typenum();

dpctl/tensor/libtensor/source/linalg_functions/dot.cpp

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -248,19 +248,7 @@ py_dot(const dpctl::tensor::usm_ndarray &x1,
248248
throw py::value_error("dst shape and size mismatch");
249249
}
250250

251-
// ensure that dst is sufficiently ample
252-
auto dst_offsets = dst.get_minmax_offsets();
253-
// destination must be ample enough to accommodate all elements
254-
{
255-
size_t range =
256-
static_cast<size_t>(dst_offsets.second - dst_offsets.first);
257-
if (range + 1 < dst_nelems) {
258-
throw py::value_error(
259-
"Memory addressed by the destination array can not "
260-
"accommodate all the "
261-
"array elements.");
262-
}
263-
}
251+
dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst, dst_nelems);
264252

265253
auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
266254
// check that dst does not intersect with x1 or x2

0 commit comments

Comments
 (0)