Skip to content

Commit 88c59bc

Browse files
committed
Refactors maximum and minimum to use math_utils implementations
1 parent e51a7fd commit 88c59bc

File tree

2 files changed

+6
-20
lines changed

2 files changed

+6
-20
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/maximum.hpp

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include <cstdint>
3030
#include <type_traits>
3131

32+
#include "utils/math_utils.hpp"
3233
#include "utils/offset_utils.hpp"
3334
#include "utils/type_dispatch.hpp"
3435
#include "utils/type_utils.hpp"
@@ -65,16 +66,8 @@ template <typename argT1, typename argT2, typename resT> struct MaximumFunctor
6566
tu_ns::is_complex<argT2>::value)
6667
{
6768
static_assert(std::is_same_v<argT1, argT2>);
68-
using realT = typename argT1::value_type;
69-
realT real1 = std::real(in1);
70-
realT real2 = std::real(in2);
71-
realT imag1 = std::imag(in1);
72-
realT imag2 = std::imag(in2);
73-
74-
bool gt = (real1 == real2) ? (imag1 > imag2)
75-
: (real1 > real2 && !std::isnan(imag1) &&
76-
!std::isnan(imag2));
77-
return (std::isnan(real1) || std::isnan(imag1) || gt) ? in1 : in2;
69+
using dpctl::tensor::math_utils::max_complex;
70+
return max_complex<argT1>(in1, in2);
7871
}
7972
else if constexpr (std::is_floating_point_v<argT1> ||
8073
std::is_same_v<argT1, sycl::half>)

dpctl/tensor/libtensor/include/kernels/elementwise_functions/minimum.hpp

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include <cstdint>
3030
#include <type_traits>
3131

32+
#include "utils/math_utils.hpp"
3233
#include "utils/offset_utils.hpp"
3334
#include "utils/type_dispatch.hpp"
3435
#include "utils/type_utils.hpp"
@@ -65,16 +66,8 @@ template <typename argT1, typename argT2, typename resT> struct MinimumFunctor
6566
tu_ns::is_complex<argT2>::value)
6667
{
6768
static_assert(std::is_same_v<argT1, argT2>);
68-
using realT = typename argT1::value_type;
69-
realT real1 = std::real(in1);
70-
realT real2 = std::real(in2);
71-
realT imag1 = std::imag(in1);
72-
realT imag2 = std::imag(in2);
73-
74-
bool lt = (real1 == real2) ? (imag1 < imag2)
75-
: (real1 < real2 && !std::isnan(imag1) &&
76-
!std::isnan(imag2));
77-
return (std::isnan(real1) || std::isnan(imag1) || lt) ? in1 : in2;
69+
using dpctl::tensor::math_utils::min_complex;
70+
return min_complex<argT1>(in1, in2);
7871
}
7972
else if constexpr (std::is_floating_point_v<argT1> ||
8073
std::is_same_v<argT1, sycl::half>)

0 commit comments

Comments
 (0)