Skip to content

Commit aa7de12

Browse files
authored
Merge pull request #1299 from IntelPython/impl_elementwise_func_round
impl_elementwise_func_round
2 parents 7e798b4 + fb35be7 commit aa7de12

File tree

5 files changed

+506
-3
lines changed

5 files changed

+506
-3
lines changed

dpctl/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@
135135
pow,
136136
proj,
137137
real,
138+
round,
138139
sin,
139140
sinh,
140141
sqrt,
@@ -264,6 +265,7 @@
264265
"logaddexp",
265266
"proj",
266267
"real",
268+
"round",
267269
"sin",
268270
"sinh",
269271
"sqrt",

dpctl/tensor/_elementwise_funcs.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1152,7 +1152,30 @@
11521152
# FIXME: implement B22
11531153

11541154
# U28: ==== ROUND (x)
1155-
# FIXME: implement U28
1155+
_round_docstring = """
1156+
round(x, out=None, order='K')
1157+
1158+
Rounds each element `x_i` of the input array `x` to
1159+
the nearest integer-valued number.
1160+
1161+
Args:
1162+
x (usm_ndarray):
1163+
Input array, expected to have numeric data type.
1164+
out ({None, usm_ndarray}, optional):
1165+
Output array to populate.
1166+
Array have the correct shape and the expected data type.
1167+
order ("C","F","A","K", optional):
1168+
Memory layout of the newly output array, if parameter `out` is `None`.
1169+
Default: "K".
1170+
Returns:
1171+
usm_narray:
1172+
An array containing the element-wise rounded value. The data type
1173+
of the returned array is determined by the Type Promotion Rules.
1174+
"""
1175+
1176+
round = UnaryElementwiseFunc(
1177+
"round", ti._round_result_type, ti._round, _round_docstring
1178+
)
11561179

11571180
# U29: ==== SIGN (x)
11581181
# FIXME: implement U29
Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
//=== round.hpp - Unary function ROUND ------ *-C++-*--/===//
2+
//
3+
// Data Parallel Control (dpctl)
4+
//
5+
// Copyright 2020-2023 Intel Corporation
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// Unless required by applicable law or agreed to in writing, software
14+
// distributed under the License is distributed on an "AS IS" BASIS,
15+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
// See the License for the specific language governing permissions and
17+
// limitations under the License.
18+
//
19+
//===---------------------------------------------------------------------===//
20+
///
21+
/// \file
22+
/// This file defines kernels for elementwise evaluation of ROUND(x) function.
23+
//===---------------------------------------------------------------------===//
24+
25+
#pragma once
26+
#include <CL/sycl.hpp>
27+
#include <cmath>
28+
#include <cstddef>
29+
#include <cstdint>
30+
#include <type_traits>
31+
32+
#include "kernels/elementwise_functions/common.hpp"
33+
34+
#include "utils/offset_utils.hpp"
35+
#include "utils/type_dispatch.hpp"
36+
#include "utils/type_utils.hpp"
37+
#include <pybind11/pybind11.h>
38+
39+
namespace dpctl
40+
{
41+
namespace tensor
42+
{
43+
namespace kernels
44+
{
45+
namespace round
46+
{
47+
48+
namespace py = pybind11;
49+
namespace td_ns = dpctl::tensor::type_dispatch;
50+
51+
using dpctl::tensor::type_utils::is_complex;
52+
53+
template <typename argT, typename resT> struct RoundFunctor
54+
{
55+
56+
// is function constant for given argT
57+
using is_constant = typename std::false_type;
58+
// constant value, if constant
59+
// constexpr resT constant_value = resT{};
60+
// is function defined for sycl::vec
61+
using supports_vec = typename std::false_type;
62+
// do both argTy and resTy support sugroup store/load operation
63+
using supports_sg_loadstore = typename std::negation<
64+
std::disjunction<is_complex<resT>, is_complex<argT>>>;
65+
66+
resT operator()(const argT &in)
67+
{
68+
69+
if constexpr (std::is_integral_v<argT>) {
70+
return in;
71+
}
72+
else if constexpr (is_complex<argT>::value) {
73+
using realT = typename argT::value_type;
74+
return resT{round_func<realT>(std::real(in)),
75+
round_func<realT>(std::imag(in))};
76+
}
77+
else {
78+
return round_func<argT>(in);
79+
}
80+
}
81+
82+
private:
83+
template <typename T> T round_func(const T &input) const
84+
{
85+
return std::rint(input);
86+
}
87+
};
88+
89+
template <typename argTy,
90+
typename resTy = argTy,
91+
unsigned int vec_sz = 4,
92+
unsigned int n_vecs = 2>
93+
using RoundContigFunctor =
94+
elementwise_common::UnaryContigFunctor<argTy,
95+
resTy,
96+
RoundFunctor<argTy, resTy>,
97+
vec_sz,
98+
n_vecs>;
99+
100+
template <typename argTy, typename resTy, typename IndexerT>
101+
using RoundStridedFunctor = elementwise_common::
102+
UnaryStridedFunctor<argTy, resTy, IndexerT, RoundFunctor<argTy, resTy>>;
103+
104+
template <typename T> struct RoundOutputType
105+
{
106+
using value_type = typename std::disjunction< // disjunction is C++17
107+
// feature, supported by DPC++
108+
td_ns::TypeMapResultEntry<T, std::uint8_t>,
109+
td_ns::TypeMapResultEntry<T, std::uint16_t>,
110+
td_ns::TypeMapResultEntry<T, std::uint32_t>,
111+
td_ns::TypeMapResultEntry<T, std::uint64_t>,
112+
td_ns::TypeMapResultEntry<T, std::int8_t>,
113+
td_ns::TypeMapResultEntry<T, std::int16_t>,
114+
td_ns::TypeMapResultEntry<T, std::int32_t>,
115+
td_ns::TypeMapResultEntry<T, std::int64_t>,
116+
td_ns::TypeMapResultEntry<T, sycl::half>,
117+
td_ns::TypeMapResultEntry<T, float>,
118+
td_ns::TypeMapResultEntry<T, double>,
119+
td_ns::TypeMapResultEntry<T, std::complex<float>>,
120+
td_ns::TypeMapResultEntry<T, std::complex<double>>,
121+
td_ns::DefaultResultEntry<void>>::result_type;
122+
};
123+
124+
template <typename T1, typename T2, unsigned int vec_sz, unsigned int n_vecs>
125+
class round_contig_kernel;
126+
127+
template <typename argTy>
128+
sycl::event round_contig_impl(sycl::queue exec_q,
129+
size_t nelems,
130+
const char *arg_p,
131+
char *res_p,
132+
const std::vector<sycl::event> &depends = {})
133+
{
134+
return elementwise_common::unary_contig_impl<
135+
argTy, RoundOutputType, RoundContigFunctor, round_contig_kernel>(
136+
exec_q, nelems, arg_p, res_p, depends);
137+
}
138+
139+
template <typename fnT, typename T> struct RoundContigFactory
140+
{
141+
fnT get()
142+
{
143+
if constexpr (std::is_same_v<typename RoundOutputType<T>::value_type,
144+
void>) {
145+
fnT fn = nullptr;
146+
return fn;
147+
}
148+
else {
149+
fnT fn = round_contig_impl<T>;
150+
return fn;
151+
}
152+
}
153+
};
154+
155+
template <typename fnT, typename T> struct RoundTypeMapFactory
156+
{
157+
/*! @brief get typeid for output type of sycl::round(T x) */
158+
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
159+
{
160+
using rT = typename RoundOutputType<T>::value_type;
161+
return td_ns::GetTypeid<rT>{}.get();
162+
}
163+
};
164+
165+
template <typename T1, typename T2, typename T3> class round_strided_kernel;
166+
167+
template <typename argTy>
168+
sycl::event
169+
round_strided_impl(sycl::queue exec_q,
170+
size_t nelems,
171+
int nd,
172+
const py::ssize_t *shape_and_strides,
173+
const char *arg_p,
174+
py::ssize_t arg_offset,
175+
char *res_p,
176+
py::ssize_t res_offset,
177+
const std::vector<sycl::event> &depends,
178+
const std::vector<sycl::event> &additional_depends)
179+
{
180+
return elementwise_common::unary_strided_impl<
181+
argTy, RoundOutputType, RoundStridedFunctor, round_strided_kernel>(
182+
exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p,
183+
res_offset, depends, additional_depends);
184+
}
185+
186+
template <typename fnT, typename T> struct RoundStridedFactory
187+
{
188+
fnT get()
189+
{
190+
if constexpr (std::is_same_v<typename RoundOutputType<T>::value_type,
191+
void>) {
192+
fnT fn = nullptr;
193+
return fn;
194+
}
195+
else {
196+
fnT fn = round_strided_impl<T>;
197+
return fn;
198+
}
199+
}
200+
};
201+
202+
} // namespace round
203+
} // namespace kernels
204+
} // namespace tensor
205+
} // namespace dpctl

dpctl/tensor/libtensor/source/elementwise_functions.cpp

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
#include "kernels/elementwise_functions/pow.hpp"
7575
#include "kernels/elementwise_functions/proj.hpp"
7676
#include "kernels/elementwise_functions/real.hpp"
77+
#include "kernels/elementwise_functions/round.hpp"
7778
#include "kernels/elementwise_functions/sin.hpp"
7879
#include "kernels/elementwise_functions/sinh.hpp"
7980
#include "kernels/elementwise_functions/sqrt.hpp"
@@ -1877,7 +1878,37 @@ namespace impl
18771878
// U28: ==== ROUND (x)
18781879
namespace impl
18791880
{
1880-
// FIXME: add code for U28
1881+
1882+
namespace round_fn_ns = dpctl::tensor::kernels::round;
1883+
1884+
static unary_contig_impl_fn_ptr_t
1885+
round_contig_dispatch_vector[td_ns::num_types];
1886+
static int round_output_typeid_vector[td_ns::num_types];
1887+
static unary_strided_impl_fn_ptr_t
1888+
round_strided_dispatch_vector[td_ns::num_types];
1889+
1890+
void populate_round_dispatch_vectors(void)
1891+
{
1892+
using namespace td_ns;
1893+
namespace fn_ns = round_fn_ns;
1894+
1895+
using fn_ns::RoundContigFactory;
1896+
DispatchVectorBuilder<unary_contig_impl_fn_ptr_t, RoundContigFactory,
1897+
num_types>
1898+
dvb1;
1899+
dvb1.populate_dispatch_vector(round_contig_dispatch_vector);
1900+
1901+
using fn_ns::RoundStridedFactory;
1902+
DispatchVectorBuilder<unary_strided_impl_fn_ptr_t, RoundStridedFactory,
1903+
num_types>
1904+
dvb2;
1905+
dvb2.populate_dispatch_vector(round_strided_dispatch_vector);
1906+
1907+
using fn_ns::RoundTypeMapFactory;
1908+
DispatchVectorBuilder<int, RoundTypeMapFactory, num_types> dvb3;
1909+
dvb3.populate_dispatch_vector(round_output_typeid_vector);
1910+
}
1911+
18811912
} // namespace impl
18821913

18831914
// U29: ==== SIGN (x)
@@ -3580,7 +3611,27 @@ void init_elementwise_functions(py::module_ m)
35803611
// FIXME:
35813612

35823613
// U28: ==== ROUND (x)
3583-
// FIXME:
3614+
{
3615+
impl::populate_round_dispatch_vectors();
3616+
using impl::round_contig_dispatch_vector;
3617+
using impl::round_output_typeid_vector;
3618+
using impl::round_strided_dispatch_vector;
3619+
3620+
auto round_pyapi = [&](arrayT src, arrayT dst, sycl::queue exec_q,
3621+
const event_vecT &depends = {}) {
3622+
return py_unary_ufunc(
3623+
src, dst, exec_q, depends, round_output_typeid_vector,
3624+
round_contig_dispatch_vector, round_strided_dispatch_vector);
3625+
};
3626+
m.def("_round", round_pyapi, "", py::arg("src"), py::arg("dst"),
3627+
py::arg("sycl_queue"), py::arg("depends") = py::list());
3628+
3629+
auto round_result_type_pyapi = [&](py::dtype dtype) {
3630+
return py_unary_ufunc_result_type(dtype,
3631+
round_output_typeid_vector);
3632+
};
3633+
m.def("_round_result_type", round_result_type_pyapi);
3634+
}
35843635

35853636
// U29: ==== SIGN (x)
35863637
// FIXME:

0 commit comments

Comments
 (0)