Skip to content

Commit 4e2440e

Browse files
committed
Implements elementwise angle
1 parent 7217ffc commit 4e2440e

File tree

7 files changed

+385
-0
lines changed

7 files changed

+385
-0
lines changed

dpctl/tensor/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ set(_elementwise_sources
3737
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/acos.cpp
3838
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/acosh.cpp
3939
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/add.cpp
40+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/angle.cpp
4041
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/asin.cpp
4142
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/asinh.cpp
4243
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/atan.cpp

dpctl/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@
102102
acos,
103103
acosh,
104104
add,
105+
angle,
105106
asin,
106107
asinh,
107108
atan,
@@ -344,4 +345,5 @@
344345
"__array_api_version__",
345346
"__array_namespace_info__",
346347
"reciprocal",
348+
"angle",
347349
]

dpctl/tensor/_elementwise_funcs.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1917,3 +1917,34 @@
19171917
ti._reciprocal,
19181918
_reciprocal_docstring,
19191919
)
1920+
1921+
1922+
# U43: ==== ANGLE (x)
1923+
_angle_docstring = """
1924+
angle(x, out=None, order='K')
1925+
1926+
Computes the phase angle (also called the argument) of each element `x_i` for
1927+
input array `x`.
1928+
1929+
Args:
1930+
x (usm_ndarray):
1931+
Input array, expected to have a complex-valued floating-point data type.
1932+
out ({None, usm_ndarray}, optional):
1933+
Output array to populate.
1934+
Array have the correct shape and the expected data type.
1935+
order ("C","F","A","K", optional):
1936+
Memory layout of the newly output array, if parameter `out` is `None`.
1937+
Default: "K".
1938+
Returns:
1939+
usm_narray:
1940+
An array containing the element-wise phase angles.
1941+
The returned array has a floating-point data type determined
1942+
by the Type Promotion Rules.
1943+
"""
1944+
1945+
angle = UnaryElementwiseFunc(
1946+
"angle",
1947+
ti._angle_result_type,
1948+
ti._angle,
1949+
_angle_docstring,
1950+
)
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
//=== angle.hpp - Unary function ANGLE ------
2+
//*-C++-*--/===//
3+
//
4+
// Data Parallel Control (dpctl)
5+
//
6+
// Copyright 2020-2023 Intel Corporation
7+
//
8+
// Licensed under the Apache License, Version 2.0 (the "License");
9+
// you may not use this file except in compliance with the License.
10+
// You may obtain a copy of the License at
11+
//
12+
// http://www.apache.org/licenses/LICENSE-2.0
13+
//
14+
// Unless required by applicable law or agreed to in writing, software
15+
// distributed under the License is distributed on an "AS IS" BASIS,
16+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
// See the License for the specific language governing permissions and
18+
// limitations under the License.
19+
//
20+
//===---------------------------------------------------------------------===//
21+
///
22+
/// \file
23+
/// This file defines kernels for elementwise evaluation of ANGLE(x) function.
24+
//===---------------------------------------------------------------------===//
25+
26+
#pragma once
27+
#include <cmath>
28+
#include <complex>
29+
#include <cstddef>
30+
#include <cstdint>
31+
#include <sycl/sycl.hpp>
32+
#include <type_traits>
33+
34+
#include "kernels/elementwise_functions/common.hpp"
35+
#include "sycl_complex.hpp"
36+
37+
#include "utils/offset_utils.hpp"
38+
#include "utils/type_dispatch.hpp"
39+
#include "utils/type_utils.hpp"
40+
#include <pybind11/pybind11.h>
41+
42+
namespace dpctl
43+
{
44+
namespace tensor
45+
{
46+
namespace kernels
47+
{
48+
namespace angle
49+
{
50+
51+
namespace py = pybind11;
52+
namespace td_ns = dpctl::tensor::type_dispatch;
53+
54+
using dpctl::tensor::type_utils::is_complex;
55+
56+
template <typename argT, typename resT> struct AngleFunctor
57+
{
58+
59+
// is function constant for given argT
60+
using is_constant = typename std::false_type;
61+
// constant value, if constant
62+
// constexpr resT constant_value = resT{};
63+
// is function defined for sycl::vec
64+
using supports_vec = typename std::false_type;
65+
// do both argTy and resTy support sugroup store/load operation
66+
using supports_sg_loadstore = typename std::negation<
67+
std::disjunction<is_complex<resT>, is_complex<argT>>>;
68+
69+
resT operator()(const argT &in) const
70+
{
71+
#ifdef USE_SYCL_FOR_COMPLEX_TYPES
72+
using rT = typename argT::value_type;
73+
74+
return exprm_ns::arg(exprm_ns::complex<rT>(in)); // std::arg(in);
75+
#else
76+
return std::arg(in);
77+
#endif
78+
}
79+
};
80+
81+
template <typename argTy,
82+
typename resTy = argTy,
83+
unsigned int vec_sz = 4,
84+
unsigned int n_vecs = 2>
85+
using AngleContigFunctor =
86+
elementwise_common::UnaryContigFunctor<argTy,
87+
resTy,
88+
AngleFunctor<argTy, resTy>,
89+
vec_sz,
90+
n_vecs>;
91+
92+
template <typename argTy, typename resTy, typename IndexerT>
93+
using AngleStridedFunctor = elementwise_common::
94+
UnaryStridedFunctor<argTy, resTy, IndexerT, AngleFunctor<argTy, resTy>>;
95+
96+
template <typename T> struct AngleOutputType
97+
{
98+
using value_type = typename std::disjunction< // disjunction is C++17
99+
// feature, supported by DPC++
100+
td_ns::TypeMapResultEntry<T, std::complex<float>, float>,
101+
td_ns::TypeMapResultEntry<T, std::complex<double>, double>,
102+
td_ns::DefaultResultEntry<void>>::result_type;
103+
};
104+
105+
template <typename T1, typename T2, unsigned int vec_sz, unsigned int n_vecs>
106+
class angle_contig_kernel;
107+
108+
template <typename argTy>
109+
sycl::event angle_contig_impl(sycl::queue &exec_q,
110+
size_t nelems,
111+
const char *arg_p,
112+
char *res_p,
113+
const std::vector<sycl::event> &depends = {})
114+
{
115+
return elementwise_common::unary_contig_impl<
116+
argTy, AngleOutputType, AngleContigFunctor, angle_contig_kernel>(
117+
exec_q, nelems, arg_p, res_p, depends);
118+
}
119+
120+
template <typename fnT, typename T> struct AngleContigFactory
121+
{
122+
fnT get()
123+
{
124+
if constexpr (std::is_same_v<typename AngleOutputType<T>::value_type,
125+
void>) {
126+
fnT fn = nullptr;
127+
return fn;
128+
}
129+
else {
130+
fnT fn = angle_contig_impl<T>;
131+
return fn;
132+
}
133+
}
134+
};
135+
136+
template <typename fnT, typename T> struct AngleTypeMapFactory
137+
{
138+
/*! @brief get typeid for output type of std::arg(T x) */
139+
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
140+
{
141+
using rT = typename AngleOutputType<T>::value_type;
142+
return td_ns::GetTypeid<rT>{}.get();
143+
}
144+
};
145+
146+
template <typename T1, typename T2, typename T3> class angle_strided_kernel;
147+
148+
template <typename argTy>
149+
sycl::event
150+
angle_strided_impl(sycl::queue &exec_q,
151+
size_t nelems,
152+
int nd,
153+
const py::ssize_t *shape_and_strides,
154+
const char *arg_p,
155+
py::ssize_t arg_offset,
156+
char *res_p,
157+
py::ssize_t res_offset,
158+
const std::vector<sycl::event> &depends,
159+
const std::vector<sycl::event> &additional_depends)
160+
{
161+
return elementwise_common::unary_strided_impl<
162+
argTy, AngleOutputType, AngleStridedFunctor, angle_strided_kernel>(
163+
exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p,
164+
res_offset, depends, additional_depends);
165+
}
166+
167+
template <typename fnT, typename T> struct AngleStridedFactory
168+
{
169+
fnT get()
170+
{
171+
if constexpr (std::is_same_v<typename AngleOutputType<T>::value_type,
172+
void>) {
173+
fnT fn = nullptr;
174+
return fn;
175+
}
176+
else {
177+
fnT fn = angle_strided_impl<T>;
178+
return fn;
179+
}
180+
}
181+
};
182+
183+
} // namespace angle
184+
} // namespace kernels
185+
} // namespace tensor
186+
} // namespace dpctl
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
//===----------- Implementation of _tensor_impl module ---------*-C++-*-/===//
2+
//
3+
// Data Parallel Control (dpctl)
4+
//
5+
// Copyright 2020-2023 Intel Corporation
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// Unless required by applicable law or agreed to in writing, software
14+
// distributed under the License is distributed on an "AS IS" BASIS,
15+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
// See the License for the specific language governing permissions and
17+
// limitations under the License.
18+
//
19+
//===----------------------------------------------------------------------===//
20+
///
21+
/// \file
22+
/// This file defines functions of dpctl.tensor._tensor_impl extensions,
23+
/// specifically functions for elementwise operations.
24+
//===----------------------------------------------------------------------===//
25+
26+
#include "dpctl4pybind11.hpp"
27+
#include <CL/sycl.hpp>
28+
#include <pybind11/numpy.h>
29+
#include <pybind11/pybind11.h>
30+
#include <pybind11/stl.h>
31+
#include <vector>
32+
33+
#include "angle.hpp"
34+
#include "elementwise_functions.hpp"
35+
#include "utils/type_dispatch.hpp"
36+
37+
#include "kernels/elementwise_functions/angle.hpp"
38+
#include "kernels/elementwise_functions/common.hpp"
39+
40+
namespace py = pybind11;
41+
42+
namespace dpctl
43+
{
44+
namespace tensor
45+
{
46+
namespace py_internal
47+
{
48+
49+
namespace td_ns = dpctl::tensor::type_dispatch;
50+
51+
namespace ew_cmn_ns = dpctl::tensor::kernels::elementwise_common;
52+
using ew_cmn_ns::unary_contig_impl_fn_ptr_t;
53+
using ew_cmn_ns::unary_strided_impl_fn_ptr_t;
54+
55+
// U43: ==== ANGLE (x)
56+
namespace impl
57+
{
58+
59+
namespace angle_fn_ns = dpctl::tensor::kernels::angle;
60+
61+
static unary_contig_impl_fn_ptr_t
62+
angle_contig_dispatch_vector[td_ns::num_types];
63+
static int angle_output_typeid_vector[td_ns::num_types];
64+
static unary_strided_impl_fn_ptr_t
65+
angle_strided_dispatch_vector[td_ns::num_types];
66+
67+
void populate_angle_dispatch_vectors(void)
68+
{
69+
using namespace td_ns;
70+
namespace fn_ns = angle_fn_ns;
71+
72+
using fn_ns::AngleContigFactory;
73+
DispatchVectorBuilder<unary_contig_impl_fn_ptr_t, AngleContigFactory,
74+
num_types>
75+
dvb1;
76+
dvb1.populate_dispatch_vector(angle_contig_dispatch_vector);
77+
78+
using fn_ns::AngleStridedFactory;
79+
DispatchVectorBuilder<unary_strided_impl_fn_ptr_t, AngleStridedFactory,
80+
num_types>
81+
dvb2;
82+
dvb2.populate_dispatch_vector(angle_strided_dispatch_vector);
83+
84+
using fn_ns::AngleTypeMapFactory;
85+
DispatchVectorBuilder<int, AngleTypeMapFactory, num_types> dvb3;
86+
dvb3.populate_dispatch_vector(angle_output_typeid_vector);
87+
};
88+
89+
} // namespace impl
90+
91+
void init_angle(py::module_ m)
92+
{
93+
using arrayT = dpctl::tensor::usm_ndarray;
94+
using event_vecT = std::vector<sycl::event>;
95+
{
96+
impl::populate_angle_dispatch_vectors();
97+
using impl::angle_contig_dispatch_vector;
98+
using impl::angle_output_typeid_vector;
99+
using impl::angle_strided_dispatch_vector;
100+
101+
auto angle_pyapi = [&](const arrayT &src, const arrayT &dst,
102+
sycl::queue &exec_q,
103+
const event_vecT &depends = {}) {
104+
return py_unary_ufunc(
105+
src, dst, exec_q, depends, angle_output_typeid_vector,
106+
angle_contig_dispatch_vector, angle_strided_dispatch_vector);
107+
};
108+
m.def("_angle", angle_pyapi, "", py::arg("src"), py::arg("dst"),
109+
py::arg("sycl_queue"), py::arg("depends") = py::list());
110+
111+
auto angle_result_type_pyapi = [&](const py::dtype &dtype) {
112+
return py_unary_ufunc_result_type(dtype,
113+
angle_output_typeid_vector);
114+
};
115+
m.def("_angle_result_type", angle_result_type_pyapi);
116+
}
117+
}
118+
119+
} // namespace py_internal
120+
} // namespace tensor
121+
} // namespace dpctl

0 commit comments

Comments
 (0)