Skip to content

Commit 0695fde

Browse files
committed
Implements rsqrt and tests for rsqrt
1 parent 64eb98c commit 0695fde

File tree

5 files changed

+346
-3
lines changed

5 files changed

+346
-3
lines changed

dpctl/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@
152152
real,
153153
remainder,
154154
round,
155+
rsqrt,
155156
sign,
156157
signbit,
157158
sin,
@@ -320,4 +321,5 @@
320321
"cbrt",
321322
"exp2",
322323
"copysign",
324+
"rsqrt",
323325
]

dpctl/tensor/_elementwise_funcs.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1763,7 +1763,7 @@
17631763
)
17641764

17651765

1766-
# U33: ==== CBRT (x)
1766+
# U37: ==== CBRT (x)
17671767
_cbrt_docstring_ = """
17681768
cbrt(x, out=None, order='K')
17691769
@@ -1790,7 +1790,7 @@
17901790
)
17911791

17921792

1793-
# U34: ==== EXP2 (x)
1793+
# U38: ==== EXP2 (x)
17941794
_exp2_docstring_ = """
17951795
exp2(x, out=None, order='K')
17961796
@@ -1817,7 +1817,7 @@
18171817
)
18181818

18191819

1820-
# B23: ==== COPYSIGN (x1, x2)
1820+
# B25: ==== COPYSIGN (x1, x2)
18211821
_copysign_docstring_ = """
18221822
copysign(x1, x2, out=None, order='K')
18231823
@@ -1847,3 +1847,30 @@
18471847
ti._copysign,
18481848
_copysign_docstring_,
18491849
)
1850+
1851+
1852+
# U39: ==== RSQRT (x)
1853+
_rsqrt_docstring_ = """
1854+
rsqrt(x, out=None, order='K')
1855+
1856+
Computes the reciprocal square-root for each element `x_i` for input array `x`.
1857+
1858+
Args:
1859+
x (usm_ndarray):
1860+
Input array, expected to have a real floating-point data type.
1861+
out ({None, usm_ndarray}, optional):
1862+
Output array to populate.
1863+
Array have the correct shape and the expected data type.
1864+
order ("C","F","A","K", optional):
1865+
Memory layout of the newly output array, if parameter `out` is `None`.
1866+
Default: "K".
1867+
Returns:
1868+
usm_narray:
1869+
An array containing the element-wise reciprocal square-root.
1870+
The data type of the returned array is determined by
1871+
the Type Promotion Rules.
1872+
"""
1873+
1874+
rsqrt = UnaryElementwiseFunc(
1875+
"rsqrt", ti._rsqrt_result_type, ti._rsqrt, _rsqrt_docstring_
1876+
)
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
//=== rsqrt.hpp - Unary function RSQRT ------
2+
//*-C++-*--/===//
3+
//
4+
// Data Parallel Control (dpctl)
5+
//
6+
// Copyright 2020-2023 Intel Corporation
7+
//
8+
// Licensed under the Apache License, Version 2.0 (the "License");
9+
// you may not use this file except in compliance with the License.
10+
// You may obtain a copy of the License at
11+
//
12+
// http://www.apache.org/licenses/LICENSE-2.0
13+
//
14+
// Unless required by applicable law or agreed to in writing, software
15+
// distributed under the License is distributed on an "AS IS" BASIS,
16+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
// See the License for the specific language governing permissions and
18+
// limitations under the License.
19+
//
20+
//===---------------------------------------------------------------------===//
21+
///
22+
/// \file
23+
/// This file defines kernels for elementwise evaluation of RSQRT(x)
24+
/// function that computes the reciprocal square root.
25+
//===---------------------------------------------------------------------===//
26+
27+
#pragma once
28+
#include <CL/sycl.hpp>
29+
#include <cmath>
30+
#include <complex>
31+
#include <cstddef>
32+
#include <cstdint>
33+
#include <limits>
34+
#include <type_traits>
35+
36+
#include "kernels/elementwise_functions/common.hpp"
37+
38+
#include "utils/offset_utils.hpp"
39+
#include "utils/type_dispatch.hpp"
40+
#include "utils/type_utils.hpp"
41+
#include <pybind11/pybind11.h>
42+
43+
namespace dpctl
44+
{
45+
namespace tensor
46+
{
47+
namespace kernels
48+
{
49+
namespace rsqrt
50+
{
51+
52+
namespace py = pybind11;
53+
namespace td_ns = dpctl::tensor::type_dispatch;
54+
55+
template <typename argT, typename resT> struct RsqrtFunctor
56+
{
57+
58+
// is function constant for given argT
59+
using is_constant = typename std::false_type;
60+
// constant value, if constant
61+
// constexpr resT constant_value = resT{};
62+
// is function defined for sycl::vec
63+
using supports_vec = typename std::false_type;
64+
// do both argTy and resTy support sugroup store/load operation
65+
using supports_sg_loadstore = typename std::true_type;
66+
67+
resT operator()(const argT &in) const
68+
{
69+
return sycl::rsqrt(in);
70+
}
71+
};
72+
73+
template <typename argTy,
74+
typename resTy = argTy,
75+
unsigned int vec_sz = 4,
76+
unsigned int n_vecs = 2>
77+
using RsqrtContigFunctor =
78+
elementwise_common::UnaryContigFunctor<argTy,
79+
resTy,
80+
RsqrtFunctor<argTy, resTy>,
81+
vec_sz,
82+
n_vecs>;
83+
84+
template <typename argTy, typename resTy, typename IndexerT>
85+
using RsqrtStridedFunctor = elementwise_common::
86+
UnaryStridedFunctor<argTy, resTy, IndexerT, RsqrtFunctor<argTy, resTy>>;
87+
88+
template <typename T> struct RsqrtOutputType
89+
{
90+
using value_type = typename std::disjunction< // disjunction is C++17
91+
// feature, supported by DPC++
92+
td_ns::TypeMapResultEntry<T, sycl::half, sycl::half>,
93+
td_ns::TypeMapResultEntry<T, float, float>,
94+
td_ns::TypeMapResultEntry<T, double, double>,
95+
td_ns::DefaultResultEntry<void>>::result_type;
96+
};
97+
98+
template <typename T1, typename T2, unsigned int vec_sz, unsigned int n_vecs>
99+
class rsqrt_contig_kernel;
100+
101+
template <typename argTy>
102+
sycl::event rsqrt_contig_impl(sycl::queue &exec_q,
103+
size_t nelems,
104+
const char *arg_p,
105+
char *res_p,
106+
const std::vector<sycl::event> &depends = {})
107+
{
108+
return elementwise_common::unary_contig_impl<
109+
argTy, RsqrtOutputType, RsqrtContigFunctor, rsqrt_contig_kernel>(
110+
exec_q, nelems, arg_p, res_p, depends);
111+
}
112+
113+
template <typename fnT, typename T> struct RsqrtContigFactory
114+
{
115+
fnT get()
116+
{
117+
if constexpr (std::is_same_v<typename RsqrtOutputType<T>::value_type,
118+
void>) {
119+
fnT fn = nullptr;
120+
return fn;
121+
}
122+
else {
123+
fnT fn = rsqrt_contig_impl<T>;
124+
return fn;
125+
}
126+
}
127+
};
128+
129+
template <typename fnT, typename T> struct RsqrtTypeMapFactory
130+
{
131+
/*! @brief get typeid for output type of sycl::rsqrt(T x) */
132+
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
133+
{
134+
using rT = typename RsqrtOutputType<T>::value_type;
135+
return td_ns::GetTypeid<rT>{}.get();
136+
}
137+
};
138+
139+
template <typename T1, typename T2, typename T3> class rsqrt_strided_kernel;
140+
141+
template <typename argTy>
142+
sycl::event
143+
rsqrt_strided_impl(sycl::queue &exec_q,
144+
size_t nelems,
145+
int nd,
146+
const py::ssize_t *shape_and_strides,
147+
const char *arg_p,
148+
py::ssize_t arg_offset,
149+
char *res_p,
150+
py::ssize_t res_offset,
151+
const std::vector<sycl::event> &depends,
152+
const std::vector<sycl::event> &additional_depends)
153+
{
154+
return elementwise_common::unary_strided_impl<
155+
argTy, RsqrtOutputType, RsqrtStridedFunctor, rsqrt_strided_kernel>(
156+
exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p,
157+
res_offset, depends, additional_depends);
158+
}
159+
160+
template <typename fnT, typename T> struct RsqrtStridedFactory
161+
{
162+
fnT get()
163+
{
164+
if constexpr (std::is_same_v<typename RsqrtOutputType<T>::value_type,
165+
void>) {
166+
fnT fn = nullptr;
167+
return fn;
168+
}
169+
else {
170+
fnT fn = rsqrt_strided_impl<T>;
171+
return fn;
172+
}
173+
}
174+
};
175+
176+
} // namespace rsqrt
177+
} // namespace kernels
178+
} // namespace tensor
179+
} // namespace dpctl

dpctl/tensor/libtensor/source/elementwise_functions.cpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@
8989
#include "kernels/elementwise_functions/real.hpp"
9090
#include "kernels/elementwise_functions/remainder.hpp"
9191
#include "kernels/elementwise_functions/round.hpp"
92+
#include "kernels/elementwise_functions/rsqrt.hpp"
9293
#include "kernels/elementwise_functions/sign.hpp"
9394
#include "kernels/elementwise_functions/signbit.hpp"
9495
#include "kernels/elementwise_functions/sin.hpp"
@@ -2899,6 +2900,42 @@ void populate_exp2_dispatch_vectors(void)
28992900

29002901
} // namespace impl
29012902

2903+
// U39: ==== RSQRT (x)
2904+
namespace impl
2905+
{
2906+
2907+
namespace rsqrt_fn_ns = dpctl::tensor::kernels::rsqrt;
2908+
2909+
static unary_contig_impl_fn_ptr_t
2910+
rsqrt_contig_dispatch_vector[td_ns::num_types];
2911+
static int rsqrt_output_typeid_vector[td_ns::num_types];
2912+
static unary_strided_impl_fn_ptr_t
2913+
rsqrt_strided_dispatch_vector[td_ns::num_types];
2914+
2915+
void populate_rsqrt_dispatch_vectors(void)
2916+
{
2917+
using namespace td_ns;
2918+
namespace fn_ns = rsqrt_fn_ns;
2919+
2920+
using fn_ns::RsqrtContigFactory;
2921+
DispatchVectorBuilder<unary_contig_impl_fn_ptr_t, RsqrtContigFactory,
2922+
num_types>
2923+
dvb1;
2924+
dvb1.populate_dispatch_vector(rsqrt_contig_dispatch_vector);
2925+
2926+
using fn_ns::RsqrtStridedFactory;
2927+
DispatchVectorBuilder<unary_strided_impl_fn_ptr_t, RsqrtStridedFactory,
2928+
num_types>
2929+
dvb2;
2930+
dvb2.populate_dispatch_vector(rsqrt_strided_dispatch_vector);
2931+
2932+
using fn_ns::RsqrtTypeMapFactory;
2933+
DispatchVectorBuilder<int, RsqrtTypeMapFactory, num_types> dvb3;
2934+
dvb3.populate_dispatch_vector(rsqrt_output_typeid_vector);
2935+
}
2936+
2937+
} // namespace impl
2938+
29022939
// ==========================================================================================
29032940
// //
29042941

@@ -5087,6 +5124,30 @@ void init_elementwise_functions(py::module_ m)
50875124
};
50885125
m.def("_exp2_result_type", exp2_result_type_pyapi);
50895126
}
5127+
5128+
// U39: ==== RSQRT (x)
5129+
{
5130+
impl::populate_rsqrt_dispatch_vectors();
5131+
using impl::rsqrt_contig_dispatch_vector;
5132+
using impl::rsqrt_output_typeid_vector;
5133+
using impl::rsqrt_strided_dispatch_vector;
5134+
5135+
auto rsqrt_pyapi = [&](const arrayT &src, const arrayT &dst,
5136+
sycl::queue &exec_q,
5137+
const event_vecT &depends = {}) {
5138+
return py_unary_ufunc(
5139+
src, dst, exec_q, depends, rsqrt_output_typeid_vector,
5140+
rsqrt_contig_dispatch_vector, rsqrt_strided_dispatch_vector);
5141+
};
5142+
m.def("_rsqrt", rsqrt_pyapi, "", py::arg("src"), py::arg("dst"),
5143+
py::arg("sycl_queue"), py::arg("depends") = py::list());
5144+
5145+
auto rsqrt_result_type_pyapi = [&](const py::dtype &dtype) {
5146+
return py_unary_ufunc_result_type(dtype,
5147+
rsqrt_output_typeid_vector);
5148+
};
5149+
m.def("_rsqrt_result_type", rsqrt_result_type_pyapi);
5150+
}
50905151
}
50915152

50925153
} // namespace py_internal

0 commit comments

Comments
 (0)