Skip to content

Commit 2beca15

Browse files
committed
Implements elementwise reciprocal
1 parent 5ec9fd5 commit 2beca15

File tree

7 files changed

+404
-1
lines changed

7 files changed

+404
-1
lines changed

dpctl/tensor/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ set(_elementwise_sources
8787
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/pow.cpp
8888
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/proj.cpp
8989
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/real.cpp
90+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/reciprocal.cpp
9091
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/remainder.cpp
9192
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/round.cpp
9293
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/rsqrt.cpp

dpctl/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@
153153
pow,
154154
proj,
155155
real,
156+
reciprocal,
156157
remainder,
157158
round,
158159
rsqrt,
@@ -342,4 +343,5 @@
342343
"var",
343344
"__array_api_version__",
344345
"__array_namespace_info__",
346+
"reciprocal",
345347
]

dpctl/tensor/_elementwise_funcs.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1880,10 +1880,40 @@
18801880
Returns:
18811881
usm_narray:
18821882
An array containing the element-wise reciprocal square-root.
1883-
The data type of the returned array is determined by
1883+
The returned array is determined by
18841884
the Type Promotion Rules.
18851885
"""
18861886

18871887
rsqrt = UnaryElementwiseFunc(
18881888
"rsqrt", ti._rsqrt_result_type, ti._rsqrt, _rsqrt_docstring_
18891889
)
1890+
1891+
1892+
# U42: ==== RECIPROCAL (x)
1893+
_reciprocal_docstring = """
1894+
reciprocal(x, out=None, order='K')
1895+
1896+
Computes the reciprocal of each element `x_i` for input array `x`.
1897+
1898+
Args:
1899+
x (usm_ndarray):
1900+
Input array, expected to have a real-valued floating-point data type.
1901+
out ({None, usm_ndarray}, optional):
1902+
Output array to populate.
1903+
Array have the correct shape and the expected data type.
1904+
order ("C","F","A","K", optional):
1905+
Memory layout of the newly output array, if parameter `out` is `None`.
1906+
Default: "K".
1907+
Returns:
1908+
usm_narray:
1909+
An array containing the element-wise reciprocals.
1910+
The returned array has a floating-point data type determined
1911+
by the Type Promotion Rules.
1912+
"""
1913+
1914+
reciprocal = UnaryElementwiseFunc(
1915+
"reciprocal",
1916+
ti._reciprocal_result_type,
1917+
ti._reciprocal,
1918+
_reciprocal_docstring,
1919+
)
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
//=== reciprocal.hpp - Unary function RECIPROCAL ------
2+
//*-C++-*--/===//
3+
//
4+
// Data Parallel Control (dpctl)
5+
//
6+
// Copyright 2020-2023 Intel Corporation
7+
//
8+
// Licensed under the Apache License, Version 2.0 (the "License");
9+
// you may not use this file except in compliance with the License.
10+
// You may obtain a copy of the License at
11+
//
12+
// http://www.apache.org/licenses/LICENSE-2.0
13+
//
14+
// Unless required by applicable law or agreed to in writing, software
15+
// distributed under the License is distributed on an "AS IS" BASIS,
16+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
// See the License for the specific language governing permissions and
18+
// limitations under the License.
19+
//
20+
//===---------------------------------------------------------------------===//
21+
///
22+
/// \file
23+
/// This file defines kernels for elementwise evaluation of RECIPROCAL(x)
24+
/// function.
25+
//===---------------------------------------------------------------------===//
26+
27+
#pragma once
28+
#include <cmath>
29+
#include <complex>
30+
#include <cstddef>
31+
#include <cstdint>
32+
#include <sycl/sycl.hpp>
33+
#include <type_traits>
34+
35+
#include "sycl_complex.hpp"
36+
#include "utils/offset_utils.hpp"
37+
#include "utils/type_dispatch.hpp"
38+
#include "utils/type_utils.hpp"
39+
40+
#include "kernels/elementwise_functions/common.hpp"
41+
#include <pybind11/pybind11.h>
42+
43+
namespace dpctl
44+
{
45+
namespace tensor
46+
{
47+
namespace kernels
48+
{
49+
namespace reciprocal
50+
{
51+
52+
namespace py = pybind11;
53+
namespace td_ns = dpctl::tensor::type_dispatch;
54+
55+
using dpctl::tensor::type_utils::is_complex;
56+
57+
template <typename argT, typename resT> struct ReciprocalFunctor
58+
{
59+
// is function constant for given argT
60+
using is_constant = typename std::false_type;
61+
// constant value, if constant
62+
// constexpr resT constant_value = resT{};
63+
// is function defined for sycl::vec
64+
using supports_vec = typename std::false_type;
65+
// do both argTy and resTy support sugroup store/load operation
66+
using supports_sg_loadstore = typename std::negation<
67+
std::disjunction<is_complex<resT>, is_complex<argT>>>;
68+
69+
resT operator()(const argT &in) const
70+
{
71+
if constexpr (is_complex<argT>::value) {
72+
73+
using realT = typename argT::value_type;
74+
#ifdef USE_SYCL_FOR_COMPLEX_TYPES
75+
76+
return realT(1) / exprm_ns::complex<realT>(in);
77+
#else
78+
return realT(1) / in;
79+
#endif
80+
}
81+
else {
82+
return argT(1) / in;
83+
}
84+
}
85+
};
86+
87+
template <typename argTy,
88+
typename resTy = argTy,
89+
unsigned int vec_sz = 4,
90+
unsigned int n_vecs = 2>
91+
using ReciprocalContigFunctor =
92+
elementwise_common::UnaryContigFunctor<argTy,
93+
resTy,
94+
ReciprocalFunctor<argTy, resTy>,
95+
vec_sz,
96+
n_vecs>;
97+
98+
template <typename argTy, typename resTy, typename IndexerT>
99+
using ReciprocalStridedFunctor =
100+
elementwise_common::UnaryStridedFunctor<argTy,
101+
resTy,
102+
IndexerT,
103+
ReciprocalFunctor<argTy, resTy>>;
104+
105+
template <typename T> struct ReciprocalOutputType
106+
{
107+
using value_type = typename std::disjunction< // disjunction is C++17
108+
// feature, supported by DPC++
109+
td_ns::TypeMapResultEntry<T, sycl::half>,
110+
td_ns::TypeMapResultEntry<T, float>,
111+
td_ns::TypeMapResultEntry<T, double>,
112+
td_ns::TypeMapResultEntry<T, std::complex<float>>,
113+
td_ns::TypeMapResultEntry<T, std::complex<double>>,
114+
td_ns::DefaultResultEntry<void>>::result_type;
115+
};
116+
117+
template <typename T1, typename T2, unsigned int vec_sz, unsigned int n_vecs>
118+
class reciprocal_contig_kernel;
119+
120+
template <typename argTy>
121+
sycl::event reciprocal_contig_impl(sycl::queue &exec_q,
122+
size_t nelems,
123+
const char *arg_p,
124+
char *res_p,
125+
const std::vector<sycl::event> &depends = {})
126+
{
127+
return elementwise_common::unary_contig_impl<argTy, ReciprocalOutputType,
128+
ReciprocalContigFunctor,
129+
reciprocal_contig_kernel>(
130+
exec_q, nelems, arg_p, res_p, depends);
131+
}
132+
133+
template <typename fnT, typename T> struct ReciprocalContigFactory
134+
{
135+
fnT get()
136+
{
137+
if constexpr (std::is_same_v<
138+
typename ReciprocalOutputType<T>::value_type, void>)
139+
{
140+
fnT fn = nullptr;
141+
return fn;
142+
}
143+
else {
144+
fnT fn = reciprocal_contig_impl<T>;
145+
return fn;
146+
}
147+
}
148+
};
149+
150+
template <typename fnT, typename T> struct ReciprocalTypeMapFactory
151+
{
152+
/*! @brief get typeid for output type of 1 / x */
153+
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
154+
{
155+
using rT = typename ReciprocalOutputType<T>::value_type;
156+
return td_ns::GetTypeid<rT>{}.get();
157+
}
158+
};
159+
160+
template <typename T1, typename T2, typename T3>
161+
class reciprocal_strided_kernel;
162+
163+
template <typename argTy>
164+
sycl::event
165+
reciprocal_strided_impl(sycl::queue &exec_q,
166+
size_t nelems,
167+
int nd,
168+
const py::ssize_t *shape_and_strides,
169+
const char *arg_p,
170+
py::ssize_t arg_offset,
171+
char *res_p,
172+
py::ssize_t res_offset,
173+
const std::vector<sycl::event> &depends,
174+
const std::vector<sycl::event> &additional_depends)
175+
{
176+
return elementwise_common::unary_strided_impl<argTy, ReciprocalOutputType,
177+
ReciprocalStridedFunctor,
178+
reciprocal_strided_kernel>(
179+
exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p,
180+
res_offset, depends, additional_depends);
181+
}
182+
183+
template <typename fnT, typename T> struct ReciprocalStridedFactory
184+
{
185+
fnT get()
186+
{
187+
if constexpr (std::is_same_v<
188+
typename ReciprocalOutputType<T>::value_type, void>)
189+
{
190+
fnT fn = nullptr;
191+
return fn;
192+
}
193+
else {
194+
fnT fn = reciprocal_strided_impl<T>;
195+
return fn;
196+
}
197+
}
198+
};
199+
200+
} // namespace reciprocal
201+
} // namespace kernels
202+
} // namespace tensor
203+
} // namespace dpctl

dpctl/tensor/libtensor/source/elementwise_functions/elementwise_common.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
#include "pow.hpp"
8080
#include "proj.hpp"
8181
#include "real.hpp"
82+
#include "reciprocal.hpp"
8283
#include "remainder.hpp"
8384
#include "round.hpp"
8485
#include "rsqrt.hpp"
@@ -161,6 +162,7 @@ void init_elementwise_functions(py::module_ m)
161162
init_pow(m);
162163
init_proj(m);
163164
init_real(m);
165+
init_reciprocal(m);
164166
init_remainder(m);
165167
init_round(m);
166168
init_rsqrt(m);

0 commit comments

Comments
 (0)