Skip to content

Commit a53090a

Browse files
committed
impl_elementwise_func_round
1 parent 73a2b68 commit a53090a

File tree

5 files changed

+528
-3
lines changed

5 files changed

+528
-3
lines changed

dpctl/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@
126126
pow,
127127
proj,
128128
real,
129+
round,
129130
sin,
130131
sqrt,
131132
square,
@@ -243,6 +244,7 @@
243244
"pow",
244245
"proj",
245246
"real",
247+
"round",
246248
"sin",
247249
"sqrt",
248250
"square",

dpctl/tensor/_elementwise_funcs.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -956,7 +956,29 @@
956956
# FIXME: implement B22
957957

958958
# U28: ==== ROUND (x)
959-
# FIXME: implement U28
959+
_round_docstring = """
960+
round(x, out=None, order='K')
961+
962+
Computes cosine for each element `x_i` for input array `x`.
963+
964+
Args:
965+
x (usm_ndarray):
966+
Input array, expected to have numeric data type.
967+
out ({None, usm_ndarray}, optional):
968+
Output array to populate.
969+
Array have the correct shape and the expected data type.
970+
order ("C","F","A","K", optional):
971+
Memory layout of the newly output array, if parameter `out` is `None`.
972+
Default: "K".
973+
Returns:
974+
usm_narray:
975+
An array containing the element-wise rounded value. The data type
976+
of the returned array is determined by the Type Promotion Rules.
977+
"""
978+
979+
round = UnaryElementwiseFunc(
980+
"round", ti._round_result_type, ti._round, _round_docstring
981+
)
960982

961983
# U29: ==== SIGN (x)
962984
# FIXME: implement U29
Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
//=== round.hpp - Unary function ROUND ------ *-C++-*--/===//
2+
//
3+
// Data Parallel Control (dpctl)
4+
//
5+
// Copyright 2020-2023 Intel Corporation
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// Unless required by applicable law or agreed to in writing, software
14+
// distributed under the License is distributed on an "AS IS" BASIS,
15+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
// See the License for the specific language governing permissions and
17+
// limitations under the License.
18+
//
19+
//===---------------------------------------------------------------------===//
20+
///
21+
/// \file
22+
/// This file defines kernels for elementwise evaluation of ROUND(x) function.
23+
//===---------------------------------------------------------------------===//
24+
25+
#pragma once
26+
#include <CL/sycl.hpp>
27+
#include <cmath>
28+
#include <cstddef>
29+
#include <cstdint>
30+
#include <type_traits>
31+
32+
#include "kernels/elementwise_functions/common.hpp"
33+
34+
#include "utils/offset_utils.hpp"
35+
#include "utils/type_dispatch.hpp"
36+
#include "utils/type_utils.hpp"
37+
#include <pybind11/pybind11.h>
38+
39+
namespace dpctl
40+
{
41+
namespace tensor
42+
{
43+
namespace kernels
44+
{
45+
namespace round
46+
{
47+
48+
namespace py = pybind11;
49+
namespace td_ns = dpctl::tensor::type_dispatch;
50+
51+
using dpctl::tensor::type_utils::is_complex;
52+
53+
template <typename argT, typename resT> struct RoundFunctor
54+
{
55+
56+
// is function constant for given argT
57+
using is_constant = typename std::false_type;
58+
// constant value, if constant
59+
// constexpr resT constant_value = resT{};
60+
// is function defined for sycl::vec
61+
using supports_vec = typename std::false_type;
62+
// do both argTy and resTy support sugroup store/load operation
63+
using supports_sg_loadstore = typename std::negation<
64+
std::disjunction<is_complex<resT>, is_complex<argT>>>;
65+
66+
resT operator()(const argT &in)
67+
{
68+
if constexpr (std::is_integral_v<argT>) {
69+
return in;
70+
}
71+
else if constexpr (is_complex<argT>::value) {
72+
using realT = typename argT::value_type;
73+
74+
const realT x = std::real(in);
75+
const realT y = std::imag(in);
76+
realT x_round, y_round;
77+
if (std::abs(x - std::floor(x)) == std::abs(x - std::ceil(x))) {
78+
x_round = static_cast<int>(std::ceil(x)) % 2 == 0
79+
? std::ceil(x)
80+
: std::floor(x);
81+
}
82+
else {
83+
x_round = std::round(x);
84+
}
85+
if (std::abs(y - std::floor(y)) == std::abs(y - std::ceil(y))) {
86+
y_round = static_cast<int>(std::ceil(y)) % 2 == 0
87+
? std::ceil(y)
88+
: std::floor(y);
89+
}
90+
else {
91+
y_round = std::round(y);
92+
}
93+
return resT{x_round, y_round};
94+
}
95+
else {
96+
if (in == 0) {
97+
return in;
98+
}
99+
else if (std::abs(in - std::floor(in)) ==
100+
std::abs(in - std::ceil(in))) {
101+
return static_cast<int>(std::ceil(in)) % 2 == 0
102+
? std::ceil(in)
103+
: std::floor(in);
104+
}
105+
else {
106+
return std::round(in);
107+
}
108+
}
109+
}
110+
};
111+
112+
template <typename argTy,
113+
typename resTy = argTy,
114+
unsigned int vec_sz = 4,
115+
unsigned int n_vecs = 2>
116+
using RoundContigFunctor =
117+
elementwise_common::UnaryContigFunctor<argTy,
118+
resTy,
119+
RoundFunctor<argTy, resTy>,
120+
vec_sz,
121+
n_vecs>;
122+
123+
template <typename argTy, typename resTy, typename IndexerT>
124+
using RoundStridedFunctor = elementwise_common::
125+
UnaryStridedFunctor<argTy, resTy, IndexerT, RoundFunctor<argTy, resTy>>;
126+
127+
template <typename T> struct RoundOutputType
128+
{
129+
using value_type = typename std::disjunction< // disjunction is C++17
130+
// feature, supported by DPC++
131+
td_ns::TypeMapResultEntry<T, std::uint8_t>,
132+
td_ns::TypeMapResultEntry<T, std::uint16_t>,
133+
td_ns::TypeMapResultEntry<T, std::uint32_t>,
134+
td_ns::TypeMapResultEntry<T, std::uint64_t>,
135+
td_ns::TypeMapResultEntry<T, std::int8_t>,
136+
td_ns::TypeMapResultEntry<T, std::int16_t>,
137+
td_ns::TypeMapResultEntry<T, std::int32_t>,
138+
td_ns::TypeMapResultEntry<T, std::int64_t>,
139+
td_ns::TypeMapResultEntry<T, sycl::half>,
140+
td_ns::TypeMapResultEntry<T, float>,
141+
td_ns::TypeMapResultEntry<T, double>,
142+
td_ns::TypeMapResultEntry<T, std::complex<float>>,
143+
td_ns::TypeMapResultEntry<T, std::complex<double>>,
144+
td_ns::DefaultResultEntry<void>>::result_type;
145+
};
146+
147+
template <typename T1, typename T2, unsigned int vec_sz, unsigned int n_vecs>
148+
class round_contig_kernel;
149+
150+
template <typename argTy>
151+
sycl::event round_contig_impl(sycl::queue exec_q,
152+
size_t nelems,
153+
const char *arg_p,
154+
char *res_p,
155+
const std::vector<sycl::event> &depends = {})
156+
{
157+
return elementwise_common::unary_contig_impl<
158+
argTy, RoundOutputType, RoundContigFunctor, round_contig_kernel>(
159+
exec_q, nelems, arg_p, res_p, depends);
160+
}
161+
162+
template <typename fnT, typename T> struct RoundContigFactory
163+
{
164+
fnT get()
165+
{
166+
if constexpr (std::is_same_v<typename RoundOutputType<T>::value_type,
167+
void>) {
168+
fnT fn = nullptr;
169+
return fn;
170+
}
171+
else {
172+
fnT fn = round_contig_impl<T>;
173+
return fn;
174+
}
175+
}
176+
};
177+
178+
template <typename fnT, typename T> struct RoundTypeMapFactory
179+
{
180+
/*! @brief get typeid for output type of sycl::round(T x) */
181+
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
182+
{
183+
using rT = typename RoundOutputType<T>::value_type;
184+
return td_ns::GetTypeid<rT>{}.get();
185+
}
186+
};
187+
188+
template <typename T1, typename T2, typename T3> class round_strided_kernel;
189+
190+
template <typename argTy>
191+
sycl::event
192+
round_strided_impl(sycl::queue exec_q,
193+
size_t nelems,
194+
int nd,
195+
const py::ssize_t *shape_and_strides,
196+
const char *arg_p,
197+
py::ssize_t arg_offset,
198+
char *res_p,
199+
py::ssize_t res_offset,
200+
const std::vector<sycl::event> &depends,
201+
const std::vector<sycl::event> &additional_depends)
202+
{
203+
return elementwise_common::unary_strided_impl<
204+
argTy, RoundOutputType, RoundStridedFunctor, round_strided_kernel>(
205+
exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p,
206+
res_offset, depends, additional_depends);
207+
}
208+
209+
template <typename fnT, typename T> struct RoundStridedFactory
210+
{
211+
fnT get()
212+
{
213+
if constexpr (std::is_same_v<typename RoundOutputType<T>::value_type,
214+
void>) {
215+
fnT fn = nullptr;
216+
return fn;
217+
}
218+
else {
219+
fnT fn = round_strided_impl<T>;
220+
return fn;
221+
}
222+
}
223+
};
224+
225+
} // namespace round
226+
} // namespace kernels
227+
} // namespace tensor
228+
} // namespace dpctl

dpctl/tensor/libtensor/source/elementwise_functions.cpp

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
#include "kernels/elementwise_functions/pow.hpp"
6666
#include "kernels/elementwise_functions/proj.hpp"
6767
#include "kernels/elementwise_functions/real.hpp"
68+
#include "kernels/elementwise_functions/round.hpp"
6869
#include "kernels/elementwise_functions/sin.hpp"
6970
#include "kernels/elementwise_functions/sqrt.hpp"
7071
#include "kernels/elementwise_functions/square.hpp"
@@ -1627,7 +1628,37 @@ namespace impl
16271628
// U28: ==== ROUND (x)
16281629
namespace impl
16291630
{
1630-
// FIXME: add code for U28
1631+
1632+
namespace round_fn_ns = dpctl::tensor::kernels::round;
1633+
1634+
static unary_contig_impl_fn_ptr_t
1635+
round_contig_dispatch_vector[td_ns::num_types];
1636+
static int round_output_typeid_vector[td_ns::num_types];
1637+
static unary_strided_impl_fn_ptr_t
1638+
round_strided_dispatch_vector[td_ns::num_types];
1639+
1640+
void populate_round_dispatch_vectors(void)
1641+
{
1642+
using namespace td_ns;
1643+
namespace fn_ns = round_fn_ns;
1644+
1645+
using fn_ns::RoundContigFactory;
1646+
DispatchVectorBuilder<unary_contig_impl_fn_ptr_t, RoundContigFactory,
1647+
num_types>
1648+
dvb1;
1649+
dvb1.populate_dispatch_vector(round_contig_dispatch_vector);
1650+
1651+
using fn_ns::RoundStridedFactory;
1652+
DispatchVectorBuilder<unary_strided_impl_fn_ptr_t, RoundStridedFactory,
1653+
num_types>
1654+
dvb2;
1655+
dvb2.populate_dispatch_vector(round_strided_dispatch_vector);
1656+
1657+
using fn_ns::RoundTypeMapFactory;
1658+
DispatchVectorBuilder<int, RoundTypeMapFactory, num_types> dvb3;
1659+
dvb3.populate_dispatch_vector(round_output_typeid_vector);
1660+
}
1661+
16311662
} // namespace impl
16321663

16331664
// U29: ==== SIGN (x)
@@ -3029,7 +3060,27 @@ void init_elementwise_functions(py::module_ m)
30293060
// FIXME:
30303061

30313062
// U28: ==== ROUND (x)
3032-
// FIXME:
3063+
{
3064+
impl::populate_round_dispatch_vectors();
3065+
using impl::round_contig_dispatch_vector;
3066+
using impl::round_output_typeid_vector;
3067+
using impl::round_strided_dispatch_vector;
3068+
3069+
auto round_pyapi = [&](arrayT src, arrayT dst, sycl::queue exec_q,
3070+
const event_vecT &depends = {}) {
3071+
return py_unary_ufunc(
3072+
src, dst, exec_q, depends, round_output_typeid_vector,
3073+
round_contig_dispatch_vector, round_strided_dispatch_vector);
3074+
};
3075+
m.def("_round", round_pyapi, "", py::arg("src"), py::arg("dst"),
3076+
py::arg("sycl_queue"), py::arg("depends") = py::list());
3077+
3078+
auto round_result_type_pyapi = [&](py::dtype dtype) {
3079+
return py_unary_ufunc_result_type(dtype,
3080+
round_output_typeid_vector);
3081+
};
3082+
m.def("_round_result_type", round_result_type_pyapi);
3083+
}
30333084

30343085
// U29: ==== SIGN (x)
30353086
// FIXME:

0 commit comments

Comments
 (0)