Skip to content

Commit e51a7fd

Browse files
committed
Refactors comparison functions to use math_utils implementations
1 parent 9679a69 commit e51a7fd

File tree

4 files changed

+12
-24
lines changed

4 files changed

+12
-24
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater.hpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include <cstdint>
3131
#include <type_traits>
3232

33+
#include "utils/math_utils.hpp"
3334
#include "utils/offset_utils.hpp"
3435
#include "utils/type_dispatch.hpp"
3536
#include "utils/type_utils.hpp"
@@ -67,12 +68,8 @@ template <typename argT1, typename argT2, typename resT> struct GreaterFunctor
6768
tu_ns::is_complex<argT2>::value)
6869
{
6970
static_assert(std::is_same_v<argT1, argT2>);
70-
using realT = typename argT1::value_type;
71-
realT real1 = std::real(in1);
72-
realT real2 = std::real(in2);
73-
74-
return (real1 == real2) ? (std::imag(in1) > std::imag(in2))
75-
: real1 > real2;
71+
using dpctl::tensor::math_utils::greater_complex;
72+
return greater_complex<argT1>(in1, in2);
7673
}
7774
else {
7875
return (in1 > in2);

dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater_equal.hpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include <cstdint>
3131
#include <type_traits>
3232

33+
#include "utils/math_utils.hpp"
3334
#include "utils/offset_utils.hpp"
3435
#include "utils/type_dispatch.hpp"
3536
#include "utils/type_utils.hpp"
@@ -68,12 +69,8 @@ struct GreaterEqualFunctor
6869
tu_ns::is_complex<argT2>::value)
6970
{
7071
static_assert(std::is_same_v<argT1, argT2>);
71-
using realT = typename argT1::value_type;
72-
realT real1 = std::real(in1);
73-
realT real2 = std::real(in2);
74-
75-
return (real1 == real2) ? (std::imag(in1) >= std::imag(in2))
76-
: real1 >= real2;
72+
using dpctl::tensor::math_utils::greater_equal_complex;
73+
return greater_equal_complex<argT1>(in1, in2);
7774
}
7875
else {
7976
return (in1 >= in2);

dpctl/tensor/libtensor/include/kernels/elementwise_functions/less.hpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include <cstdint>
3030
#include <type_traits>
3131

32+
#include "utils/math_utils.hpp"
3233
#include "utils/offset_utils.hpp"
3334
#include "utils/type_dispatch.hpp"
3435
#include "utils/type_utils.hpp"
@@ -66,12 +67,8 @@ template <typename argT1, typename argT2, typename resT> struct LessFunctor
6667
tu_ns::is_complex<argT2>::value)
6768
{
6869
static_assert(std::is_same_v<argT1, argT2>);
69-
using realT = typename argT1::value_type;
70-
realT real1 = std::real(in1);
71-
realT real2 = std::real(in2);
72-
73-
return (real1 == real2) ? (std::imag(in1) < std::imag(in2))
74-
: real1 < real2;
70+
using dpctl::tensor::math_utils::less_complex;
71+
return less_complex<argT1>(in1, in2);
7572
}
7673
else {
7774
return (in1 < in2);

dpctl/tensor/libtensor/include/kernels/elementwise_functions/less_equal.hpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include <cstdint>
3131
#include <type_traits>
3232

33+
#include "utils/math_utils.hpp"
3334
#include "utils/offset_utils.hpp"
3435
#include "utils/type_dispatch.hpp"
3536
#include "utils/type_utils.hpp"
@@ -67,12 +68,8 @@ template <typename argT1, typename argT2, typename resT> struct LessEqualFunctor
6768
tu_ns::is_complex<argT2>::value)
6869
{
6970
static_assert(std::is_same_v<argT1, argT2>);
70-
using realT = typename argT1::value_type;
71-
realT real1 = std::real(in1);
72-
realT real2 = std::real(in2);
73-
74-
return (real1 == real2) ? (std::imag(in1) <= std::imag(in2))
75-
: real1 <= real2;
71+
using dpctl::tensor::math_utils::less_equal_complex;
72+
return less_equal_complex<argT1>(in1, in2);
7673
}
7774
else {
7875
return (in1 <= in2);

0 commit comments

Comments
 (0)