Skip to content

Commit 2accbe4

Browse files
Merge 8bb77fa into 5a9daef
2 parents 5a9daef + 8bb77fa commit 2accbe4

File tree

14 files changed

+979
-35
lines changed

14 files changed

+979
-35
lines changed

dpnp/backend/extensions/common/ext/common.hpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,27 @@ struct IsNan
106106
}
107107
};
108108

109+
template <typename T, bool hasValueType>
110+
struct value_type_of_impl;
111+
112+
template <typename T>
113+
struct value_type_of_impl<T, false>
114+
{
115+
using type = T;
116+
};
117+
118+
template <typename T>
119+
struct value_type_of_impl<T, true>
120+
{
121+
using type = typename T::value_type;
122+
};
123+
124+
template <typename T>
125+
using value_type_of = value_type_of_impl<T, type_utils::is_complex_v<T>>;
126+
127+
template <typename T>
128+
using value_type_of_t = typename value_type_of<T>::type;
129+
109130
size_t get_max_local_size(const sycl::device &device);
110131
size_t get_max_local_size(const sycl::device &device,
111132
int cpu_local_size_limit,

dpnp/backend/extensions/common/ext/details/validation_utils_internal.hpp

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,17 @@
2323
// THE POSSIBILITY OF SUCH DAMAGE.
2424
//*****************************************************************************
2525

26+
#include <pybind11/numpy.h>
27+
#include <pybind11/pybind11.h>
28+
29+
#include "ext/common.hpp"
30+
2631
#include "ext/validation_utils.hpp"
2732
#include "utils/memory_overlap.hpp"
2833

34+
namespace td_ns = dpctl::tensor::type_dispatch;
35+
namespace common = ext::common;
36+
2937
namespace ext::validation
3038
{
3139
inline sycl::queue get_queue(const std::vector<array_ptr> &inputs,
@@ -137,6 +145,15 @@ inline void check_num_dims(const array_ptr &arr,
137145
}
138146
}
139147

148+
inline void check_num_dims(const std::vector<array_ptr> &arrays,
149+
const size_t ndim,
150+
const array_names &names)
151+
{
152+
for (const auto &arr : arrays) {
153+
check_num_dims(arr, ndim, names);
154+
}
155+
}
156+
140157
inline void check_max_dims(const array_ptr &arr,
141158
const size_t max_ndim,
142159
const array_names &names)
@@ -163,6 +180,103 @@ inline void check_size_at_least(const array_ptr &arr,
163180
}
164181
}
165182

183+
inline void check_has_dtype(const array_ptr &arr,
184+
const typenum_t dtype,
185+
const array_names &names)
186+
{
187+
if (arr == nullptr) {
188+
return;
189+
}
190+
191+
auto array_types = td_ns::usm_ndarray_types();
192+
int array_type_id = array_types.typenum_to_lookup_id(arr->get_typenum());
193+
int expected_type_id = static_cast<int>(dtype);
194+
195+
if (array_type_id != expected_type_id) {
196+
py::dtype actual_dtype = common::dtype_from_typenum(array_type_id);
197+
py::dtype dtype_py = common::dtype_from_typenum(expected_type_id);
198+
199+
std::string msg = "Array " + name_of(arr, names) + " must have dtype " +
200+
std::string(py::str(dtype_py)) + ", but got " +
201+
std::string(py::str(actual_dtype));
202+
203+
throw py::value_error(msg);
204+
}
205+
}
206+
207+
inline void check_same_dtype(const array_ptr &arr1,
208+
const array_ptr &arr2,
209+
const array_names &names)
210+
{
211+
if (arr1 == nullptr || arr2 == nullptr) {
212+
return;
213+
}
214+
215+
auto array_types = td_ns::usm_ndarray_types();
216+
int first_type_id = array_types.typenum_to_lookup_id(arr1->get_typenum());
217+
int second_type_id = array_types.typenum_to_lookup_id(arr2->get_typenum());
218+
219+
if (first_type_id != second_type_id) {
220+
py::dtype first_dtype = common::dtype_from_typenum(first_type_id);
221+
py::dtype second_dtype = common::dtype_from_typenum(second_type_id);
222+
223+
std::string msg = "Arrays " + name_of(arr1, names) + " and " +
224+
name_of(arr2, names) +
225+
" must have the same dtype, but got " +
226+
std::string(py::str(first_dtype)) + " and " +
227+
std::string(py::str(second_dtype));
228+
229+
throw py::value_error(msg);
230+
}
231+
}
232+
233+
inline void check_same_dtype(const std::vector<array_ptr> &arrays,
234+
const array_names &names)
235+
{
236+
if (arrays.empty()) {
237+
return;
238+
}
239+
240+
const auto *first = arrays[0];
241+
for (size_t i = 1; i < arrays.size(); ++i) {
242+
check_same_dtype(first, arrays[i], names);
243+
}
244+
}
245+
246+
inline void check_same_size(const array_ptr &arr1,
247+
const array_ptr &arr2,
248+
const array_names &names)
249+
{
250+
if (arr1 == nullptr || arr2 == nullptr) {
251+
return;
252+
}
253+
254+
auto size1 = arr1->get_size();
255+
auto size2 = arr2->get_size();
256+
257+
if (size1 != size2) {
258+
std::string msg =
259+
"Arrays " + name_of(arr1, names) + " and " + name_of(arr2, names) +
260+
" must have the same size, but got " + std::to_string(size1) +
261+
" and " + std::to_string(size2);
262+
263+
throw py::value_error(msg);
264+
}
265+
}
266+
267+
inline void check_same_size(const std::vector<array_ptr> &arrays,
268+
const array_names &names)
269+
{
270+
if (arrays.empty()) {
271+
return;
272+
}
273+
274+
auto first = arrays[0];
275+
for (size_t i = 1; i < arrays.size(); ++i) {
276+
check_same_size(first, arrays[i], names);
277+
}
278+
}
279+
166280
inline void common_checks(const std::vector<array_ptr> &inputs,
167281
const std::vector<array_ptr> &outputs,
168282
const array_names &names)

dpnp/backend/extensions/common/ext/validation_utils.hpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ namespace ext::validation
3535
{
3636
using array_ptr = const dpctl::tensor::usm_ndarray *;
3737
using array_names = std::unordered_map<array_ptr, std::string>;
38+
using dpctl::tensor::type_dispatch::typenum_t;
3839

3940
std::string name_of(const array_ptr &arr, const array_names &names);
4041

@@ -56,6 +57,9 @@ void check_no_overlap(const std::vector<array_ptr> &inputs,
5657
void check_num_dims(const array_ptr &arr,
5758
const size_t ndim,
5859
const array_names &names);
60+
void check_num_dims(const std::vector<array_ptr> &arrays,
61+
const size_t ndim,
62+
const array_names &names);
5963
void check_max_dims(const array_ptr &arr,
6064
const size_t max_ndim,
6165
const array_names &names);
@@ -64,6 +68,20 @@ void check_size_at_least(const array_ptr &arr,
6468
const size_t size,
6569
const array_names &names);
6670

71+
void check_has_dtype(const array_ptr &arr,
72+
const typenum_t dtype,
73+
const array_names &names);
74+
75+
void check_same_dtype(const array_ptr &arr1,
76+
const array_ptr &arr2,
77+
const array_names &names);
78+
79+
void check_same_size(const array_ptr &arr1,
80+
const array_ptr &arr2,
81+
const array_names &names);
82+
void check_same_size(const std::vector<array_ptr> &arrays,
83+
const array_names &names);
84+
6785
void common_checks(const std::vector<array_ptr> &inputs,
6886
const std::vector<array_ptr> &outputs,
6987
const array_names &names);

dpnp/backend/extensions/ufunc/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ set(_elementwise_sources
3636
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/gcd.cpp
3737
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/heaviside.cpp
3838
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/i0.cpp
39+
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/interpolate.cpp
3940
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/lcm.cpp
4041
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/ldexp.cpp
4142
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/logaddexp2.cpp
@@ -69,6 +70,7 @@ endif()
6970
set_target_properties(${python_module_name} PROPERTIES CMAKE_POSITION_INDEPENDENT_CODE ON)
7071

7172
target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../)
73+
target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../common)
7274

7375
target_include_directories(${python_module_name} PUBLIC ${Dpctl_INCLUDE_DIR})
7476
target_include_directories(${python_module_name} PUBLIC ${Dpctl_TENSOR_INCLUDE_DIR})

dpnp/backend/extensions/ufunc/elementwise_functions/common.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include "gcd.hpp"
3737
#include "heaviside.hpp"
3838
#include "i0.hpp"
39+
#include "interpolate.hpp"
3940
#include "lcm.hpp"
4041
#include "ldexp.hpp"
4142
#include "logaddexp2.hpp"
@@ -64,6 +65,7 @@ void init_elementwise_functions(py::module_ m)
6465
init_gcd(m);
6566
init_heaviside(m);
6667
init_i0(m);
68+
init_interpolate(m);
6769
init_lcm(m);
6870
init_ldexp(m);
6971
init_logaddexp2(m);

0 commit comments

Comments
 (0)