Skip to content

Commit 7149317

Browse files
authored
Merge 6ce12ed into 4ad786f
2 parents 4ad786f + 6ce12ed commit 7149317

File tree

11 files changed

+402
-14
lines changed

11 files changed

+402
-14
lines changed

dpnp/backend/extensions/vm/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ set(_elementwise_sources
5252
${CMAKE_CURRENT_SOURCE_DIR}/log1p.cpp
5353
${CMAKE_CURRENT_SOURCE_DIR}/log2.cpp
5454
${CMAKE_CURRENT_SOURCE_DIR}/mul.cpp
55+
${CMAKE_CURRENT_SOURCE_DIR}/nextafter.cpp
5556
${CMAKE_CURRENT_SOURCE_DIR}/pow.cpp
5657
${CMAKE_CURRENT_SOURCE_DIR}/rint.cpp
5758
${CMAKE_CURRENT_SOURCE_DIR}/sin.cpp
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#include <oneapi/mkl.hpp>
27+
#include <sycl/sycl.hpp>
28+
29+
#include "dpctl4pybind11.hpp"
30+
31+
#include "common.hpp"
32+
#include "nextafter.hpp"
33+
34+
// include a local copy of elementwise common header from dpctl tensor:
35+
// dpctl/tensor/libtensor/source/elementwise_functions/elementwise_functions.hpp
36+
// TODO: replace by including dpctl header once available
37+
#include "../elementwise_functions/elementwise_functions.hpp"
38+
39+
// dpctl tensor headers
40+
#include "kernels/elementwise_functions/common.hpp"
41+
#include "utils/type_dispatch.hpp"
42+
#include "utils/type_utils.hpp"
43+
44+
namespace dpnp::extensions::vm
45+
{
46+
namespace ew_cmn_ns = dpctl::tensor::kernels::elementwise_common;
47+
namespace py = pybind11;
48+
namespace py_int = dpnp::extensions::py_internal;
49+
namespace td_ns = dpctl::tensor::type_dispatch;
50+
namespace tu_ns = dpctl::tensor::type_utils;
51+
52+
namespace impl
53+
{
54+
// OneMKL namespace with VM functions
55+
namespace mkl_vm = oneapi::mkl::vm;
56+
57+
/**
58+
* @brief A factory to define pairs of supported types for which
59+
* MKL VM library provides support in oneapi::mkl::vm::nextafter<T> function.
60+
*
61+
* @tparam T Type of input vectors `a` and `b` and of result vector `y`.
62+
*/
63+
template <typename T1, typename T2>
64+
struct OutputType
65+
{
66+
using value_type = typename std::disjunction<
67+
td_ns::BinaryTypeMapResultEntry<T1, double, T2, double, double>,
68+
td_ns::BinaryTypeMapResultEntry<T1, float, T2, float, float>,
69+
td_ns::DefaultResultEntry<void>>::result_type;
70+
};
71+
72+
template <typename T1, typename T2>
73+
static sycl::event
74+
nextafter_contig_impl(sycl::queue &exec_q,
75+
std::size_t in_n,
76+
const char *in_a,
77+
py::ssize_t a_offset,
78+
const char *in_b,
79+
py::ssize_t b_offset,
80+
char *out_y,
81+
py::ssize_t out_offset,
82+
const std::vector<sycl::event> &depends)
83+
{
84+
tu_ns::validate_type_for_device<T1>(exec_q);
85+
tu_ns::validate_type_for_device<T2>(exec_q);
86+
87+
if ((a_offset != 0) || (b_offset != 0) || (out_offset != 0)) {
88+
throw std::runtime_error("Arrays offsets have to be equals to 0");
89+
}
90+
91+
std::int64_t n = static_cast<std::int64_t>(in_n);
92+
const T1 *a = reinterpret_cast<const T1 *>(in_a);
93+
const T2 *b = reinterpret_cast<const T2 *>(in_b);
94+
95+
using resTy = typename OutputType<T1, T2>::value_type;
96+
resTy *y = reinterpret_cast<resTy *>(out_y);
97+
98+
return mkl_vm::nextafter(
99+
exec_q,
100+
n, // number of elements to be calculated
101+
a, // pointer `a` containing 1st input vector of size n
102+
b, // pointer `b` containing 2nd input vector of size n
103+
y, // pointer `y` to the output vector of size n
104+
depends);
105+
}
106+
107+
using ew_cmn_ns::binary_contig_impl_fn_ptr_t;
108+
using ew_cmn_ns::binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t;
109+
using ew_cmn_ns::binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t;
110+
using ew_cmn_ns::binary_strided_impl_fn_ptr_t;
111+
112+
static int output_typeid_vector[td_ns::num_types][td_ns::num_types];
113+
static binary_contig_impl_fn_ptr_t contig_dispatch_vector[td_ns::num_types]
114+
[td_ns::num_types];
115+
116+
MACRO_POPULATE_DISPATCH_TABLES(nextafter);
117+
} // namespace impl
118+
119+
void init_nextafter(py::module_ m)
120+
{
121+
using arrayT = dpctl::tensor::usm_ndarray;
122+
using event_vecT = std::vector<sycl::event>;
123+
124+
impl::populate_dispatch_tables();
125+
using impl::contig_dispatch_vector;
126+
using impl::output_typeid_vector;
127+
128+
auto nextafter_pyapi = [&](sycl::queue &exec_q, const arrayT &src1,
129+
const arrayT &src2, const arrayT &dst,
130+
const event_vecT &depends = {}) {
131+
return py_int::py_binary_ufunc(
132+
src1, src2, dst, exec_q, depends, output_typeid_vector,
133+
contig_dispatch_vector,
134+
// no support of strided implementation in OneMKL
135+
td_ns::NullPtrTable<impl::binary_strided_impl_fn_ptr_t>{},
136+
// no support of C-contig row with broadcasting in OneMKL
137+
td_ns::NullPtrTable<
138+
impl::
139+
binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{},
140+
td_ns::NullPtrTable<
141+
impl::
142+
binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{});
143+
};
144+
m.def(
145+
"_nextafter", nextafter_pyapi,
146+
"Call `nextafter` function from OneMKL VM library to return `dst` of "
147+
"elements containing the next representable floating-point values "
148+
"following the values from the elements of `src1` in the direction of "
149+
"the corresponding elements of `src2`",
150+
py::arg("sycl_queue"), py::arg("src1"), py::arg("src2"), py::arg("dst"),
151+
py::arg("depends") = py::list());
152+
153+
auto nextafter_need_to_call_pyapi = [&](sycl::queue &exec_q,
154+
const arrayT &src1,
155+
const arrayT &src2,
156+
const arrayT &dst) {
157+
return py_internal::need_to_call_binary_ufunc(exec_q, src1, src2, dst,
158+
output_typeid_vector,
159+
contig_dispatch_vector);
160+
};
161+
m.def("_mkl_nextafter_to_call", nextafter_need_to_call_pyapi,
162+
"Check input arguments to answer if `nextafter` function from "
163+
"OneMKL VM library can be used",
164+
py::arg("sycl_queue"), py::arg("src1"), py::arg("src2"),
165+
py::arg("dst"));
166+
}
167+
} // namespace dpnp::extensions::vm
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <pybind11/pybind11.h>
29+
30+
namespace py = pybind11;
31+
32+
namespace dpnp::extensions::vm
33+
{
34+
void init_nextafter(py::module_ m);
35+
} // namespace dpnp::extensions::vm

dpnp/backend/extensions/vm/vm_py.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
#include "log1p.hpp"
5656
#include "log2.hpp"
5757
#include "mul.hpp"
58+
#include "nextafter.hpp"
5859
#include "pow.hpp"
5960
#include "rint.hpp"
6061
#include "sin.hpp"
@@ -98,6 +99,7 @@ PYBIND11_MODULE(_vm_impl, m)
9899
vm_ns::init_log1p(m);
99100
vm_ns::init_log2(m);
100101
vm_ns::init_mul(m);
102+
vm_ns::init_nextafter(m);
101103
vm_ns::init_pow(m);
102104
vm_ns::init_rint(m);
103105
vm_ns::init_sin(m);

dpnp/dpnp_iface_mathematical.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@
111111
"modf",
112112
"multiply",
113113
"negative",
114+
"nextafter",
114115
"positive",
115116
"power",
116117
"prod",
@@ -2359,6 +2360,63 @@ def modf(x1, **kwargs):
23592360
)
23602361

23612362

2363+
_NEXTAFTER_DOCSTRING = """
2364+
Return the next floating-point value after `x1` towards `x2`, element-wise.
2365+
2366+
For full documentation refer to :obj:`numpy.nextafter`.
2367+
2368+
Parameters
2369+
----------
2370+
x1 : {dpnp.ndarray, usm_ndarray, scalar}
2371+
Values to find the next representable value of.
2372+
Both inputs `x1` and `x2` can not be scalars at the same time.
2373+
x2 : {dpnp.ndarray, usm_ndarray, scalar}
2374+
The direction where to look for the next representable value of `x1`.
2375+
Both inputs `x1` and `x2` can not be scalars at the same time.
2376+
out : {None, dpnp.ndarray, usm_ndarray}, optional
2377+
Output array to populate. Array must have the correct shape and
2378+
the expected data type.
2379+
Default: ``None``.
2380+
order : {"C", "F", "A", "K"}, optional
2381+
Output array, if parameter `out` is ``None``.
2382+
Default: ``"K"``.
2383+
2384+
Returns
2385+
-------
2386+
out : dpnp.ndarray
2387+
The next representable values of `x1` in the direction of `x2`. The data
2388+
type of the returned array is determined by the Type Promotion Rules.
2389+
2390+
Limitations
2391+
-----------
2392+
Parameters `where` and `subok` are supported with their default values.
2393+
Keyword argument `kwargs` is currently unsupported.
2394+
Otherwise ``NotImplementedError`` exception will be raised.
2395+
2396+
Examples
2397+
--------
2398+
>>> import dpnp as np
2399+
>>> eps = np.finfo(np.float64).eps
2400+
>>> np.nextafter(np.array(1), 2) == eps + 1
2401+
array(True)
2402+
2403+
>>> a = np.array([1, 2])
2404+
>>> b = np.array([2, 1])
2405+
>>> c = np.array([eps + 1, 2 - eps])
2406+
>>> np.nextafter(a, b) == c
2407+
array([ True, True])
2408+
"""
2409+
2410+
nextafter = DPNPBinaryFunc(
2411+
"nextafter",
2412+
ti._nextafter_result_type,
2413+
ti._nextafter,
2414+
_NEXTAFTER_DOCSTRING,
2415+
mkl_fn_to_call=vmi._mkl_nextafter_to_call,
2416+
mkl_impl_fn=vmi._nextafter,
2417+
)
2418+
2419+
23622420
_POSITIVE_DOCSTRING = """
23632421
Computes the numerical positive for each element `x_i` of input array `x`.
23642422

tests/skipped_tests.tbl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,6 @@ tests/test_umath.py::test_umaths[('ldexp', 'di')]
5757
tests/test_umath.py::test_umaths[('ldexp', 'dl')]
5858
tests/test_umath.py::test_umaths[('logaddexp2', 'ff')]
5959
tests/test_umath.py::test_umaths[('logaddexp2', 'dd')]
60-
tests/test_umath.py::test_umaths[('nextafter', 'ff')]
61-
tests/test_umath.py::test_umaths[('nextafter', 'dd')]
6260
tests/test_umath.py::test_umaths[('spacing', 'f')]
6361
tests/test_umath.py::test_umaths[('spacing', 'd')]
6462

@@ -217,11 +215,6 @@ tests/third_party/cupy/manipulation_tests/test_dims.py::TestInvalidBroadcast_par
217215
tests/third_party/cupy/math_tests/test_explog.py::TestExplog::test_logaddexp2
218216
tests/third_party/cupy/math_tests/test_explog.py::TestExplog::test_logaddexp2_infinities
219217

220-
tests/third_party/cupy/math_tests/test_floating.py::TestFloating::test_frexp
221-
tests/third_party/cupy/math_tests/test_floating.py::TestFloating::test_ldexp
222-
tests/third_party/cupy/math_tests/test_floating.py::TestFloating::test_nextafter_combination
223-
tests/third_party/cupy/math_tests/test_floating.py::TestFloating::test_nextafter_float
224-
225218
tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_nan_to_num
226219
tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_nan_to_num_negative
227220
tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_nan_to_num_for_old_numpy

tests/skipped_tests_gpu.tbl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,6 @@ tests/test_umath.py::test_umaths[('ldexp', 'di')]
3737
tests/test_umath.py::test_umaths[('ldexp', 'dl')]
3838
tests/test_umath.py::test_umaths[('logaddexp2', 'ff')]
3939
tests/test_umath.py::test_umaths[('logaddexp2', 'dd')]
40-
tests/test_umath.py::test_umaths[('nextafter', 'ff')]
41-
tests/test_umath.py::test_umaths[('nextafter', 'dd')]
4240
tests/test_umath.py::test_umaths[('spacing', 'f')]
4341
tests/test_umath.py::test_umaths[('spacing', 'd')]
4442

@@ -268,11 +266,6 @@ tests/third_party/cupy/manipulation_tests/test_dims.py::TestInvalidBroadcast_par
268266
tests/third_party/cupy/math_tests/test_explog.py::TestExplog::test_logaddexp2
269267
tests/third_party/cupy/math_tests/test_explog.py::TestExplog::test_logaddexp2_infinities
270268

271-
tests/third_party/cupy/math_tests/test_floating.py::TestFloating::test_frexp
272-
tests/third_party/cupy/math_tests/test_floating.py::TestFloating::test_ldexp
273-
tests/third_party/cupy/math_tests/test_floating.py::TestFloating::test_nextafter_combination
274-
tests/third_party/cupy/math_tests/test_floating.py::TestFloating::test_nextafter_float
275-
276269
tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_nan_to_num
277270
tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_nan_to_num_negative
278271
tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_nan_to_num_for_old_numpy

0 commit comments

Comments
 (0)