Skip to content

Commit 597cd4b

Browse files
authored
Refactor complex comparison functions (#1320)
* Adds `math_utils.hpp` with complex comparison functions * Refactors comparison functions to use math_utils implementations * Refactors maximum and minimum to use math_utils implementations
1 parent b055ff9 commit 597cd4b

File tree

7 files changed

+138
-44
lines changed

7 files changed

+138
-44
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater.hpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include <cstdint>
3131
#include <type_traits>
3232

33+
#include "utils/math_utils.hpp"
3334
#include "utils/offset_utils.hpp"
3435
#include "utils/type_dispatch.hpp"
3536
#include "utils/type_utils.hpp"
@@ -67,12 +68,8 @@ template <typename argT1, typename argT2, typename resT> struct GreaterFunctor
6768
tu_ns::is_complex<argT2>::value)
6869
{
6970
static_assert(std::is_same_v<argT1, argT2>);
70-
using realT = typename argT1::value_type;
71-
realT real1 = std::real(in1);
72-
realT real2 = std::real(in2);
73-
74-
return (real1 == real2) ? (std::imag(in1) > std::imag(in2))
75-
: real1 > real2;
71+
using dpctl::tensor::math_utils::greater_complex;
72+
return greater_complex<argT1>(in1, in2);
7673
}
7774
else {
7875
return (in1 > in2);

dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater_equal.hpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include <cstdint>
3131
#include <type_traits>
3232

33+
#include "utils/math_utils.hpp"
3334
#include "utils/offset_utils.hpp"
3435
#include "utils/type_dispatch.hpp"
3536
#include "utils/type_utils.hpp"
@@ -68,12 +69,8 @@ struct GreaterEqualFunctor
6869
tu_ns::is_complex<argT2>::value)
6970
{
7071
static_assert(std::is_same_v<argT1, argT2>);
71-
using realT = typename argT1::value_type;
72-
realT real1 = std::real(in1);
73-
realT real2 = std::real(in2);
74-
75-
return (real1 == real2) ? (std::imag(in1) >= std::imag(in2))
76-
: real1 >= real2;
72+
using dpctl::tensor::math_utils::greater_equal_complex;
73+
return greater_equal_complex<argT1>(in1, in2);
7774
}
7875
else {
7976
return (in1 >= in2);

dpctl/tensor/libtensor/include/kernels/elementwise_functions/less.hpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include <cstdint>
3030
#include <type_traits>
3131

32+
#include "utils/math_utils.hpp"
3233
#include "utils/offset_utils.hpp"
3334
#include "utils/type_dispatch.hpp"
3435
#include "utils/type_utils.hpp"
@@ -66,12 +67,8 @@ template <typename argT1, typename argT2, typename resT> struct LessFunctor
6667
tu_ns::is_complex<argT2>::value)
6768
{
6869
static_assert(std::is_same_v<argT1, argT2>);
69-
using realT = typename argT1::value_type;
70-
realT real1 = std::real(in1);
71-
realT real2 = std::real(in2);
72-
73-
return (real1 == real2) ? (std::imag(in1) < std::imag(in2))
74-
: real1 < real2;
70+
using dpctl::tensor::math_utils::less_complex;
71+
return less_complex<argT1>(in1, in2);
7572
}
7673
else {
7774
return (in1 < in2);

dpctl/tensor/libtensor/include/kernels/elementwise_functions/less_equal.hpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include <cstdint>
3131
#include <type_traits>
3232

33+
#include "utils/math_utils.hpp"
3334
#include "utils/offset_utils.hpp"
3435
#include "utils/type_dispatch.hpp"
3536
#include "utils/type_utils.hpp"
@@ -67,12 +68,8 @@ template <typename argT1, typename argT2, typename resT> struct LessEqualFunctor
6768
tu_ns::is_complex<argT2>::value)
6869
{
6970
static_assert(std::is_same_v<argT1, argT2>);
70-
using realT = typename argT1::value_type;
71-
realT real1 = std::real(in1);
72-
realT real2 = std::real(in2);
73-
74-
return (real1 == real2) ? (std::imag(in1) <= std::imag(in2))
75-
: real1 <= real2;
71+
using dpctl::tensor::math_utils::less_equal_complex;
72+
return less_equal_complex<argT1>(in1, in2);
7673
}
7774
else {
7875
return (in1 <= in2);

dpctl/tensor/libtensor/include/kernels/elementwise_functions/maximum.hpp

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include <cstdint>
3030
#include <type_traits>
3131

32+
#include "utils/math_utils.hpp"
3233
#include "utils/offset_utils.hpp"
3334
#include "utils/type_dispatch.hpp"
3435
#include "utils/type_utils.hpp"
@@ -65,16 +66,8 @@ template <typename argT1, typename argT2, typename resT> struct MaximumFunctor
6566
tu_ns::is_complex<argT2>::value)
6667
{
6768
static_assert(std::is_same_v<argT1, argT2>);
68-
using realT = typename argT1::value_type;
69-
realT real1 = std::real(in1);
70-
realT real2 = std::real(in2);
71-
realT imag1 = std::imag(in1);
72-
realT imag2 = std::imag(in2);
73-
74-
bool gt = (real1 == real2) ? (imag1 > imag2)
75-
: (real1 > real2 && !std::isnan(imag1) &&
76-
!std::isnan(imag2));
77-
return (std::isnan(real1) || std::isnan(imag1) || gt) ? in1 : in2;
69+
using dpctl::tensor::math_utils::max_complex;
70+
return max_complex<argT1>(in1, in2);
7871
}
7972
else if constexpr (std::is_floating_point_v<argT1> ||
8073
std::is_same_v<argT1, sycl::half>)

dpctl/tensor/libtensor/include/kernels/elementwise_functions/minimum.hpp

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include <cstdint>
3030
#include <type_traits>
3131

32+
#include "utils/math_utils.hpp"
3233
#include "utils/offset_utils.hpp"
3334
#include "utils/type_dispatch.hpp"
3435
#include "utils/type_utils.hpp"
@@ -65,16 +66,8 @@ template <typename argT1, typename argT2, typename resT> struct MinimumFunctor
6566
tu_ns::is_complex<argT2>::value)
6667
{
6768
static_assert(std::is_same_v<argT1, argT2>);
68-
using realT = typename argT1::value_type;
69-
realT real1 = std::real(in1);
70-
realT real2 = std::real(in2);
71-
realT imag1 = std::imag(in1);
72-
realT imag2 = std::imag(in2);
73-
74-
bool lt = (real1 == real2) ? (imag1 < imag2)
75-
: (real1 < real2 && !std::isnan(imag1) &&
76-
!std::isnan(imag2));
77-
return (std::isnan(real1) || std::isnan(imag1) || lt) ? in1 : in2;
69+
using dpctl::tensor::math_utils::min_complex;
70+
return min_complex<argT1>(in1, in2);
7871
}
7972
else if constexpr (std::is_floating_point_v<argT1> ||
8073
std::is_same_v<argT1, sycl::half>)
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
//===------- math_utils.hpp - Implementation of math utils -------*-C++-*/===//
2+
//
3+
// Data Parallel Control (dpctl)
4+
//
5+
// Copyright 2020-2023 Intel Corporation
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// Unless required by applicable law or agreed to in writing, software
14+
// distributed under the License is distributed on an "AS IS" BASIS,
15+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
// See the License for the specific language governing permissions and
17+
// limitations under the License.
18+
//
19+
//===----------------------------------------------------------------------===//
20+
///
21+
/// \file
22+
/// This file defines math utility functions.
23+
//===----------------------------------------------------------------------===//
24+
25+
#pragma once
26+
#include <cmath>
27+
#include <complex>
28+
29+
namespace dpctl
30+
{
31+
namespace tensor
32+
{
33+
namespace math_utils
34+
{
35+
36+
template <typename T> bool less_complex(const T &x1, const T &x2)
37+
{
38+
using realT = typename T::value_type;
39+
realT real1 = std::real(x1);
40+
realT real2 = std::real(x2);
41+
realT imag1 = std::imag(x1);
42+
realT imag2 = std::imag(x2);
43+
44+
return (real1 == real2)
45+
? (imag1 < imag2)
46+
: (real1 < real2 && !std::isnan(imag1) && !std::isnan(imag2));
47+
}
48+
49+
template <typename T> bool greater_complex(const T &x1, const T &x2)
50+
{
51+
using realT = typename T::value_type;
52+
realT real1 = std::real(x1);
53+
realT real2 = std::real(x2);
54+
realT imag1 = std::imag(x1);
55+
realT imag2 = std::imag(x2);
56+
57+
return (real1 == real2)
58+
? (imag1 > imag2)
59+
: (real1 > real2 && !std::isnan(imag1) && !std::isnan(imag2));
60+
}
61+
62+
template <typename T> bool less_equal_complex(const T &x1, const T &x2)
63+
{
64+
using realT = typename T::value_type;
65+
realT real1 = std::real(x1);
66+
realT real2 = std::real(x2);
67+
realT imag1 = std::imag(x1);
68+
realT imag2 = std::imag(x2);
69+
70+
return (real1 == real2)
71+
? (imag1 <= imag2)
72+
: (real1 < real2 && !std::isnan(imag1) && !std::isnan(imag2));
73+
}
74+
75+
template <typename T> bool greater_equal_complex(const T &x1, const T &x2)
76+
{
77+
using realT = typename T::value_type;
78+
realT real1 = std::real(x1);
79+
realT real2 = std::real(x2);
80+
realT imag1 = std::imag(x1);
81+
realT imag2 = std::imag(x2);
82+
83+
return (real1 == real2)
84+
? (imag1 >= imag2)
85+
: (real1 > real2 && !std::isnan(imag1) && !std::isnan(imag2));
86+
}
87+
88+
template <typename T> T max_complex(const T &x1, const T &x2)
89+
{
90+
using realT = typename T::value_type;
91+
realT real1 = std::real(x1);
92+
realT real2 = std::real(x2);
93+
realT imag1 = std::imag(x1);
94+
realT imag2 = std::imag(x2);
95+
96+
bool isnan_imag1 = std::isnan(imag1);
97+
bool gt = (real1 == real2)
98+
? (imag1 > imag2)
99+
: (real1 > real2 && !isnan_imag1 && !std::isnan(imag2));
100+
return (std::isnan(real1) || isnan_imag1 || gt) ? x1 : x2;
101+
}
102+
103+
template <typename T> T min_complex(const T &x1, const T &x2)
104+
{
105+
using realT = typename T::value_type;
106+
realT real1 = std::real(x1);
107+
realT real2 = std::real(x2);
108+
realT imag1 = std::imag(x1);
109+
realT imag2 = std::imag(x2);
110+
111+
bool isnan_imag1 = std::isnan(imag1);
112+
bool lt = (real1 == real2)
113+
? (imag1 < imag2)
114+
: (real1 < real2 && !isnan_imag1 && !std::isnan(imag2));
115+
return (std::isnan(real1) || isnan_imag1 || lt) ? x1 : x2;
116+
}
117+
118+
} // namespace math_utils
119+
} // namespace tensor
120+
} // namespace dpctl

0 commit comments

Comments
 (0)