Skip to content

Commit fb35be7

Browse files
committed
address reviewer's comments
1 parent a53090a commit fb35be7

File tree

2 files changed

+12
-34
lines changed

2 files changed

+12
-34
lines changed

dpctl/tensor/_elementwise_funcs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -959,7 +959,8 @@
959959
_round_docstring = """
960960
round(x, out=None, order='K')
961961
962-
Computes cosine for each element `x_i` for input array `x`.
962+
Rounds each element `x_i` of the input array `x` to
963+
the nearest integer-valued number.
963964
964965
Args:
965966
x (usm_ndarray):

dpctl/tensor/libtensor/include/kernels/elementwise_functions/round.hpp

Lines changed: 10 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -65,48 +65,25 @@ template <typename argT, typename resT> struct RoundFunctor
6565

6666
resT operator()(const argT &in)
6767
{
68+
6869
if constexpr (std::is_integral_v<argT>) {
6970
return in;
7071
}
7172
else if constexpr (is_complex<argT>::value) {
7273
using realT = typename argT::value_type;
73-
74-
const realT x = std::real(in);
75-
const realT y = std::imag(in);
76-
realT x_round, y_round;
77-
if (std::abs(x - std::floor(x)) == std::abs(x - std::ceil(x))) {
78-
x_round = static_cast<int>(std::ceil(x)) % 2 == 0
79-
? std::ceil(x)
80-
: std::floor(x);
81-
}
82-
else {
83-
x_round = std::round(x);
84-
}
85-
if (std::abs(y - std::floor(y)) == std::abs(y - std::ceil(y))) {
86-
y_round = static_cast<int>(std::ceil(y)) % 2 == 0
87-
? std::ceil(y)
88-
: std::floor(y);
89-
}
90-
else {
91-
y_round = std::round(y);
92-
}
93-
return resT{x_round, y_round};
74+
return resT{round_func<realT>(std::real(in)),
75+
round_func<realT>(std::imag(in))};
9476
}
9577
else {
96-
if (in == 0) {
97-
return in;
98-
}
99-
else if (std::abs(in - std::floor(in)) ==
100-
std::abs(in - std::ceil(in))) {
101-
return static_cast<int>(std::ceil(in)) % 2 == 0
102-
? std::ceil(in)
103-
: std::floor(in);
104-
}
105-
else {
106-
return std::round(in);
107-
}
78+
return round_func<argT>(in);
10879
}
10980
}
81+
82+
private:
83+
template <typename T> T round_func(const T &input) const
84+
{
85+
return std::rint(input);
86+
}
11087
};
11188

11289
template <typename argTy,

0 commit comments

Comments
 (0)