Skip to content

Commit 36ee455

Browse files
Address remarks
1 parent b89f41a commit 36ee455

File tree

4 files changed

+74
-45
lines changed

4 files changed

+74
-45
lines changed

dpnp/backend/extensions/ufunc/elementwise_functions/interpolate.cpp

Lines changed: 42 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@
3131
#include <pybind11/stl.h>
3232

3333
// dpctl tensor headers
34-
#include "utils/output_validation.hpp"
3534
#include "utils/type_dispatch.hpp"
35+
#include "utils/type_utils.hpp"
3636

3737
#include "kernels/elementwise_functions/interpolate.hpp"
3838

@@ -41,6 +41,7 @@
4141

4242
namespace py = pybind11;
4343
namespace td_ns = dpctl::tensor::type_dispatch;
44+
namespace type_utils = dpctl::tensor::type_utils;
4445

4546
using ext::common::value_type_of;
4647
using ext::validation::array_names;
@@ -57,18 +58,18 @@ template <typename T>
5758
using value_type_of_t = typename value_type_of<T>::type;
5859

5960
typedef sycl::event (*interpolate_fn_ptr_t)(sycl::queue &,
60-
const void *, // x
61-
const void *, // idx
62-
const void *, // xp
63-
const void *, // fp
64-
const void *, // left
65-
const void *, // right
66-
void *, // out
67-
std::size_t, // n
68-
std::size_t, // xp_size
61+
const void *, // x
62+
const void *, // idx
63+
const void *, // xp
64+
const void *, // fp
65+
const void *, // left
66+
const void *, // right
67+
void *, // out
68+
const std::size_t, // n
69+
const std::size_t, // xp_size
6970
const std::vector<sycl::event> &);
7071

71-
template <typename T>
72+
template <typename T, typename TIdx = std::int64_t>
7273
sycl::event interpolate_call(sycl::queue &exec_q,
7374
const void *vx,
7475
const void *vidx,
@@ -77,15 +78,15 @@ sycl::event interpolate_call(sycl::queue &exec_q,
7778
const void *vleft,
7879
const void *vright,
7980
void *vout,
80-
std::size_t n,
81-
std::size_t xp_size,
81+
const std::size_t n,
82+
const std::size_t xp_size,
8283
const std::vector<sycl::event> &depends)
8384
{
84-
using dpctl::tensor::type_utils::is_complex_v;
85+
using type_utils::is_complex_v;
8586
using TCoord = std::conditional_t<is_complex_v<T>, value_type_of_t<T>, T>;
8687

8788
const TCoord *x = static_cast<const TCoord *>(vx);
88-
const std::int64_t *idx = static_cast<const std::int64_t *>(vidx);
89+
const TIdx *idx = static_cast<const TIdx *>(vidx);
8990
const TCoord *xp = static_cast<const TCoord *>(vxp);
9091
const T *fp = static_cast<const T *>(vfp);
9192
const T *left = static_cast<const T *>(vleft);
@@ -114,6 +115,7 @@ void common_interpolate_checks(
114115

115116
auto array_types = td_ns::usm_ndarray_types();
116117
int x_type_id = array_types.typenum_to_lookup_id(x.get_typenum());
118+
int idx_type_id = array_types.typenum_to_lookup_id(idx.get_typenum());
117119
int xp_type_id = array_types.typenum_to_lookup_id(xp.get_typenum());
118120
int fp_type_id = array_types.typenum_to_lookup_id(fp.get_typenum());
119121
int out_type_id = array_types.typenum_to_lookup_id(out.get_typenum());
@@ -124,38 +126,41 @@ void common_interpolate_checks(
124126
if (fp_type_id != out_type_id) {
125127
throw py::value_error("fp and out must have the same dtype");
126128
}
129+
if (idx_type_id != static_cast<int>(td_ns::typenum_t::INT64)) {
130+
throw py::value_error("The type of idx must be int64");
131+
}
127132

128-
if (left) {
129-
const auto &l = left.value();
130-
names.insert({&l, "left"});
131-
if (l.get_ndim() != 0) {
133+
auto left_v = left ? &left.value() : nullptr;
134+
if (left_v) {
135+
names.insert({left_v, "left"});
136+
if (left_v->get_ndim() != 0) {
132137
throw py::value_error("left must be a zero-dimensional array");
133138
}
134139

135-
int left_type_id = array_types.typenum_to_lookup_id(l.get_typenum());
140+
int left_type_id =
141+
array_types.typenum_to_lookup_id(left_v->get_typenum());
136142
if (left_type_id != fp_type_id) {
137143
throw py::value_error(
138144
"left must have the same dtype as fp and out");
139145
}
140146
}
141147

142-
if (right) {
143-
const auto &r = right.value();
144-
names.insert({&r, "right"});
145-
if (r.get_ndim() != 0) {
148+
auto right_v = right ? &right.value() : nullptr;
149+
if (right_v) {
150+
names.insert({right_v, "right"});
151+
if (right_v->get_ndim() != 0) {
146152
throw py::value_error("right must be a zero-dimensional array");
147153
}
148154

149-
int right_type_id = array_types.typenum_to_lookup_id(r.get_typenum());
155+
int right_type_id =
156+
array_types.typenum_to_lookup_id(right_v->get_typenum());
150157
if (right_type_id != fp_type_id) {
151158
throw py::value_error(
152159
"right must have the same dtype as fp and out");
153160
}
154161
}
155162

156-
common_checks({&x, &xp, &fp, left ? &left.value() : nullptr,
157-
right ? &right.value() : nullptr},
158-
{&out}, names);
163+
common_checks({&x, &xp, &fp, left_v, right_v}, {&out}, names);
159164

160165
if (x.get_ndim() != 1 || xp.get_ndim() != 1 || fp.get_ndim() != 1 ||
161166
idx.get_ndim() != 1 || out.get_ndim() != 1)
@@ -167,6 +172,10 @@ void common_interpolate_checks(
167172
throw py::value_error("xp and fp must have the same size");
168173
}
169174

175+
if (xp.get_size() == 0) {
176+
throw py::value_error("array of sample points is empty");
177+
}
178+
170179
if (x.get_size() != out.get_size() || x.get_size() != idx.get_size()) {
171180
throw py::value_error("x, idx, and out must have the same size");
172181
}
@@ -183,12 +192,12 @@ std::pair<sycl::event, sycl::event>
183192
sycl::queue &exec_q,
184193
const std::vector<sycl::event> &depends)
185194
{
195+
common_interpolate_checks(x, idx, xp, fp, out, left, right);
196+
186197
if (x.get_size() == 0) {
187198
return {sycl::event(), sycl::event()};
188199
}
189200

190-
common_interpolate_checks(x, idx, xp, fp, out, left, right);
191-
192201
int out_typenum = out.get_typenum();
193202

194203
auto array_types = td_ns::usm_ndarray_types();
@@ -215,13 +224,10 @@ std::pair<sycl::event, sycl::event>
215224
args_ev = dpctl::utils::keep_args_alive(
216225
exec_q, {x, idx, xp, fp, out, left.value(), right.value()}, {ev});
217226
}
218-
else if (left) {
219-
args_ev = dpctl::utils::keep_args_alive(
220-
exec_q, {x, idx, xp, fp, out, left.value()}, {ev});
221-
}
222-
else if (right) {
227+
else if (left || right) {
223228
args_ev = dpctl::utils::keep_args_alive(
224-
exec_q, {x, idx, xp, fp, out, right.value()}, {ev});
229+
exec_q, {x, idx, xp, fp, out, left ? left.value() : right.value()},
230+
{ev});
225231
}
226232
else {
227233
args_ev =

dpnp/backend/kernels/elementwise_functions/interpolate.hpp

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,43 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2025, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
126
#pragma once
227

328
#include <sycl/sycl.hpp>
429
#include <vector>
530

631
#include "ext/common.hpp"
7-
#include "utils/type_utils.hpp"
8-
9-
namespace type_utils = dpctl::tensor::type_utils;
1032

1133
using ext::common::IsNan;
1234

1335
namespace dpnp::kernels::interpolate
1436
{
15-
template <typename TCoord, typename TValue>
37+
template <typename TCoord, typename TValue, typename TIdx = std::int64_t>
1638
sycl::event interpolate_impl(sycl::queue &q,
1739
const TCoord *x,
18-
const std::int64_t *idx,
40+
const TIdx *idx,
1941
const TCoord *xp,
2042
const TValue *fp,
2143
const TValue *left,

dpnp/dpnp_iface_mathematical.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,8 @@ def _validate_interp_param(param, name, exec_q, usm_type, dtype=None):
366366
)
367367
if dpu.get_execution_queue([exec_q, param.sycl_queue]) is None:
368368
raise ValueError(
369-
"input arrays and {name} must be on the same SYCL queue"
369+
f"input arrays and {name} must be allocated "
370+
"on the same SYCL queue"
370371
)
371372
if dtype is not None:
372373
param = param.astype(dtype)

dpnp/tests/test_sycl_queue.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1454,9 +1454,9 @@ def test_choose(device):
14541454

14551455

14561456
@pytest.mark.parametrize("device", valid_dev, ids=dev_ids)
1457-
@pytest.mark.parametrize("left", [None, dpnp.array(-1.0)])
1458-
@pytest.mark.parametrize("right", [None, dpnp.array(99.0)])
1459-
@pytest.mark.parametrize("period", [None, dpnp.array(180.0)])
1457+
@pytest.mark.parametrize("left", [None, -1.0])
1458+
@pytest.mark.parametrize("right", [None, 99.0])
1459+
@pytest.mark.parametrize("period", [None, 180.0])
14601460
def test_interp(device, left, right, period):
14611461
x = dpnp.linspace(0.1, 9.9, 20, device=device)
14621462
xp = dpnp.linspace(0.0, 10.0, 5, sycl_queue=x.sycl_queue)

0 commit comments

Comments
 (0)