Skip to content

Commit 0c9fbeb

Browse files
committed
address reviewer's comments
1 parent b989d36 commit 0c9fbeb

File tree

3 files changed

+35
-25
lines changed

3 files changed

+35
-25
lines changed

dpctl/tensor/_elementwise_funcs.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1196,7 +1196,7 @@
11961196
Default: "K".
11971197
Returns:
11981198
usm_narray:
1199-
An array containing the element-wise products. The data type of
1199+
An array containing the element-wise result. The data type of
12001200
the returned array is determined by the Type Promotion Rules.
12011201
"""
12021202
maximum = BinaryElementwiseFunc(
@@ -1429,6 +1429,12 @@
14291429
First input array, expected to have a real-valued data type.
14301430
x2 (usm_ndarray):
14311431
Second input array, also expected to have a real-valued data type.
1432+
out ({None, usm_ndarray}, optional):
1433+
Output array to populate.
1434+
Array have the correct shape and the expected data type.
1435+
order ("C","F","A","K", optional):
1436+
Memory layout of the newly output array, if parameter `out` is `None`.
1437+
Default: "K".
14321438
Returns:
14331439
usm_ndarray:
14341440
an array containing the element-wise remainders. The data type of

dpctl/tensor/libtensor/include/kernels/elementwise_functions/maximum.hpp

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -71,18 +71,16 @@ template <typename argT1, typename argT2, typename resT> struct MaximumFunctor
7171
realT imag1 = std::imag(in1);
7272
realT imag2 = std::imag(in2);
7373

74-
if (std::isnan(real1) || std::isnan(imag1))
75-
return in1;
76-
else if (std::isnan(real2) || std::isnan(imag2))
77-
return in2;
78-
else if (real1 == real2)
79-
return imag1 > imag2 ? in1 : in2;
80-
else
81-
return real1 > real2 ? in1 : in2;
82-
}
83-
else {
84-
return (in1 != in1 || in1 > in2) ? in1 : in2;
74+
bool gt = (real1 == real2) ? (imag1 > imag2)
75+
: (real1 > real2 && !std::isnan(imag1) &&
76+
!std::isnan(imag2));
77+
return (std::isnan(real1) || std::isnan(imag1) || gt) ? in1 : in2;
8578
}
79+
else if constexpr (std::is_floating_point_v<argT1> ||
80+
std::is_same_v<argT1, sycl::half>)
81+
return (std::isnan(in1) || in1 > in2) ? in1 : in2;
82+
else
83+
return (in1 > in2) ? in1 : in2;
8684
}
8785

8886
template <int vec_sz>
@@ -92,7 +90,11 @@ template <typename argT1, typename argT2, typename resT> struct MaximumFunctor
9290
sycl::vec<resT, vec_sz> res;
9391
#pragma unroll
9492
for (int i = 0; i < vec_sz; ++i) {
95-
res[i] = (in1[i] != in1[i] || in1[i] > in2[i]) ? in1[i] : in2[i];
93+
if constexpr (std::is_floating_point_v<argT1>)
94+
res[i] =
95+
(sycl::isnan(in1[i]) || in1[i] > in2[i]) ? in1[i] : in2[i];
96+
else
97+
res[i] = (in1[i] > in2[i]) ? in1[i] : in2[i];
9698
}
9799
return res;
98100
}

dpctl/tensor/libtensor/include/kernels/elementwise_functions/minimum.hpp

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -71,18 +71,16 @@ template <typename argT1, typename argT2, typename resT> struct MinimumFunctor
7171
realT imag1 = std::imag(in1);
7272
realT imag2 = std::imag(in2);
7373

74-
if (std::isnan(real1) || std::isnan(imag1))
75-
return in1;
76-
else if (std::isnan(real2) || std::isnan(imag2))
77-
return in2;
78-
else if (real1 == real2)
79-
return imag1 < imag2 ? in1 : in2;
80-
else
81-
return real1 < real2 ? in1 : in2;
82-
}
83-
else {
84-
return (in1 != in1 || in1 < in2) ? in1 : in2;
74+
bool lt = (real1 == real2) ? (imag1 < imag2)
75+
: (real1 < real2 && !std::isnan(imag1) &&
76+
!std::isnan(imag2));
77+
return (std::isnan(real1) || std::isnan(imag1) || lt) ? in1 : in2;
8578
}
79+
else if constexpr (std::is_floating_point_v<argT1> ||
80+
std::is_same_v<argT1, sycl::half>)
81+
return (std::isnan(in1) || in1 < in2) ? in1 : in2;
82+
else
83+
return (in1 < in2) ? in1 : in2;
8684
}
8785

8886
template <int vec_sz>
@@ -92,7 +90,11 @@ template <typename argT1, typename argT2, typename resT> struct MinimumFunctor
9290
sycl::vec<resT, vec_sz> res;
9391
#pragma unroll
9492
for (int i = 0; i < vec_sz; ++i) {
95-
res[i] = (in1[i] != in1[i] || in1[i] < in2[i]) ? in1[i] : in2[i];
93+
if constexpr (std::is_floating_point_v<argT1>)
94+
res[i] =
95+
(sycl::isnan(in1[i]) || in1[i] < in2[i]) ? in1[i] : in2[i];
96+
else
97+
res[i] = (in1[i] < in2[i]) ? in1[i] : in2[i];
9698
}
9799
return res;
98100
}

0 commit comments

Comments
 (0)