31
31
#include < pybind11/stl.h>
32
32
33
33
// dpctl tensor headers
34
- #include " utils/output_validation.hpp"
35
34
#include " utils/type_dispatch.hpp"
35
+ #include " utils/type_utils.hpp"
36
36
37
37
#include " kernels/elementwise_functions/interpolate.hpp"
38
38
41
41
42
42
namespace py = pybind11;
43
43
namespace td_ns = dpctl::tensor::type_dispatch;
44
+ namespace type_utils = dpctl::tensor::type_utils;
44
45
45
46
using ext::common::value_type_of;
46
47
using ext::validation::array_names;
@@ -57,18 +58,18 @@ template <typename T>
57
58
using value_type_of_t = typename value_type_of<T>::type;
58
59
59
60
typedef sycl::event (*interpolate_fn_ptr_t )(sycl::queue &,
60
- const void *, // x
61
- const void *, // idx
62
- const void *, // xp
63
- const void *, // fp
64
- const void *, // left
65
- const void *, // right
66
- void *, // out
67
- std::size_t , // n
68
- std::size_t , // xp_size
61
+ const void *, // x
62
+ const void *, // idx
63
+ const void *, // xp
64
+ const void *, // fp
65
+ const void *, // left
66
+ const void *, // right
67
+ void *, // out
68
+ const std::size_t , // n
69
+ const std::size_t , // xp_size
69
70
const std::vector<sycl::event> &);
70
71
71
- template <typename T>
72
+ template <typename T, typename TIdx = std:: int64_t >
72
73
sycl::event interpolate_call (sycl::queue &exec_q,
73
74
const void *vx,
74
75
const void *vidx,
@@ -77,15 +78,15 @@ sycl::event interpolate_call(sycl::queue &exec_q,
77
78
const void *vleft,
78
79
const void *vright,
79
80
void *vout,
80
- std::size_t n,
81
- std::size_t xp_size,
81
+ const std::size_t n,
82
+ const std::size_t xp_size,
82
83
const std::vector<sycl::event> &depends)
83
84
{
84
- using dpctl::tensor:: type_utils::is_complex_v;
85
+ using type_utils::is_complex_v;
85
86
using TCoord = std::conditional_t <is_complex_v<T>, value_type_of_t <T>, T>;
86
87
87
88
const TCoord *x = static_cast <const TCoord *>(vx);
88
- const std:: int64_t *idx = static_cast <const std:: int64_t *>(vidx);
89
+ const TIdx *idx = static_cast <const TIdx *>(vidx);
89
90
const TCoord *xp = static_cast <const TCoord *>(vxp);
90
91
const T *fp = static_cast <const T *>(vfp);
91
92
const T *left = static_cast <const T *>(vleft);
@@ -114,6 +115,7 @@ void common_interpolate_checks(
114
115
115
116
auto array_types = td_ns::usm_ndarray_types ();
116
117
int x_type_id = array_types.typenum_to_lookup_id (x.get_typenum ());
118
+ int idx_type_id = array_types.typenum_to_lookup_id (idx.get_typenum ());
117
119
int xp_type_id = array_types.typenum_to_lookup_id (xp.get_typenum ());
118
120
int fp_type_id = array_types.typenum_to_lookup_id (fp.get_typenum ());
119
121
int out_type_id = array_types.typenum_to_lookup_id (out.get_typenum ());
@@ -124,38 +126,41 @@ void common_interpolate_checks(
124
126
if (fp_type_id != out_type_id) {
125
127
throw py::value_error (" fp and out must have the same dtype" );
126
128
}
129
+ if (idx_type_id != static_cast <int >(td_ns::typenum_t ::INT64)) {
130
+ throw py::value_error (" The type of idx must be int64" );
131
+ }
127
132
128
- if ( left) {
129
- const auto &l = left. value ();
130
- names.insert ({&l , " left" });
131
- if (l. get_ndim () != 0 ) {
133
+ auto left_v = left ? &left. value () : nullptr ;
134
+ if (left_v) {
135
+ names.insert ({left_v , " left" });
136
+ if (left_v-> get_ndim () != 0 ) {
132
137
throw py::value_error (" left must be a zero-dimensional array" );
133
138
}
134
139
135
- int left_type_id = array_types.typenum_to_lookup_id (l.get_typenum ());
140
+ int left_type_id =
141
+ array_types.typenum_to_lookup_id (left_v->get_typenum ());
136
142
if (left_type_id != fp_type_id) {
137
143
throw py::value_error (
138
144
" left must have the same dtype as fp and out" );
139
145
}
140
146
}
141
147
142
- if ( right) {
143
- const auto &r = right. value ();
144
- names.insert ({&r , " right" });
145
- if (r. get_ndim () != 0 ) {
148
+ auto right_v = right ? &right. value () : nullptr ;
149
+ if (right_v) {
150
+ names.insert ({right_v , " right" });
151
+ if (right_v-> get_ndim () != 0 ) {
146
152
throw py::value_error (" right must be a zero-dimensional array" );
147
153
}
148
154
149
- int right_type_id = array_types.typenum_to_lookup_id (r.get_typenum ());
155
+ int right_type_id =
156
+ array_types.typenum_to_lookup_id (right_v->get_typenum ());
150
157
if (right_type_id != fp_type_id) {
151
158
throw py::value_error (
152
159
" right must have the same dtype as fp and out" );
153
160
}
154
161
}
155
162
156
- common_checks ({&x, &xp, &fp, left ? &left.value () : nullptr ,
157
- right ? &right.value () : nullptr },
158
- {&out}, names);
163
+ common_checks ({&x, &xp, &fp, left_v, right_v}, {&out}, names);
159
164
160
165
if (x.get_ndim () != 1 || xp.get_ndim () != 1 || fp.get_ndim () != 1 ||
161
166
idx.get_ndim () != 1 || out.get_ndim () != 1 )
@@ -167,6 +172,10 @@ void common_interpolate_checks(
167
172
throw py::value_error (" xp and fp must have the same size" );
168
173
}
169
174
175
+ if (xp.get_size () == 0 ) {
176
+ throw py::value_error (" array of sample points is empty" );
177
+ }
178
+
170
179
if (x.get_size () != out.get_size () || x.get_size () != idx.get_size ()) {
171
180
throw py::value_error (" x, idx, and out must have the same size" );
172
181
}
@@ -183,12 +192,12 @@ std::pair<sycl::event, sycl::event>
183
192
sycl::queue &exec_q,
184
193
const std::vector<sycl::event> &depends)
185
194
{
195
+ common_interpolate_checks (x, idx, xp, fp, out, left, right);
196
+
186
197
if (x.get_size () == 0 ) {
187
198
return {sycl::event (), sycl::event ()};
188
199
}
189
200
190
- common_interpolate_checks (x, idx, xp, fp, out, left, right);
191
-
192
201
int out_typenum = out.get_typenum ();
193
202
194
203
auto array_types = td_ns::usm_ndarray_types ();
@@ -215,13 +224,10 @@ std::pair<sycl::event, sycl::event>
215
224
args_ev = dpctl::utils::keep_args_alive (
216
225
exec_q, {x, idx, xp, fp, out, left.value (), right.value ()}, {ev});
217
226
}
218
- else if (left) {
219
- args_ev = dpctl::utils::keep_args_alive (
220
- exec_q, {x, idx, xp, fp, out, left.value ()}, {ev});
221
- }
222
- else if (right) {
227
+ else if (left || right) {
223
228
args_ev = dpctl::utils::keep_args_alive (
224
- exec_q, {x, idx, xp, fp, out, right.value ()}, {ev});
229
+ exec_q, {x, idx, xp, fp, out, left ? left.value () : right.value ()},
230
+ {ev});
225
231
}
226
232
else {
227
233
args_ev =
0 commit comments