Skip to content

Commit ed2815d

Browse files
committed
address comments
1 parent 0af2080 commit ed2815d

File tree

5 files changed

+79
-62
lines changed

5 files changed

+79
-62
lines changed

dpnp/backend/extensions/vm/vm_py.cpp

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -90,34 +90,6 @@ PYBIND11_MODULE(_vm_impl, m)
9090
py::arg("dst"));
9191
}
9292

93-
// UnaryUfunc: ==== Cos(x) ====
94-
{
95-
vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t,
96-
vm_ext::CosContigFactory>(
97-
cos_dispatch_vector);
98-
99-
auto cos_pyapi = [&](sycl::queue exec_q, arrayT src, arrayT dst,
100-
const event_vecT &depends = {}) {
101-
return vm_ext::unary_ufunc(exec_q, src, dst, depends,
102-
cos_dispatch_vector);
103-
};
104-
m.def("_cos", cos_pyapi,
105-
"Call `cos` function from OneMKL VM library to compute "
106-
"cosine of vector elements",
107-
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"),
108-
py::arg("depends") = py::list());
109-
110-
auto cos_need_to_call_pyapi = [&](sycl::queue exec_q, arrayT src,
111-
arrayT dst) {
112-
return vm_ext::need_to_call_unary_ufunc(exec_q, src, dst,
113-
cos_dispatch_vector);
114-
};
115-
m.def("_mkl_cos_to_call", cos_need_to_call_pyapi,
116-
"Check input arguments to answer if `cos` function from "
117-
"OneMKL VM library can be used",
118-
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"));
119-
}
120-
12193
// UnaryUfunc: ==== Conj(x) ====
12294
{
12395
vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t,
@@ -146,6 +118,34 @@ PYBIND11_MODULE(_vm_impl, m)
146118
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"));
147119
}
148120

121+
// UnaryUfunc: ==== Cos(x) ====
122+
{
123+
vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t,
124+
vm_ext::CosContigFactory>(
125+
cos_dispatch_vector);
126+
127+
auto cos_pyapi = [&](sycl::queue exec_q, arrayT src, arrayT dst,
128+
const event_vecT &depends = {}) {
129+
return vm_ext::unary_ufunc(exec_q, src, dst, depends,
130+
cos_dispatch_vector);
131+
};
132+
m.def("_cos", cos_pyapi,
133+
"Call `cos` function from OneMKL VM library to compute "
134+
"cosine of vector elements",
135+
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"),
136+
py::arg("depends") = py::list());
137+
138+
auto cos_need_to_call_pyapi = [&](sycl::queue exec_q, arrayT src,
139+
arrayT dst) {
140+
return vm_ext::need_to_call_unary_ufunc(exec_q, src, dst,
141+
cos_dispatch_vector);
142+
};
143+
m.def("_mkl_cos_to_call", cos_need_to_call_pyapi,
144+
"Check input arguments to answer if `cos` function from "
145+
"OneMKL VM library can be used",
146+
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"));
147+
}
148+
149149
// UnaryUfunc: ==== Ln(x) ====
150150
{
151151
vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t,

dpnp/dpnp_algo/dpnp_elementwise_common.py

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -338,36 +338,6 @@ def dpnp_bitwise_xor(x1, x2, out=None, order="K"):
338338
"""
339339

340340

341-
def dpnp_cos(x, out=None, order="K"):
342-
"""
343-
Invokes cos() function from pybind11 extension of OneMKL VM if possible.
344-
345-
Otherwise fully relies on dpctl.tensor implementation for cos() function.
346-
347-
"""
348-
349-
def _call_cos(src, dst, sycl_queue, depends=None):
350-
"""A callback to register in UnaryElementwiseFunc class of dpctl.tensor"""
351-
352-
if depends is None:
353-
depends = []
354-
355-
if vmi._mkl_cos_to_call(sycl_queue, src, dst):
356-
# call pybind11 extension for cos() function from OneMKL VM
357-
return vmi._cos(sycl_queue, src, dst, depends)
358-
return ti._cos(src, dst, sycl_queue, depends)
359-
360-
# dpctl.tensor only works with usm_ndarray
361-
x1_usm = dpnp.get_usm_ndarray(x)
362-
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
363-
364-
func = UnaryElementwiseFunc(
365-
"cos", ti._cos_result_type, _call_cos, _cos_docstring
366-
)
367-
res_usm = func(x1_usm, out=out_usm, order=order)
368-
return dpnp_array._create_from_usm_ndarray(res_usm)
369-
370-
371341
_conj_docstring = """
372342
conj(x, out=None, order='K')
373343
@@ -406,6 +376,36 @@ def _call_conj(src, dst, sycl_queue, depends=None):
406376
)
407377

408378

379+
def dpnp_cos(x, out=None, order="K"):
380+
"""
381+
Invokes cos() function from pybind11 extension of OneMKL VM if possible.
382+
383+
Otherwise fully relies on dpctl.tensor implementation for cos() function.
384+
385+
"""
386+
387+
def _call_cos(src, dst, sycl_queue, depends=None):
388+
"""A callback to register in UnaryElementwiseFunc class of dpctl.tensor"""
389+
390+
if depends is None:
391+
depends = []
392+
393+
if vmi._mkl_cos_to_call(sycl_queue, src, dst):
394+
# call pybind11 extension for cos() function from OneMKL VM
395+
return vmi._cos(sycl_queue, src, dst, depends)
396+
return ti._cos(src, dst, sycl_queue, depends)
397+
398+
# dpctl.tensor only works with usm_ndarray
399+
x1_usm = dpnp.get_usm_ndarray(x)
400+
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
401+
402+
func = UnaryElementwiseFunc(
403+
"cos", ti._cos_result_type, _call_cos, _cos_docstring
404+
)
405+
res_usm = func(x1_usm, out=out_usm, order=order)
406+
return dpnp_array._create_from_usm_ndarray(res_usm)
407+
408+
409409
def dpnp_conj(x, out=None, order="K"):
410410
"""
411411
Invokes conj() function from pybind11 extension of OneMKL VM if possible.

dpnp/dpnp_array.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -622,7 +622,7 @@ def conj(self):
622622
623623
"""
624624

625-
if not dpnp.issubsctype(self.dtype, dpnp.complex_):
625+
if not dpnp.issubsctype(self.dtype, dpnp.complexfloating):
626626
return self
627627
else:
628628
return dpnp.conjugate(self)
@@ -635,7 +635,7 @@ def conjugate(self):
635635
636636
"""
637637

638-
if not dpnp.issubsctype(self.dtype, dpnp.complex_):
638+
if not dpnp.issubsctype(self.dtype, dpnp.complexfloating):
639639
return self
640640
else:
641641
return dpnp.conjugate(self)

dpnp/dpnp_iface_mathematical.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ def conjugate(
380380
Examples
381381
--------
382382
>>> import dpnp as np
383-
>>> np.conjugate(1+2j)
383+
>>> np.conjugate(np.array(1+2j))
384384
(1-2j)
385385
386386
>>> x = np.eye(2) + 1j * np.eye(2)

tests/third_party/cupy/math_tests/test_arithmetic.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,12 @@
2626
testing.product(
2727
{
2828
"nargs": [1],
29-
"name": ["reciprocal", "angle"],
29+
"name": [
30+
"reciprocal",
31+
"conj",
32+
"conjugate",
33+
"angle",
34+
],
3035
}
3136
)
3237
+ testing.product(
@@ -68,6 +73,18 @@ def test_raises_with_numpy_input(self):
6873
@testing.parameterize(
6974
*(
7075
testing.product(
76+
{
77+
"arg1": (
78+
[
79+
testing.shaped_arange((2, 3), numpy, dtype=d)
80+
for d in all_types
81+
]
82+
+ [0, 0.0j, 0j, 2, 2.0, 2j, True, False]
83+
),
84+
"name": ["conj", "conjugate"],
85+
}
86+
)
87+
+ testing.product(
7188
{
7289
"arg1": (
7390
[

0 commit comments

Comments
 (0)