Skip to content

Commit 0b5f940

Browse files
committed
Implements dpctl.tensor.cbrt
1 parent af04d34 commit 0b5f940

File tree

4 files changed

+260
-0
lines changed

4 files changed

+260
-0
lines changed

dpctl/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@
110110
bitwise_or,
111111
bitwise_right_shift,
112112
bitwise_xor,
113+
cbrt,
113114
ceil,
114115
conj,
115116
cos,
@@ -314,4 +315,5 @@
314315
"argmax",
315316
"argmin",
316317
"prod",
318+
"cbrt",
317319
]

dpctl/tensor/_elementwise_funcs.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1761,3 +1761,30 @@
17611761
hypot = BinaryElementwiseFunc(
17621762
"hypot", ti._hypot_result_type, ti._hypot, _hypot_docstring_
17631763
)
1764+
1765+
1766+
# U33: ==== CBRT (x)
1767+
_cbrt_docstring_ = """
1768+
cbrt(x, out=None, order='K')
1769+
1770+
Computes positive cube-root for each element `x_i` for input array `x`.
1771+
1772+
Args:
1773+
x (usm_ndarray):
1774+
Input array, expected to have a real floating-point data type.
1775+
out ({None, usm_ndarray}, optional):
1776+
Output array to populate.
1777+
Array have the correct shape and the expected data type.
1778+
order ("C","F","A","K", optional):
1779+
Memory layout of the newly output array, if parameter `out` is `None`.
1780+
Default: "K".
1781+
Returns:
1782+
usm_narray:
1783+
An array containing the element-wise positive cube-root.
1784+
The data type of the returned array is determined by
1785+
the Type Promotion Rules.
1786+
"""
1787+
1788+
cbrt = UnaryElementwiseFunc(
1789+
"cbrt", ti._cbrt_result_type, ti._cbrt, _cbrt_docstring_
1790+
)
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
//=== cbrt.hpp - Unary function CBRT ------ *-C++-*--/===//
2+
//
3+
// Data Parallel Control (dpctl)
4+
//
5+
// Copyright 2020-2023 Intel Corporation
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// Unless required by applicable law or agreed to in writing, software
14+
// distributed under the License is distributed on an "AS IS" BASIS,
15+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
// See the License for the specific language governing permissions and
17+
// limitations under the License.
18+
//
19+
//===---------------------------------------------------------------------===//
20+
///
21+
/// \file
22+
/// This file defines kernels for elementwise evaluation of CBRT(x)
23+
/// function that compute a square root.
24+
//===---------------------------------------------------------------------===//
25+
26+
#pragma once
27+
#include <CL/sycl.hpp>
28+
#include <cmath>
29+
#include <cstddef>
30+
#include <cstdint>
31+
#include <type_traits>
32+
33+
#include "kernels/elementwise_functions/common.hpp"
34+
35+
#include "utils/offset_utils.hpp"
36+
#include "utils/type_dispatch.hpp"
37+
#include "utils/type_utils.hpp"
38+
#include <pybind11/pybind11.h>
39+
40+
namespace dpctl
41+
{
42+
namespace tensor
43+
{
44+
namespace kernels
45+
{
46+
namespace cbrt
47+
{
48+
49+
namespace py = pybind11;
50+
namespace td_ns = dpctl::tensor::type_dispatch;
51+
52+
template <typename argT, typename resT> struct CbrtFunctor
53+
{
54+
55+
// is function constant for given argT
56+
using is_constant = typename std::false_type;
57+
// constant value, if constant
58+
// constexpr resT constant_value = resT{};
59+
// is function defined for sycl::vec
60+
using supports_vec = typename std::false_type;
61+
// do both argTy and resTy support sugroup store/load operation
62+
using supports_sg_loadstore = typename std::true_type;
63+
64+
resT operator()(const argT &in) const
65+
{
66+
return sycl::cbrt(in);
67+
}
68+
};
69+
70+
template <typename argTy,
71+
typename resTy = argTy,
72+
unsigned int vec_sz = 4,
73+
unsigned int n_vecs = 2>
74+
using CbrtContigFunctor = elementwise_common::
75+
UnaryContigFunctor<argTy, resTy, CbrtFunctor<argTy, resTy>, vec_sz, n_vecs>;
76+
77+
template <typename argTy, typename resTy, typename IndexerT>
78+
using CbrtStridedFunctor = elementwise_common::
79+
UnaryStridedFunctor<argTy, resTy, IndexerT, CbrtFunctor<argTy, resTy>>;
80+
81+
template <typename T> struct CbrtOutputType
82+
{
83+
using value_type = typename std::disjunction< // disjunction is C++17
84+
// feature, supported by DPC++
85+
td_ns::TypeMapResultEntry<T, sycl::half, sycl::half>,
86+
td_ns::TypeMapResultEntry<T, float, float>,
87+
td_ns::TypeMapResultEntry<T, double, double>,
88+
td_ns::DefaultResultEntry<void>>::result_type;
89+
};
90+
91+
template <typename T1, typename T2, unsigned int vec_sz, unsigned int n_vecs>
92+
class cbrt_contig_kernel;
93+
94+
template <typename argTy>
95+
sycl::event cbrt_contig_impl(sycl::queue &exec_q,
96+
size_t nelems,
97+
const char *arg_p,
98+
char *res_p,
99+
const std::vector<sycl::event> &depends = {})
100+
{
101+
return elementwise_common::unary_contig_impl<
102+
argTy, CbrtOutputType, CbrtContigFunctor, cbrt_contig_kernel>(
103+
exec_q, nelems, arg_p, res_p, depends);
104+
}
105+
106+
template <typename fnT, typename T> struct CbrtContigFactory
107+
{
108+
fnT get()
109+
{
110+
if constexpr (std::is_same_v<typename CbrtOutputType<T>::value_type,
111+
void>) {
112+
fnT fn = nullptr;
113+
return fn;
114+
}
115+
else {
116+
fnT fn = cbrt_contig_impl<T>;
117+
return fn;
118+
}
119+
}
120+
};
121+
122+
template <typename fnT, typename T> struct CbrtTypeMapFactory
123+
{
124+
/*! @brief get typeid for output type of std::cbrt(T x) */
125+
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
126+
{
127+
using rT = typename CbrtOutputType<T>::value_type;
128+
return td_ns::GetTypeid<rT>{}.get();
129+
}
130+
};
131+
132+
template <typename T1, typename T2, typename T3> class cbrt_strided_kernel;
133+
134+
template <typename argTy>
135+
sycl::event
136+
cbrt_strided_impl(sycl::queue &exec_q,
137+
size_t nelems,
138+
int nd,
139+
const py::ssize_t *shape_and_strides,
140+
const char *arg_p,
141+
py::ssize_t arg_offset,
142+
char *res_p,
143+
py::ssize_t res_offset,
144+
const std::vector<sycl::event> &depends,
145+
const std::vector<sycl::event> &additional_depends)
146+
{
147+
return elementwise_common::unary_strided_impl<
148+
argTy, CbrtOutputType, CbrtStridedFunctor, cbrt_strided_kernel>(
149+
exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p,
150+
res_offset, depends, additional_depends);
151+
}
152+
153+
template <typename fnT, typename T> struct CbrtStridedFactory
154+
{
155+
fnT get()
156+
{
157+
if constexpr (std::is_same_v<typename CbrtOutputType<T>::value_type,
158+
void>) {
159+
fnT fn = nullptr;
160+
return fn;
161+
}
162+
else {
163+
fnT fn = cbrt_strided_impl<T>;
164+
return fn;
165+
}
166+
}
167+
};
168+
169+
} // namespace cbrt
170+
} // namespace kernels
171+
} // namespace tensor
172+
} // namespace dpctl

dpctl/tensor/libtensor/source/elementwise_functions.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
#include "kernels/elementwise_functions/bitwise_or.hpp"
4949
#include "kernels/elementwise_functions/bitwise_right_shift.hpp"
5050
#include "kernels/elementwise_functions/bitwise_xor.hpp"
51+
#include "kernels/elementwise_functions/cbrt.hpp"
5152
#include "kernels/elementwise_functions/ceil.hpp"
5253
#include "kernels/elementwise_functions/conj.hpp"
5354
#include "kernels/elementwise_functions/cos.hpp"
@@ -2788,6 +2789,41 @@ void populate_hypot_dispatch_tables(void)
27882789

27892790
} // namespace impl
27902791

2792+
// U33: ==== CBRT (x)
2793+
namespace impl
2794+
{
2795+
2796+
namespace cbrt_fn_ns = dpctl::tensor::kernels::cbrt;
2797+
2798+
static unary_contig_impl_fn_ptr_t cbrt_contig_dispatch_vector[td_ns::num_types];
2799+
static int cbrt_output_typeid_vector[td_ns::num_types];
2800+
static unary_strided_impl_fn_ptr_t
2801+
cbrt_strided_dispatch_vector[td_ns::num_types];
2802+
2803+
void populate_cbrt_dispatch_vectors(void)
2804+
{
2805+
using namespace td_ns;
2806+
namespace fn_ns = cbrt_fn_ns;
2807+
2808+
using fn_ns::CbrtContigFactory;
2809+
DispatchVectorBuilder<unary_contig_impl_fn_ptr_t, CbrtContigFactory,
2810+
num_types>
2811+
dvb1;
2812+
dvb1.populate_dispatch_vector(cbrt_contig_dispatch_vector);
2813+
2814+
using fn_ns::CbrtStridedFactory;
2815+
DispatchVectorBuilder<unary_strided_impl_fn_ptr_t, CbrtStridedFactory,
2816+
num_types>
2817+
dvb2;
2818+
dvb2.populate_dispatch_vector(cbrt_strided_dispatch_vector);
2819+
2820+
using fn_ns::CbrtTypeMapFactory;
2821+
DispatchVectorBuilder<int, CbrtTypeMapFactory, num_types> dvb3;
2822+
dvb3.populate_dispatch_vector(cbrt_output_typeid_vector);
2823+
}
2824+
2825+
} // namespace impl
2826+
27912827
// ==========================================================================================
27922828
// //
27932829

@@ -4889,6 +4925,29 @@ void init_elementwise_functions(py::module_ m)
48894925
py::arg("depends") = py::list());
48904926
m.def("_hypot_result_type", hypot_result_type_pyapi, "");
48914927
}
4928+
4929+
// U37: ==== CBRT (x)
4930+
{
4931+
impl::populate_cbrt_dispatch_vectors();
4932+
using impl::cbrt_contig_dispatch_vector;
4933+
using impl::cbrt_output_typeid_vector;
4934+
using impl::cbrt_strided_dispatch_vector;
4935+
4936+
auto cbrt_pyapi = [&](const arrayT &src, const arrayT &dst,
4937+
sycl::queue &exec_q,
4938+
const event_vecT &depends = {}) {
4939+
return py_unary_ufunc(
4940+
src, dst, exec_q, depends, cbrt_output_typeid_vector,
4941+
cbrt_contig_dispatch_vector, cbrt_strided_dispatch_vector);
4942+
};
4943+
m.def("_cbrt", cbrt_pyapi, "", py::arg("src"), py::arg("dst"),
4944+
py::arg("sycl_queue"), py::arg("depends") = py::list());
4945+
4946+
auto cbrt_result_type_pyapi = [&](const py::dtype &dtype) {
4947+
return py_unary_ufunc_result_type(dtype, cbrt_output_typeid_vector);
4948+
};
4949+
m.def("_cbrt_result_type", cbrt_result_type_pyapi);
4950+
}
48924951
}
48934952

48944953
} // namespace py_internal

0 commit comments

Comments
 (0)