Skip to content

Commit e4acd3e

Browse files
antonwolfyvtavana
andauthored
Add implementation of dpnp.fix function (#1971)
* Implement dpnp.fix() * Add tests to cover function * Update dpnp/dpnp_iface_mathematical.py Co-authored-by: vtavana <[email protected]> --------- Co-authored-by: vtavana <[email protected]>
1 parent 689aeeb commit e4acd3e

File tree

11 files changed

+349
-4
lines changed

11 files changed

+349
-4
lines changed

dpnp/backend/extensions/ufunc/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ set(_elementwise_sources
2727
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/common.cpp
2828
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/degrees.cpp
2929
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/fabs.cpp
30+
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/fix.cpp
3031
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/float_power.cpp
3132
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/fmax.cpp
3233
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/fmin.cpp

dpnp/backend/extensions/ufunc/elementwise_functions/common.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
#include "degrees.hpp"
2929
#include "fabs.hpp"
30+
#include "fix.hpp"
3031
#include "float_power.hpp"
3132
#include "fmax.hpp"
3233
#include "fmin.hpp"
@@ -45,6 +46,7 @@ void init_elementwise_functions(py::module_ m)
4546
{
4647
init_degrees(m);
4748
init_fabs(m);
49+
init_fix(m);
4850
init_float_power(m);
4951
init_fmax(m);
5052
init_fmin(m);
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#include <sycl/sycl.hpp>
27+
28+
#include "dpctl4pybind11.hpp"
29+
30+
#include "fix.hpp"
31+
#include "kernels/elementwise_functions/fix.hpp"
32+
#include "populate.hpp"
33+
34+
// include a local copy of elementwise common header from dpctl tensor:
35+
// dpctl/tensor/libtensor/source/elementwise_functions/elementwise_functions.hpp
36+
// TODO: replace by including dpctl header once available
37+
#include "../../elementwise_functions/elementwise_functions.hpp"
38+
39+
// dpctl tensor headers
40+
#include "kernels/elementwise_functions/common.hpp"
41+
#include "utils/type_dispatch.hpp"
42+
43+
namespace dpnp::extensions::ufunc
44+
{
45+
namespace py = pybind11;
46+
namespace py_int = dpnp::extensions::py_internal;
47+
48+
namespace impl
49+
{
50+
namespace ew_cmn_ns = dpctl::tensor::kernels::elementwise_common;
51+
namespace td_ns = dpctl::tensor::type_dispatch;
52+
53+
/**
54+
* @brief A factory to define pairs of supported types for which
55+
* sycl::fix<T> function is available.
56+
*
57+
* @tparam T Type of input vector `a` and of result vector `y`.
58+
*/
59+
template <typename T>
60+
struct OutputType
61+
{
62+
using value_type =
63+
typename std::disjunction<td_ns::TypeMapResultEntry<T, sycl::half>,
64+
td_ns::TypeMapResultEntry<T, float>,
65+
td_ns::TypeMapResultEntry<T, double>,
66+
td_ns::DefaultResultEntry<void>>::result_type;
67+
};
68+
69+
using dpnp::kernels::fix::FixFunctor;
70+
71+
template <typename argT,
72+
typename resT = argT,
73+
unsigned int vec_sz = 4,
74+
unsigned int n_vecs = 2,
75+
bool enable_sg_loadstore = true>
76+
using ContigFunctor = ew_cmn_ns::UnaryContigFunctor<argT,
77+
resT,
78+
FixFunctor<argT, resT>,
79+
vec_sz,
80+
n_vecs,
81+
enable_sg_loadstore>;
82+
83+
template <typename argTy, typename resTy, typename IndexerT>
84+
using StridedFunctor = ew_cmn_ns::
85+
UnaryStridedFunctor<argTy, resTy, IndexerT, FixFunctor<argTy, resTy>>;
86+
87+
using ew_cmn_ns::unary_contig_impl_fn_ptr_t;
88+
using ew_cmn_ns::unary_strided_impl_fn_ptr_t;
89+
90+
static unary_contig_impl_fn_ptr_t fix_contig_dispatch_vector[td_ns::num_types];
91+
static int fix_output_typeid_vector[td_ns::num_types];
92+
static unary_strided_impl_fn_ptr_t
93+
fix_strided_dispatch_vector[td_ns::num_types];
94+
95+
MACRO_POPULATE_DISPATCH_VECTORS(fix);
96+
} // namespace impl
97+
98+
void init_fix(py::module_ m)
99+
{
100+
using arrayT = dpctl::tensor::usm_ndarray;
101+
using event_vecT = std::vector<sycl::event>;
102+
{
103+
impl::populate_fix_dispatch_vectors();
104+
using impl::fix_contig_dispatch_vector;
105+
using impl::fix_output_typeid_vector;
106+
using impl::fix_strided_dispatch_vector;
107+
108+
auto fix_pyapi = [&](const arrayT &src, const arrayT &dst,
109+
sycl::queue &exec_q,
110+
const event_vecT &depends = {}) {
111+
return py_int::py_unary_ufunc(
112+
src, dst, exec_q, depends, fix_output_typeid_vector,
113+
fix_contig_dispatch_vector, fix_strided_dispatch_vector);
114+
};
115+
m.def("_fix", fix_pyapi, "", py::arg("src"), py::arg("dst"),
116+
py::arg("sycl_queue"), py::arg("depends") = py::list());
117+
118+
auto fix_result_type_pyapi = [&](const py::dtype &dtype) {
119+
return py_int::py_unary_ufunc_result_type(dtype,
120+
fix_output_typeid_vector);
121+
};
122+
m.def("_fix_result_type", fix_result_type_pyapi);
123+
}
124+
}
125+
} // namespace dpnp::extensions::ufunc
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <pybind11/pybind11.h>
29+
30+
namespace py = pybind11;
31+
32+
namespace dpnp::extensions::ufunc
33+
{
34+
void init_fix(py::module_ m);
35+
} // namespace dpnp::extensions::ufunc
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <sycl/sycl.hpp>
29+
30+
namespace dpnp::kernels::fix
31+
{
32+
template <typename argT, typename resT>
33+
struct FixFunctor
34+
{
35+
// is function constant for given argT
36+
using is_constant = typename std::false_type;
37+
// constant value, if constant
38+
// constexpr resT constant_value = resT{};
39+
// is function defined for sycl::vec
40+
using supports_vec = typename std::false_type;
41+
// do both argT and resT support subgroup store/load operation
42+
using supports_sg_loadstore = typename std::true_type;
43+
44+
resT operator()(const argT &x) const
45+
{
46+
return (x >= 0.0) ? sycl::floor(x) : sycl::ceil(x);
47+
}
48+
};
49+
} // namespace dpnp::kernels::fix

dpnp/dpnp_iface_mathematical.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@
9696
"divide",
9797
"ediff1d",
9898
"fabs",
99+
"fix",
99100
"float_power",
100101
"floor",
101102
"floor_divide",
@@ -533,6 +534,7 @@ def around(x, /, decimals=0, out=None):
533534
:obj:`dpnp.round` : Equivalent function; see for details.
534535
:obj:`dpnp.ndarray.round` : Equivalent function.
535536
:obj:`dpnp.rint` : Round elements of the array to the nearest integer.
537+
:obj:`dpnp.fix` : Round to nearest integer towards zero, element-wise.
536538
:obj:`dpnp.ceil` : Compute the ceiling of the input, element-wise.
537539
:obj:`dpnp.floor` : Return the floor of the input, element-wise.
538540
:obj:`dpnp.trunc` : Return the truncated value of the input, element-wise.
@@ -578,6 +580,8 @@ def around(x, /, decimals=0, out=None):
578580
--------
579581
:obj:`dpnp.floor` : Return the floor of the input, element-wise.
580582
:obj:`dpnp.trunc` : Return the truncated value of the input, element-wise.
583+
:obj:`dpnp.rint` : Round elements of the array to the nearest integer.
584+
:obj:`dpnp.fix` : Round to nearest integer towards zero, element-wise.
581585
582586
Examples
583587
--------
@@ -1371,6 +1375,64 @@ def ediff1d(x1, to_end=None, to_begin=None):
13711375
)
13721376

13731377

1378+
_FIX_DOCSTRING = """
1379+
Round to nearest integer towards zero.
1380+
1381+
Round an array of floats element-wise to nearest integer towards zero.
1382+
The rounded values are returned as floats.
1383+
1384+
For full documentation refer to :obj:`numpy.fix`.
1385+
1386+
Parameters
1387+
----------
1388+
x : {dpnp.ndarray, usm_ndarray}
1389+
An array of floats to be rounded.
1390+
out : {None, dpnp.ndarray, usm_ndarray}, optional
1391+
Output array to populate.
1392+
Array must have the correct shape and the expected data type.
1393+
Default: ``None``.
1394+
order : {"C", "F", "A", "K"}, optional
1395+
Memory layout of the newly output array, if parameter `out` is ``None``.
1396+
Default: ``"K"``.
1397+
1398+
Returns
1399+
-------
1400+
out : dpnp.ndarray
1401+
An array with the rounded values and with the same dimensions as the input.
1402+
The returned array will have the default floating point data type for the
1403+
device where `a` is allocated.
1404+
If `out` is ``None`` then a float array is returned with the rounded values.
1405+
Otherwise the result is stored there and the return value `out` is
1406+
a reference to that array.
1407+
1408+
See Also
1409+
--------
1410+
:obj:`dpnp.round` : Round to given number of decimals.
1411+
:obj:`dpnp.rint` : Round elements of the array to the nearest integer.
1412+
:obj:`dpnp.trunc` : Return the truncated value of the input, element-wise.
1413+
:obj:`dpnp.floor` : Return the floor of the input, element-wise.
1414+
:obj:`dpnp.ceil` : Return the ceiling of the input, element-wise.
1415+
1416+
Examples
1417+
--------
1418+
>>> import dpnp as np
1419+
>>> np.fix(np.array(3.14))
1420+
array(3.)
1421+
>>> np.fix(np.array(3))
1422+
array(3.)
1423+
>>> a = np.array([2.1, 2.9, -2.1, -2.9])
1424+
>>> np.fix(a)
1425+
array([ 2., 2., -2., -2.])
1426+
"""
1427+
1428+
fix = DPNPUnaryFunc(
1429+
"fix",
1430+
ufi._fix_result_type,
1431+
ufi._fix,
1432+
_FIX_DOCSTRING,
1433+
)
1434+
1435+
13741436
_FLOAT_POWER_DOCSTRING = """
13751437
Calculates `x1_i` raised to `x2_i` for each element `x1_i` of the input array
13761438
`x1` with the respective element `x2_i` of the input array `x2`.
@@ -1504,6 +1566,8 @@ def ediff1d(x1, to_end=None, to_begin=None):
15041566
--------
15051567
:obj:`dpnp.ceil` : Compute the ceiling of the input, element-wise.
15061568
:obj:`dpnp.trunc` : Return the truncated value of the input, element-wise.
1569+
:obj:`dpnp.rint` : Round elements of the array to the nearest integer.
1570+
:obj:`dpnp.fix` : Round to nearest integer towards zero, element-wise.
15071571
15081572
Notes
15091573
-----
@@ -3048,6 +3112,7 @@ def prod(
30483112
See Also
30493113
--------
30503114
:obj:`dpnp.round` : Evenly round to the given number of decimals.
3115+
:obj:`dpnp.fix` : Round to nearest integer towards zero, element-wise.
30513116
:obj:`dpnp.ceil` : Compute the ceiling of the input, element-wise.
30523117
:obj:`dpnp.floor` : Return the floor of the input, element-wise.
30533118
:obj:`dpnp.trunc` : Return the truncated value of the input, element-wise.
@@ -3103,6 +3168,7 @@ def prod(
31033168
:obj:`dpnp.ndarray.round` : Equivalent function.
31043169
:obj:`dpnp.rint` : Round elements of the array to the nearest integer.
31053170
:obj:`dpnp.ceil` : Compute the ceiling of the input, element-wise.
3171+
:obj:`dpnp.fix` : Round to nearest integer towards zero, element-wise.
31063172
:obj:`dpnp.floor` : Return the floor of the input, element-wise.
31073173
:obj:`dpnp.trunc` : Return the truncated value of the input, element-wise.
31083174
@@ -3536,6 +3602,8 @@ def trapz(y1, x1=None, dx=1.0, axis=-1):
35363602
--------
35373603
:obj:`dpnp.floor` : Round a number to the nearest integer toward minus infinity.
35383604
:obj:`dpnp.ceil` : Round a number to the nearest integer toward infinity.
3605+
:obj:`dpnp.rint` : Round elements of the array to the nearest integer.
3606+
:obj:`dpnp.fix` : Round to nearest integer towards zero, element-wise.
35393607
35403608
Examples
35413609
--------

tests/skipped_tests.tbl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,6 @@ tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_interp_inf_to_nan
228228
tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_heaviside
229229
tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_heaviside_nan_inf
230230

231-
tests/third_party/cupy/math_tests/test_rounding.py::TestRounding::test_fix
232-
233231
tests/third_party/cupy/random_tests/test_distributions.py::TestDistributionsBeta_param_0_{a_shape=(), b_shape=(), shape=(4, 3, 2)}::test_beta
234232
tests/third_party/cupy/random_tests/test_distributions.py::TestDistributionsBeta_param_1_{a_shape=(), b_shape=(), shape=(3, 2)}::test_beta
235233
tests/third_party/cupy/random_tests/test_distributions.py::TestDistributionsBeta_param_2_{a_shape=(), b_shape=(3, 2), shape=(4, 3, 2)}::test_beta

tests/skipped_tests_gpu.tbl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,8 +282,6 @@ tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_interp_inf_to_nan
282282
tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_heaviside
283283
tests/third_party/cupy/math_tests/test_misc.py::TestMisc::test_heaviside_nan_inf
284284

285-
tests/third_party/cupy/math_tests/test_rounding.py::TestRounding::test_fix
286-
287285
tests/third_party/cupy/random_tests/test_distributions.py::TestDistributionsBeta_param_0_{a_shape=(), b_shape=(), shape=(4, 3, 2)}::test_beta
288286
tests/third_party/cupy/random_tests/test_distributions.py::TestDistributionsBeta_param_1_{a_shape=(), b_shape=(), shape=(3, 2)}::test_beta
289287
tests/third_party/cupy/random_tests/test_distributions.py::TestDistributionsBeta_param_2_{a_shape=(), b_shape=(3, 2), shape=(4, 3, 2)}::test_beta

0 commit comments

Comments
 (0)