Skip to content

Commit 071a618

Browse files
committed
impl_real_imag_conj
1 parent ebdc96b commit 071a618

File tree

7 files changed

+809
-9
lines changed

7 files changed

+809
-9
lines changed

dpctl/tensor/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,14 +94,17 @@
9494
from ._elementwise_funcs import (
9595
abs,
9696
add,
97+
conj,
9798
cos,
9899
divide,
99100
equal,
100101
exp,
102+
imag,
101103
isfinite,
102104
isinf,
103105
isnan,
104106
multiply,
107+
real,
105108
sin,
106109
sqrt,
107110
subtract,
@@ -195,4 +198,7 @@
195198
"multiply",
196199
"subtract",
197200
"equal",
201+
"real",
202+
"imag",
203+
"conj",
198204
]

dpctl/tensor/_elementwise_funcs.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,15 @@
9090
# FIXME: implement U09
9191

9292
# U10: ==== CONJ (x)
93-
# FIXME: implement U10
93+
_conj_docstring = """
94+
conj(x, order='K')
95+
96+
Computes conjugate of each element `x_i` for input array `x`.
97+
"""
98+
99+
conj = UnaryElementwiseFunc(
100+
"conj", ti._conj_result_type, ti._conj, _conj_docstring
101+
)
94102

95103
# U11: ==== COS (x)
96104
_cos_docstring = """
@@ -174,7 +182,15 @@
174182
# FIXME: implement B12
175183

176184
# U16: ==== IMAG (x)
177-
# FIXME: implement U16
185+
_imag_docstring = """
186+
imag(x, order='K')
187+
188+
Computes imaginary part of each element `x_i` for input array `x`.
189+
"""
190+
191+
imag = UnaryElementwiseFunc(
192+
"imag", ti._imag_result_type, ti._imag, _imag_docstring
193+
)
178194

179195
# U17: ==== ISFINITE (x)
180196
_isfinite_docstring_ = """
@@ -276,7 +292,15 @@
276292
# FIXME: implement B21
277293

278294
# U27: ==== REAL (x)
279-
# FIXME: implement U27
295+
_real_docstring = """
296+
real(x, order='K')
297+
298+
Computes real part of each element `x_i` for input array `x`.
299+
"""
300+
301+
real = UnaryElementwiseFunc(
302+
"real", ti._real_result_type, ti._real, _real_docstring
303+
)
280304

281305
# B22: ==== REMAINDER (x1, x2)
282306
# FIXME: implement B22
Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
//=== conj.hpp - Unary function CONJ ------
2+
//*-C++-*--/===//
3+
//
4+
// Data Parallel Control (dpctl)
5+
//
6+
// Copyright 2020-2023 Intel Corporation
7+
//
8+
// Licensed under the Apache License, Version 2.0 (the "License");
9+
// you may not use this file except in compliance with the License.
10+
// You may obtain a copy of the License at
11+
//
12+
// http://www.apache.org/licenses/LICENSE-2.0
13+
//
14+
// Unless required by applicable law or agreed to in writing, software
15+
// distributed under the License is distributed on an "AS IS" BASIS,
16+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
// See the License for the specific language governing permissions and
18+
// limitations under the License.
19+
//
20+
//===---------------------------------------------------------------------===//
21+
///
22+
/// \file
23+
/// This file defines kernels for elementwise evaluation of CONJ(x) function.
24+
//===---------------------------------------------------------------------===//
25+
26+
#pragma once
27+
#include <CL/sycl.hpp>
28+
#include <cmath>
29+
#include <complex>
30+
#include <cstddef>
31+
#include <cstdint>
32+
#include <type_traits>
33+
34+
#include "kernels/elementwise_functions/common.hpp"
35+
36+
#include "utils/offset_utils.hpp"
37+
#include "utils/type_dispatch.hpp"
38+
#include "utils/type_utils.hpp"
39+
#include <pybind11/pybind11.h>
40+
41+
namespace dpctl
42+
{
43+
namespace tensor
44+
{
45+
namespace kernels
46+
{
47+
namespace conj
48+
{
49+
50+
namespace py = pybind11;
51+
namespace td_ns = dpctl::tensor::type_dispatch;
52+
53+
using dpctl::tensor::type_utils::is_complex;
54+
55+
template <typename argT, typename resT> struct ConjFunctor
56+
{
57+
58+
// is function constant for given argT
59+
using is_constant = typename std::false_type;
60+
// constant value, if constant
61+
// constexpr resT constant_value = resT{};
62+
// is function defined for sycl::vec
63+
using supports_vec = typename std::false_type;
64+
// do both argTy and resTy support sugroup store/load operation
65+
using supports_sg_loadstore = typename std::negation<
66+
std::disjunction<is_complex<resT>, is_complex<argT>>>;
67+
68+
resT operator()(const argT &in)
69+
{
70+
if constexpr (is_complex<argT>::value) {
71+
return std::conj(in);
72+
}
73+
else {
74+
if constexpr (!std::is_same_v<argT, bool>)
75+
static_assert(std::is_same_v<resT, argT>);
76+
return in;
77+
}
78+
}
79+
};
80+
81+
template <typename argTy,
82+
typename resTy = argTy,
83+
unsigned int vec_sz = 4,
84+
unsigned int n_vecs = 2>
85+
using ConjContigFunctor = elementwise_common::
86+
UnaryContigFunctor<argTy, resTy, ConjFunctor<argTy, resTy>, vec_sz, n_vecs>;
87+
88+
template <typename argTy, typename resTy, typename IndexerT>
89+
using ConjStridedFunctor = elementwise_common::
90+
UnaryStridedFunctor<argTy, resTy, IndexerT, ConjFunctor<argTy, resTy>>;
91+
92+
template <typename T> struct ConjOutputType
93+
{
94+
using value_type = typename std::disjunction< // disjunction is C++17
95+
// feature, supported by DPC++
96+
td_ns::TypeMapResultEntry<T, bool, int8_t>,
97+
td_ns::TypeMapResultEntry<T, std::uint8_t>,
98+
td_ns::TypeMapResultEntry<T, std::uint16_t>,
99+
td_ns::TypeMapResultEntry<T, std::uint32_t>,
100+
td_ns::TypeMapResultEntry<T, std::uint64_t>,
101+
td_ns::TypeMapResultEntry<T, std::int8_t>,
102+
td_ns::TypeMapResultEntry<T, std::int16_t>,
103+
td_ns::TypeMapResultEntry<T, std::int32_t>,
104+
td_ns::TypeMapResultEntry<T, std::int64_t>,
105+
td_ns::TypeMapResultEntry<T, sycl::half>,
106+
td_ns::TypeMapResultEntry<T, float>,
107+
td_ns::TypeMapResultEntry<T, double>,
108+
td_ns::TypeMapResultEntry<T, std::complex<float>>,
109+
td_ns::TypeMapResultEntry<T, std::complex<double>>,
110+
td_ns::DefaultResultEntry<void>>::result_type;
111+
};
112+
113+
template <typename T1, typename T2, unsigned int vec_sz, unsigned int n_vecs>
114+
class conj_contig_kernel;
115+
116+
template <typename argTy>
117+
sycl::event conj_contig_impl(sycl::queue exec_q,
118+
size_t nelems,
119+
const char *arg_p,
120+
char *res_p,
121+
const std::vector<sycl::event> &depends = {})
122+
{
123+
return elementwise_common::unary_contig_impl<
124+
argTy, ConjOutputType, ConjContigFunctor, conj_contig_kernel>(
125+
exec_q, nelems, arg_p, res_p, depends);
126+
}
127+
128+
template <typename fnT, typename T> struct ConjContigFactory
129+
{
130+
fnT get()
131+
{
132+
if constexpr (std::is_same_v<typename ConjOutputType<T>::value_type,
133+
void>) {
134+
fnT fn = nullptr;
135+
return fn;
136+
}
137+
else {
138+
fnT fn = conj_contig_impl<T>;
139+
return fn;
140+
}
141+
}
142+
};
143+
144+
template <typename fnT, typename T> struct ConjTypeMapFactory
145+
{
146+
/*! @brief get typeid for output type of std::conj(T x) */
147+
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
148+
{
149+
using rT = typename ConjOutputType<T>::value_type;
150+
return td_ns::GetTypeid<rT>{}.get();
151+
}
152+
};
153+
154+
template <typename T1, typename T2, typename T3> class conj_strided_kernel;
155+
156+
template <typename argTy>
157+
sycl::event
158+
conj_strided_impl(sycl::queue exec_q,
159+
size_t nelems,
160+
int nd,
161+
const py::ssize_t *shape_and_strides,
162+
const char *arg_p,
163+
py::ssize_t arg_offset,
164+
char *res_p,
165+
py::ssize_t res_offset,
166+
const std::vector<sycl::event> &depends,
167+
const std::vector<sycl::event> &additional_depends)
168+
{
169+
return elementwise_common::unary_strided_impl<
170+
argTy, ConjOutputType, ConjStridedFunctor, conj_strided_kernel>(
171+
exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p,
172+
res_offset, depends, additional_depends);
173+
}
174+
175+
template <typename fnT, typename T> struct ConjStridedFactory
176+
{
177+
fnT get()
178+
{
179+
if constexpr (std::is_same_v<typename ConjOutputType<T>::value_type,
180+
void>) {
181+
fnT fn = nullptr;
182+
return fn;
183+
}
184+
else {
185+
fnT fn = conj_strided_impl<T>;
186+
return fn;
187+
}
188+
}
189+
};
190+
191+
} // namespace conj
192+
} // namespace kernels
193+
} // namespace tensor
194+
} // namespace dpctl

0 commit comments

Comments
 (0)