Skip to content

Commit 1253889

Browse files
committed
address comments
1 parent 295e335 commit 1253889

File tree

5 files changed

+55
-37
lines changed

5 files changed

+55
-37
lines changed

dpnp/backend/extensions/vm/vm_py.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,6 @@ PYBIND11_MODULE(_vm_impl, m)
152152
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"));
153153
}
154154

155-
<<<<<<< HEAD
156155
// UnaryUfunc: ==== Floor(x) ====
157156
{
158157
vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t,
@@ -177,7 +176,10 @@ PYBIND11_MODULE(_vm_impl, m)
177176
};
178177
m.def("_mkl_floor_to_call", floor_need_to_call_pyapi,
179178
"Check input arguments to answer if `floor` function from "
180-
=======
179+
"OneMKL VM library can be used",
180+
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"));
181+
}
182+
181183
// UnaryUfunc: ==== Conj(x) ====
182184
{
183185
vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t,
@@ -202,7 +204,6 @@ PYBIND11_MODULE(_vm_impl, m)
202204
};
203205
m.def("_mkl_conj_to_call", conj_need_to_call_pyapi,
204206
"Check input arguments to answer if `conj` function from "
205-
>>>>>>> use_dpctl_conj_for_dpnp
206207
"OneMKL VM library can be used",
207208
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"));
208209
}

dpnp/dpnp_algo/dpnp_elementwise_common.py

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -394,36 +394,6 @@ def dpnp_ceil(x, out=None, order="K"):
394394
"""
395395

396396

397-
def dpnp_cos(x, out=None, order="K"):
398-
"""
399-
Invokes cos() function from pybind11 extension of OneMKL VM if possible.
400-
401-
Otherwise fully relies on dpctl.tensor implementation for cos() function.
402-
403-
"""
404-
405-
def _call_cos(src, dst, sycl_queue, depends=None):
406-
"""A callback to register in UnaryElementwiseFunc class of dpctl.tensor"""
407-
408-
if depends is None:
409-
depends = []
410-
411-
if vmi._mkl_cos_to_call(sycl_queue, src, dst):
412-
# call pybind11 extension for cos() function from OneMKL VM
413-
return vmi._cos(sycl_queue, src, dst, depends)
414-
return ti._cos(src, dst, sycl_queue, depends)
415-
416-
# dpctl.tensor only works with usm_ndarray
417-
x1_usm = dpnp.get_usm_ndarray(x)
418-
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
419-
420-
func = UnaryElementwiseFunc(
421-
"cos", ti._cos_result_type, _call_cos, _cos_docstring
422-
)
423-
res_usm = func(x1_usm, out=out_usm, order=order)
424-
return dpnp_array._create_from_usm_ndarray(res_usm)
425-
426-
427397
_conj_docstring = """
428398
conj(x, out=None, order='K')
429399
@@ -462,6 +432,36 @@ def _call_conj(src, dst, sycl_queue, depends=None):
462432
)
463433

464434

435+
def dpnp_cos(x, out=None, order="K"):
436+
"""
437+
Invokes cos() function from pybind11 extension of OneMKL VM if possible.
438+
439+
Otherwise fully relies on dpctl.tensor implementation for cos() function.
440+
441+
"""
442+
443+
def _call_cos(src, dst, sycl_queue, depends=None):
444+
"""A callback to register in UnaryElementwiseFunc class of dpctl.tensor"""
445+
446+
if depends is None:
447+
depends = []
448+
449+
if vmi._mkl_cos_to_call(sycl_queue, src, dst):
450+
# call pybind11 extension for cos() function from OneMKL VM
451+
return vmi._cos(sycl_queue, src, dst, depends)
452+
return ti._cos(src, dst, sycl_queue, depends)
453+
454+
# dpctl.tensor only works with usm_ndarray
455+
x1_usm = dpnp.get_usm_ndarray(x)
456+
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
457+
458+
func = UnaryElementwiseFunc(
459+
"cos", ti._cos_result_type, _call_cos, _cos_docstring
460+
)
461+
res_usm = func(x1_usm, out=out_usm, order=order)
462+
return dpnp_array._create_from_usm_ndarray(res_usm)
463+
464+
465465
def dpnp_conj(x, out=None, order="K"):
466466
"""
467467
Invokes conj() function from pybind11 extension of OneMKL VM if possible.

dpnp/dpnp_array.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -622,7 +622,7 @@ def conj(self):
622622
623623
"""
624624

625-
if not dpnp.issubsctype(self.dtype, dpnp.complex_):
625+
if not dpnp.issubsctype(self.dtype, dpnp.complexfloating):
626626
return self
627627
else:
628628
return dpnp.conjugate(self)
@@ -635,7 +635,7 @@ def conjugate(self):
635635
636636
"""
637637

638-
if not dpnp.issubsctype(self.dtype, dpnp.complex_):
638+
if not dpnp.issubsctype(self.dtype, dpnp.complexfloating):
639639
return self
640640
else:
641641
return dpnp.conjugate(self)

dpnp/dpnp_iface_mathematical.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ def conjugate(
397397
Examples
398398
--------
399399
>>> import dpnp as np
400-
>>> np.conjugate(1+2j)
400+
>>> np.conjugate(np.array(1+2j))
401401
(1-2j)
402402
403403
>>> x = np.eye(2) + 1j * np.eye(2)

tests/third_party/cupy/math_tests/test_arithmetic.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,12 @@
2626
testing.product(
2727
{
2828
"nargs": [1],
29-
"name": ["reciprocal", "angle"],
29+
"name": [
30+
"reciprocal",
31+
"conj",
32+
"conjugate",
33+
"angle",
34+
],
3035
}
3136
)
3237
+ testing.product(
@@ -68,6 +73,18 @@ def test_raises_with_numpy_input(self):
6873
@testing.parameterize(
6974
*(
7075
testing.product(
76+
{
77+
"arg1": (
78+
[
79+
testing.shaped_arange((2, 3), numpy, dtype=d)
80+
for d in all_types
81+
]
82+
+ [0, 0.0j, 0j, 2, 2.0, 2j, True, False]
83+
),
84+
"name": ["conj", "conjugate"],
85+
}
86+
)
87+
+ testing.product(
7188
{
7289
"arg1": (
7390
[

0 commit comments

Comments
 (0)