Skip to content

Commit 821a888

Browse files
committed
Implements expm1, log, and log1p
- Also adds tests for real and complex variants where appropriate
1 parent 521867b commit 821a888

File tree

9 files changed

+1327
-9
lines changed

9 files changed

+1327
-9
lines changed

dpctl/tensor/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,12 @@
9797
cos,
9898
divide,
9999
equal,
100+
expm1,
100101
isfinite,
101102
isinf,
102103
isnan,
104+
log,
105+
log1p,
103106
multiply,
104107
sqrt,
105108
subtract,
@@ -184,9 +187,12 @@
184187
"abs",
185188
"add",
186189
"cos",
190+
"expm1",
187191
"isinf",
188192
"isnan",
189193
"isfinite",
194+
"log",
195+
"log1p",
190196
"sqrt",
191197
"divide",
192198
"multiply",

dpctl/tensor/_elementwise_funcs.py

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,26 @@
153153
# FIXME: implement U13
154154

155155
# U14: ==== EXPM1 (x)
156-
# FIXME: implement U14
156+
_expm1_docstring = """
157+
expm1(x, out=None, order='K')
158+
Computes an approximation of exp(x)-1 element-wise.
159+
Args:
160+
x (usm_ndarray):
161+
Input array, expected to have numeric data type.
162+
out (usm_ndarray):
163+
Output array to populate. Array must have the correct
164+
shape and the expected data type.
165+
order ("C","F","A","K", optional): memory layout of the new
166+
output array, if parameter `out` is `None`.
167+
Default: "K".
168+
Return:
169+
usm_ndarray:
170+
An array containing the element-wise exp(x)-1 values.
171+
"""
172+
173+
expm1 = UnaryElementwiseFunc(
174+
"expm1", ti._expm1_result_type, ti._expm1, _expm1_docstring
175+
)
157176

158177
# U15: ==== FLOOR (x)
159178
# FIXME: implement U15
@@ -210,10 +229,46 @@
210229
# FIXME: implement B14
211230

212231
# U20: ==== LOG (x)
213-
# FIXME: implement U20
232+
_log_docstring = """
233+
log(x, out=None, order='K')
234+
Computes the natural logarithm element-wise.
235+
Args:
236+
x (usm_ndarray):
237+
Input array, expected to have numeric data type.
238+
out (usm_ndarray):
239+
Output array to populate. Array must have the correct
240+
shape and the expected data type.
241+
order ("C","F","A","K", optional): memory layout of the new
242+
output array, if parameter `out` is `None`.
243+
Default: "K".
244+
Return:
245+
usm_ndarray:
246+
An array containing the element-wise natural logarithm values.
247+
"""
248+
249+
log = UnaryElementwiseFunc("log", ti._log_result_type, ti._log, _log_docstring)
214250

215251
# U21: ==== LOG1P (x)
216-
# FIXME: implement U21
252+
_log1p_docstring = """
253+
log1p(x, out=None, order='K')
254+
Computes an approximation of log(1+x) element-wise.
255+
Args:
256+
x (usm_ndarray):
257+
Input array, expected to have numeric data type.
258+
out (usm_ndarray):
259+
Output array to populate. Array must have the correct
260+
shape and the expected data type.
261+
order ("C","F","A","K", optional): memory layout of the new
262+
output array, if parameter `out` is `None`.
263+
Default: "K".
264+
Return:
265+
usm_ndarray:
266+
An array containing the element-wise log(1+x) values.
267+
"""
268+
269+
log1p = UnaryElementwiseFunc(
270+
"log1p", ti._log1p_result_type, ti._log1p, _log1p_docstring
271+
)
217272

218273
# U22: ==== LOG2 (x)
219274
# FIXME: implement U22
Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
//=== expm1.hpp - Unary function EXPM1 ------
2+
//*-C++-*--/===//
3+
//
4+
// Data Parallel Control (dpctl)
5+
//
6+
// Copyright 2020-2023 Intel Corporation
7+
//
8+
// Licensed under the Apache License, Version 2.0 (the "License");
9+
// you may not use this file except in compliance with the License.
10+
// You may obtain a copy of the License at
11+
//
12+
// http://www.apache.org/licenses/LICENSE-2.0
13+
//
14+
// Unless required by applicable law or agreed to in writing, software
15+
// distributed under the License is distributed on an "AS IS" BASIS,
16+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
// See the License for the specific language governing permissions and
18+
// limitations under the License.
19+
//
20+
//===---------------------------------------------------------------------===//
21+
///
22+
/// \file
23+
/// This file defines kernels for elementwise evaluation of EXPM1(x) function.
24+
//===---------------------------------------------------------------------===//
25+
26+
#pragma once
27+
#include <CL/sycl.hpp>
28+
#include <cmath>
29+
#include <cstddef>
30+
#include <cstdint>
31+
#include <type_traits>
32+
33+
#include "kernels/elementwise_functions/common.hpp"
34+
35+
#include "utils/offset_utils.hpp"
36+
#include "utils/type_dispatch.hpp"
37+
#include "utils/type_utils.hpp"
38+
#include <pybind11/pybind11.h>
39+
40+
namespace dpctl
41+
{
42+
namespace tensor
43+
{
44+
namespace kernels
45+
{
46+
namespace expm1
47+
{
48+
49+
namespace py = pybind11;
50+
namespace td_ns = dpctl::tensor::type_dispatch;
51+
52+
using dpctl::tensor::type_utils::is_complex;
53+
54+
template <typename argT, typename resT> struct Expm1Functor
55+
{
56+
57+
// is function constant for given argT
58+
using is_constant = typename std::false_type;
59+
// constant value, if constant
60+
// constexpr resT constant_value = resT{};
61+
// is function defined for sycl::vec
62+
using supports_vec = typename std::false_type;
63+
// do both argTy and resTy support sugroup store/load operation
64+
using supports_sg_loadstore = typename std::negation<
65+
std::disjunction<is_complex<resT>, is_complex<argT>>>;
66+
67+
resT operator()(const argT &in)
68+
{
69+
if constexpr (is_complex<argT>::value) {
70+
using realT = typename argT::value_type;
71+
// expm1(x + I*y) = expm1(x)*cos(y) - 2*sin(y / 2)^2 +
72+
// I*exp(x)*sin(y)
73+
auto x = std::real(in);
74+
const realT expm1X_val = std::expm1(x);
75+
const realT expX_val = std::exp(x);
76+
77+
x = std::imag(in);
78+
realT cosY_val;
79+
const realT sinY_val = sycl::sincos(x, &cosY_val);
80+
const realT sinhalfY_val = std::sin(x / realT{2});
81+
82+
const realT res_re =
83+
expm1X_val * cosY_val - realT{2} * sinhalfY_val * sinhalfY_val;
84+
const realT res_im = expX_val * sinY_val;
85+
return resT{res_re, res_im};
86+
}
87+
else {
88+
static_assert(std::is_floating_point_v<argT> ||
89+
std::is_same_v<argT, sycl::half>);
90+
return std::expm1(in);
91+
}
92+
}
93+
};
94+
95+
template <typename argTy,
96+
typename resTy = argTy,
97+
unsigned int vec_sz = 4,
98+
unsigned int n_vecs = 2>
99+
using Expm1ContigFunctor =
100+
elementwise_common::UnaryContigFunctor<argTy,
101+
resTy,
102+
Expm1Functor<argTy, resTy>,
103+
vec_sz,
104+
n_vecs>;
105+
106+
template <typename argTy, typename resTy, typename IndexerT>
107+
using Expm1StridedFunctor = elementwise_common::
108+
UnaryStridedFunctor<argTy, resTy, IndexerT, Expm1Functor<argTy, resTy>>;
109+
110+
template <typename T> struct Expm1OutputType
111+
{
112+
using value_type = typename std::disjunction< // disjunction is C++17
113+
// feature, supported by DPC++
114+
td_ns::TypeMapResultEntry<T, sycl::half, sycl::half>,
115+
td_ns::TypeMapResultEntry<T, float, float>,
116+
td_ns::TypeMapResultEntry<T, double, double>,
117+
td_ns::TypeMapResultEntry<T, std::complex<float>, std::complex<float>>,
118+
td_ns::
119+
TypeMapResultEntry<T, std::complex<double>, std::complex<double>>,
120+
td_ns::DefaultResultEntry<void>>::result_type;
121+
};
122+
123+
typedef sycl::event (*expm1_contig_impl_fn_ptr_t)(
124+
sycl::queue,
125+
size_t,
126+
const char *,
127+
char *,
128+
const std::vector<sycl::event> &);
129+
130+
template <typename T1, typename T2, unsigned int vec_sz, unsigned int n_vecs>
131+
class expm1_contig_kernel;
132+
133+
template <typename argTy>
134+
sycl::event expm1_contig_impl(sycl::queue exec_q,
135+
size_t nelems,
136+
const char *arg_p,
137+
char *res_p,
138+
const std::vector<sycl::event> &depends = {})
139+
{
140+
sycl::event expm1_ev = exec_q.submit([&](sycl::handler &cgh) {
141+
cgh.depends_on(depends);
142+
constexpr size_t lws = 64;
143+
constexpr unsigned int vec_sz = 4;
144+
constexpr unsigned int n_vecs = 2;
145+
static_assert(lws % vec_sz == 0);
146+
auto gws_range = sycl::range<1>(
147+
((nelems + n_vecs * lws * vec_sz - 1) / (lws * n_vecs * vec_sz)) *
148+
lws);
149+
auto lws_range = sycl::range<1>(lws);
150+
151+
using resTy = typename Expm1OutputType<argTy>::value_type;
152+
const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_p);
153+
resTy *res_tp = reinterpret_cast<resTy *>(res_p);
154+
155+
cgh.parallel_for<
156+
class expm1_contig_kernel<argTy, resTy, vec_sz, n_vecs>>(
157+
sycl::nd_range<1>(gws_range, lws_range),
158+
Expm1ContigFunctor<argTy, resTy, vec_sz, n_vecs>(arg_tp, res_tp,
159+
nelems));
160+
});
161+
return expm1_ev;
162+
}
163+
164+
template <typename fnT, typename T> struct Expm1ContigFactory
165+
{
166+
fnT get()
167+
{
168+
if constexpr (std::is_same_v<typename Expm1OutputType<T>::value_type,
169+
void>) {
170+
fnT fn = nullptr;
171+
return fn;
172+
}
173+
else {
174+
fnT fn = expm1_contig_impl<T>;
175+
return fn;
176+
}
177+
}
178+
};
179+
180+
template <typename fnT, typename T> struct Expm1TypeMapFactory
181+
{
182+
/*! @brief get typeid for output type of std::expm1(T x) */
183+
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
184+
{
185+
using rT = typename Expm1OutputType<T>::value_type;
186+
;
187+
return td_ns::GetTypeid<rT>{}.get();
188+
}
189+
};
190+
191+
template <typename T1, typename T2, typename T3> class expm1_strided_kernel;
192+
193+
typedef sycl::event (*expm1_strided_impl_fn_ptr_t)(
194+
sycl::queue,
195+
size_t,
196+
int,
197+
const py::ssize_t *,
198+
const char *,
199+
py::ssize_t,
200+
char *,
201+
py::ssize_t,
202+
const std::vector<sycl::event> &,
203+
const std::vector<sycl::event> &);
204+
205+
template <typename argTy>
206+
sycl::event
207+
expm1_strided_impl(sycl::queue exec_q,
208+
size_t nelems,
209+
int nd,
210+
const py::ssize_t *shape_and_strides,
211+
const char *arg_p,
212+
py::ssize_t arg_offset,
213+
char *res_p,
214+
py::ssize_t res_offset,
215+
const std::vector<sycl::event> &depends,
216+
const std::vector<sycl::event> &additional_depends)
217+
{
218+
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
219+
cgh.depends_on(depends);
220+
cgh.depends_on(additional_depends);
221+
222+
using resTy = typename Expm1OutputType<argTy>::value_type;
223+
using IndexerT =
224+
typename dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer;
225+
226+
IndexerT arg_res_indexer(nd, arg_offset, res_offset, shape_and_strides);
227+
228+
const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_p);
229+
resTy *res_tp = reinterpret_cast<resTy *>(res_p);
230+
231+
sycl::range<1> gRange{nelems};
232+
233+
cgh.parallel_for<expm1_strided_kernel<argTy, resTy, IndexerT>>(
234+
gRange, Expm1StridedFunctor<argTy, resTy, IndexerT>(
235+
arg_tp, res_tp, arg_res_indexer));
236+
});
237+
return comp_ev;
238+
}
239+
240+
template <typename fnT, typename T> struct Expm1StridedFactory
241+
{
242+
fnT get()
243+
{
244+
if constexpr (std::is_same_v<typename Expm1OutputType<T>::value_type,
245+
void>) {
246+
fnT fn = nullptr;
247+
return fn;
248+
}
249+
else {
250+
fnT fn = expm1_strided_impl<T>;
251+
return fn;
252+
}
253+
}
254+
};
255+
256+
} // namespace expm1
257+
} // namespace kernels
258+
} // namespace tensor
259+
} // namespace dpctl

0 commit comments

Comments
 (0)