Skip to content

Commit f82cdc4

Browse files
Improve dpnp.cos() and dpnp.sin() implementations (#1471)
* Improve dpnp.cos() and dpnp.sin() implementations * Update dpnp/backend/extensions/vm/vm_py.cpp Co-authored-by: vlad-perevezentsev <[email protected]> * Update dpnp/backend/extensions/vm/vm_py.cpp Co-authored-by: vlad-perevezentsev <[email protected]> --------- Co-authored-by: vlad-perevezentsev <[email protected]>
1 parent f49da0e commit f82cdc4

File tree

15 files changed

+580
-105
lines changed

15 files changed

+580
-105
lines changed

.github/workflows/conda-package.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ env:
2323
test_umath.py
2424
test_usm_type.py
2525
third_party/cupy/math_tests/test_explog.py
26+
third_party/cupy/math_tests/test_trigonometric.py
2627
third_party/cupy/sorting_tests/test_sort.py
2728
VER_JSON_NAME: 'version.json'
2829
VER_SCRIPT1: "import json; f = open('version.json', 'r'); j = json.load(f); f.close(); "

dpnp/backend/extensions/vm/cos.hpp

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2023, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <CL/sycl.hpp>
29+
30+
#include "common.hpp"
31+
#include "types_matrix.hpp"
32+
33+
namespace dpnp
34+
{
35+
namespace backend
36+
{
37+
namespace ext
38+
{
39+
namespace vm
40+
{
41+
template <typename T>
42+
sycl::event cos_contig_impl(sycl::queue exec_q,
43+
const std::int64_t n,
44+
const char *in_a,
45+
char *out_y,
46+
const std::vector<sycl::event> &depends)
47+
{
48+
type_utils::validate_type_for_device<T>(exec_q);
49+
50+
const T *a = reinterpret_cast<const T *>(in_a);
51+
T *y = reinterpret_cast<T *>(out_y);
52+
53+
return mkl_vm::cos(exec_q,
54+
n, // number of elements to be calculated
55+
a, // pointer `a` containing input vector of size n
56+
y, // pointer `y` to the output vector of size n
57+
depends);
58+
}
59+
60+
template <typename fnT, typename T>
61+
struct CosContigFactory
62+
{
63+
fnT get()
64+
{
65+
if constexpr (std::is_same_v<
66+
typename types::CosOutputType<T>::value_type, void>)
67+
{
68+
return nullptr;
69+
}
70+
else {
71+
return cos_contig_impl<T>;
72+
}
73+
}
74+
};
75+
} // namespace vm
76+
} // namespace ext
77+
} // namespace backend
78+
} // namespace dpnp

dpnp/backend/extensions/vm/sin.hpp

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2023, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <CL/sycl.hpp>
29+
30+
#include "common.hpp"
31+
#include "types_matrix.hpp"
32+
33+
namespace dpnp
34+
{
35+
namespace backend
36+
{
37+
namespace ext
38+
{
39+
namespace vm
40+
{
41+
template <typename T>
42+
sycl::event sin_contig_impl(sycl::queue exec_q,
43+
const std::int64_t n,
44+
const char *in_a,
45+
char *out_y,
46+
const std::vector<sycl::event> &depends)
47+
{
48+
type_utils::validate_type_for_device<T>(exec_q);
49+
50+
const T *a = reinterpret_cast<const T *>(in_a);
51+
T *y = reinterpret_cast<T *>(out_y);
52+
53+
return mkl_vm::sin(exec_q,
54+
n, // number of elements to be calculated
55+
a, // pointer `a` containing input vector of size n
56+
y, // pointer `y` to the output vector of size n
57+
depends);
58+
}
59+
60+
template <typename fnT, typename T>
61+
struct SinContigFactory
62+
{
63+
fnT get()
64+
{
65+
if constexpr (std::is_same_v<
66+
typename types::SinOutputType<T>::value_type, void>)
67+
{
68+
return nullptr;
69+
}
70+
else {
71+
return sin_contig_impl<T>;
72+
}
73+
}
74+
};
75+
} // namespace vm
76+
} // namespace ext
77+
} // namespace backend
78+
} // namespace dpnp

dpnp/backend/extensions/vm/types_matrix.hpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,25 @@ struct DivOutputType
6868
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
6969
};
7070

71+
/**
72+
* @brief A factory to define pairs of supported types for which
73+
* MKL VM library provides support in oneapi::mkl::vm::cos<T> function.
74+
*
75+
* @tparam T Type of input vector `a` and of result vector `y`.
76+
*/
77+
template <typename T>
78+
struct CosOutputType
79+
{
80+
using value_type = typename std::disjunction<
81+
dpctl_td_ns::
82+
TypeMapResultEntry<T, std::complex<double>, std::complex<double>>,
83+
dpctl_td_ns::
84+
TypeMapResultEntry<T, std::complex<float>, std::complex<float>>,
85+
dpctl_td_ns::TypeMapResultEntry<T, double, double>,
86+
dpctl_td_ns::TypeMapResultEntry<T, float, float>,
87+
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
88+
};
89+
7190
/**
7291
* @brief A factory to define pairs of supported types for which
7392
* MKL VM library provides support in oneapi::mkl::vm::ln<T> function.
@@ -86,6 +105,25 @@ struct LnOutputType
86105
dpctl_td_ns::TypeMapResultEntry<T, float, float>,
87106
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
88107
};
108+
109+
/**
110+
* @brief A factory to define pairs of supported types for which
111+
* MKL VM library provides support in oneapi::mkl::vm::sin<T> function.
112+
*
113+
* @tparam T Type of input vector `a` and of result vector `y`.
114+
*/
115+
template <typename T>
116+
struct SinOutputType
117+
{
118+
using value_type = typename std::disjunction<
119+
dpctl_td_ns::
120+
TypeMapResultEntry<T, std::complex<double>, std::complex<double>>,
121+
dpctl_td_ns::
122+
TypeMapResultEntry<T, std::complex<float>, std::complex<float>>,
123+
dpctl_td_ns::TypeMapResultEntry<T, double, double>,
124+
dpctl_td_ns::TypeMapResultEntry<T, float, float>,
125+
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
126+
};
89127
} // namespace types
90128
} // namespace vm
91129
} // namespace ext

dpnp/backend/extensions/vm/vm_py.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@
3131
#include <pybind11/stl.h>
3232

3333
#include "common.hpp"
34+
#include "cos.hpp"
3435
#include "div.hpp"
3536
#include "ln.hpp"
37+
#include "sin.hpp"
3638
#include "types_matrix.hpp"
3739

3840
namespace py = pybind11;
@@ -43,7 +45,9 @@ using vm_ext::unary_impl_fn_ptr_t;
4345

4446
static binary_impl_fn_ptr_t div_dispatch_vector[dpctl_td_ns::num_types];
4547

48+
static unary_impl_fn_ptr_t cos_dispatch_vector[dpctl_td_ns::num_types];
4649
static unary_impl_fn_ptr_t ln_dispatch_vector[dpctl_td_ns::num_types];
50+
static unary_impl_fn_ptr_t sin_dispatch_vector[dpctl_td_ns::num_types];
4751

4852
PYBIND11_MODULE(_vm_impl, m)
4953
{
@@ -80,6 +84,34 @@ PYBIND11_MODULE(_vm_impl, m)
8084
py::arg("dst"));
8185
}
8286

87+
// UnaryUfunc: ==== Cos(x) ====
88+
{
89+
vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t,
90+
vm_ext::CosContigFactory>(
91+
cos_dispatch_vector);
92+
93+
auto cos_pyapi = [&](sycl::queue exec_q, arrayT src, arrayT dst,
94+
const event_vecT &depends = {}) {
95+
return vm_ext::unary_ufunc(exec_q, src, dst, depends,
96+
cos_dispatch_vector);
97+
};
98+
m.def("_cos", cos_pyapi,
99+
"Call `cos` function from OneMKL VM library to compute "
100+
"cosine of vector elements",
101+
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"),
102+
py::arg("depends") = py::list());
103+
104+
auto cos_need_to_call_pyapi = [&](sycl::queue exec_q, arrayT src,
105+
arrayT dst) {
106+
return vm_ext::need_to_call_unary_ufunc(exec_q, src, dst,
107+
cos_dispatch_vector);
108+
};
109+
m.def("_mkl_cos_to_call", cos_need_to_call_pyapi,
110+
"Check input arguments to answer if `cos` function from "
111+
"OneMKL VM library can be used",
112+
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"));
113+
}
114+
83115
// UnaryUfunc: ==== Ln(x) ====
84116
{
85117
vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t,
@@ -107,4 +139,32 @@ PYBIND11_MODULE(_vm_impl, m)
107139
"OneMKL VM library can be used",
108140
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"));
109141
}
142+
143+
// UnaryUfunc: ==== Sin(x) ====
144+
{
145+
vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t,
146+
vm_ext::SinContigFactory>(
147+
sin_dispatch_vector);
148+
149+
auto sin_pyapi = [&](sycl::queue exec_q, arrayT src, arrayT dst,
150+
const event_vecT &depends = {}) {
151+
return vm_ext::unary_ufunc(exec_q, src, dst, depends,
152+
sin_dispatch_vector);
153+
};
154+
m.def("_sin", sin_pyapi,
155+
"Call `sin` function from OneMKL VM library to compute "
156+
"sine of vector elements",
157+
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"),
158+
py::arg("depends") = py::list());
159+
160+
auto sin_need_to_call_pyapi = [&](sycl::queue exec_q, arrayT src,
161+
arrayT dst) {
162+
return vm_ext::need_to_call_unary_ufunc(exec_q, src, dst,
163+
sin_dispatch_vector);
164+
};
165+
m.def("_mkl_sin_to_call", sin_need_to_call_pyapi,
166+
"Check input arguments to answer if `sin` function from "
167+
"OneMKL VM library can be used",
168+
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"));
169+
}
110170
}

dpnp/backend/include/dpnp_iface_fptr.hpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,7 @@ enum class DPNPFuncName : size_t
145145
DPNP_FN_CORRELATE_EXT, /**< Used in numpy.correlate() impl, requires extra
146146
parameters */
147147
DPNP_FN_COS, /**< Used in numpy.cos() impl */
148-
DPNP_FN_COS_EXT, /**< Used in numpy.cos() impl, requires extra parameters */
149-
DPNP_FN_COSH, /**< Used in numpy.cosh() impl */
148+
DPNP_FN_COSH, /**< Used in numpy.cosh() impl */
150149
DPNP_FN_COSH_EXT, /**< Used in numpy.cosh() impl, requires extra parameters
151150
*/
152151
DPNP_FN_COUNT_NONZERO, /**< Used in numpy.count_nonzero() impl */
@@ -475,8 +474,7 @@ enum class DPNPFuncName : size_t
475474
DPNP_FN_SIGN_EXT, /**< Used in numpy.sign() impl, requires extra parameters
476475
*/
477476
DPNP_FN_SIN, /**< Used in numpy.sin() impl */
478-
DPNP_FN_SIN_EXT, /**< Used in numpy.sin() impl, requires extra parameters */
479-
DPNP_FN_SINH, /**< Used in numpy.sinh() impl */
477+
DPNP_FN_SINH, /**< Used in numpy.sinh() impl */
480478
DPNP_FN_SINH_EXT, /**< Used in numpy.sinh() impl, requires extra parameters
481479
*/
482480
DPNP_FN_SORT, /**< Used in numpy.sort() impl */

dpnp/backend/kernels/dpnp_krnl_elemwise.cpp

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -486,15 +486,6 @@ static void func_map_init_elemwise_1arg_2type(func_map_t &fmap)
486486
fmap[DPNPFuncName::DPNP_FN_COS][eft_DBL][eft_DBL] = {
487487
eft_DBL, (void *)dpnp_cos_c_default<double, double>};
488488

489-
fmap[DPNPFuncName::DPNP_FN_COS_EXT][eft_INT][eft_INT] = {
490-
eft_DBL, (void *)dpnp_cos_c_ext<int32_t, double>};
491-
fmap[DPNPFuncName::DPNP_FN_COS_EXT][eft_LNG][eft_LNG] = {
492-
eft_DBL, (void *)dpnp_cos_c_ext<int64_t, double>};
493-
fmap[DPNPFuncName::DPNP_FN_COS_EXT][eft_FLT][eft_FLT] = {
494-
eft_FLT, (void *)dpnp_cos_c_ext<float, float>};
495-
fmap[DPNPFuncName::DPNP_FN_COS_EXT][eft_DBL][eft_DBL] = {
496-
eft_DBL, (void *)dpnp_cos_c_ext<double, double>};
497-
498489
fmap[DPNPFuncName::DPNP_FN_COSH][eft_INT][eft_INT] = {
499490
eft_DBL, (void *)dpnp_cosh_c_default<int32_t, double>};
500491
fmap[DPNPFuncName::DPNP_FN_COSH][eft_LNG][eft_LNG] = {
@@ -711,15 +702,6 @@ static void func_map_init_elemwise_1arg_2type(func_map_t &fmap)
711702
fmap[DPNPFuncName::DPNP_FN_SIN][eft_DBL][eft_DBL] = {
712703
eft_DBL, (void *)dpnp_sin_c_default<double, double>};
713704

714-
fmap[DPNPFuncName::DPNP_FN_SIN_EXT][eft_INT][eft_INT] = {
715-
eft_DBL, (void *)dpnp_sin_c_ext<int32_t, double>};
716-
fmap[DPNPFuncName::DPNP_FN_SIN_EXT][eft_LNG][eft_LNG] = {
717-
eft_DBL, (void *)dpnp_sin_c_ext<int64_t, double>};
718-
fmap[DPNPFuncName::DPNP_FN_SIN_EXT][eft_FLT][eft_FLT] = {
719-
eft_FLT, (void *)dpnp_sin_c_ext<float, float>};
720-
fmap[DPNPFuncName::DPNP_FN_SIN_EXT][eft_DBL][eft_DBL] = {
721-
eft_DBL, (void *)dpnp_sin_c_ext<double, double>};
722-
723705
fmap[DPNPFuncName::DPNP_FN_SINH][eft_INT][eft_INT] = {
724706
eft_DBL, (void *)dpnp_sinh_c_default<int32_t, double>};
725707
fmap[DPNPFuncName::DPNP_FN_SINH][eft_LNG][eft_LNG] = {

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
9090
DPNP_FN_COPYTO_EXT
9191
DPNP_FN_CORRELATE
9292
DPNP_FN_CORRELATE_EXT
93-
DPNP_FN_COS
94-
DPNP_FN_COS_EXT
9593
DPNP_FN_COSH
9694
DPNP_FN_COSH_EXT
9795
DPNP_FN_COUNT_NONZERO
@@ -293,8 +291,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
293291
DPNP_FN_SEARCHSORTED_EXT
294292
DPNP_FN_SIGN
295293
DPNP_FN_SIGN_EXT
296-
DPNP_FN_SIN
297-
DPNP_FN_SIN_EXT
298294
DPNP_FN_SINH
299295
DPNP_FN_SINH_EXT
300296
DPNP_FN_SORT
@@ -546,7 +542,6 @@ cpdef dpnp_descriptor dpnp_arcsinh(dpnp_descriptor array1)
546542
cpdef dpnp_descriptor dpnp_arctan(dpnp_descriptor array1, dpnp_descriptor out)
547543
cpdef dpnp_descriptor dpnp_arctanh(dpnp_descriptor array1)
548544
cpdef dpnp_descriptor dpnp_cbrt(dpnp_descriptor array1)
549-
cpdef dpnp_descriptor dpnp_cos(dpnp_descriptor array1, dpnp_descriptor out)
550545
cpdef dpnp_descriptor dpnp_cosh(dpnp_descriptor array1)
551546
cpdef dpnp_descriptor dpnp_degrees(dpnp_descriptor array1)
552547
cpdef dpnp_descriptor dpnp_exp(dpnp_descriptor array1, dpnp_descriptor out)
@@ -557,7 +552,6 @@ cpdef dpnp_descriptor dpnp_log1p(dpnp_descriptor array1)
557552
cpdef dpnp_descriptor dpnp_log2(dpnp_descriptor array1)
558553
cpdef dpnp_descriptor dpnp_radians(dpnp_descriptor array1)
559554
cpdef dpnp_descriptor dpnp_recip(dpnp_descriptor array1)
560-
cpdef dpnp_descriptor dpnp_sin(dpnp_descriptor array1, dpnp_descriptor out)
561555
cpdef dpnp_descriptor dpnp_sinh(dpnp_descriptor array1)
562556
cpdef dpnp_descriptor dpnp_sqrt(dpnp_descriptor array1, dpnp_descriptor out)
563557
cpdef dpnp_descriptor dpnp_square(dpnp_descriptor array1)

0 commit comments

Comments
 (0)