Skip to content

Commit 3fccd62

Browse files
committed
Apply temporary python solution for use vm extension
1 parent c905507 commit 3fccd62

File tree

3 files changed

+93
-92
lines changed

3 files changed

+93
-92
lines changed

dpnp/dpnp_algo/dpnp_elementwise_common.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
)
3333

3434
import dpnp
35+
import dpnp.backend.extensions.vm._vm_impl as vmi
3536
from dpnp.dpnp_array import dpnp_array
3637

3738
__all__ = [
@@ -111,11 +112,12 @@ def _call_func(src, dst, sycl_queue, depends=None):
111112
if depends is None:
112113
depends = []
113114

114-
if mkl_fn_to_call is not None and mkl_fn_to_call(
115-
sycl_queue, src, dst
116-
):
117-
# call pybind11 extension for unary function from OneMKL VM
118-
return mkl_impl_fn(sycl_queue, src, dst, depends)
115+
if vmi.mkl_vm_is_defined() and mkl_fn_to_call is not None:
116+
if getattr(vmi, mkl_fn_to_call)(sycl_queue, src, dst):
117+
# call pybind11 extension for unary function from OneMKL VM
118+
return getattr(vmi, mkl_impl_fn)(
119+
sycl_queue, src, dst, depends
120+
)
119121
return unary_dp_impl_fn(src, dst, sycl_queue, depends)
120122

121123
super().__init__(
@@ -264,11 +266,12 @@ def _call_func(src1, src2, dst, sycl_queue, depends=None):
264266
if depends is None:
265267
depends = []
266268

267-
if mkl_fn_to_call is not None and mkl_fn_to_call(
268-
sycl_queue, src1, src2, dst
269-
):
270-
# call pybind11 extension for binary function from OneMKL VM
271-
return mkl_impl_fn(sycl_queue, src1, src2, dst, depends)
269+
if vmi.mkl_vm_is_defined() and mkl_fn_to_call is not None:
270+
if getattr(vmi, mkl_fn_to_call)(sycl_queue, src1, src2, dst):
271+
# call pybind11 extension for binary function from OneMKL VM
272+
return getattr(vmi, mkl_impl_fn)(
273+
sycl_queue, src1, src2, dst, depends
274+
)
272275
return binary_dp_impl_fn(src1, src2, dst, sycl_queue, depends)
273276

274277
super().__init__(

dpnp/dpnp_iface_mathematical.py

Lines changed: 32 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@
5656

5757
import dpnp
5858
import dpnp.backend.extensions.ufunc._ufunc_impl as ufi
59-
import dpnp.backend.extensions.vm._vm_impl as vmi
6059

6160
from .backend.extensions.sycl_ext import _sycl_ext_impl
6261
from .dpnp_algo import (
@@ -385,8 +384,8 @@ def _gradient_num_diff_edges(
385384
ti._abs_result_type,
386385
ti._abs,
387386
_ABS_DOCSTRING,
388-
mkl_fn_to_call=vmi._mkl_abs_to_call,
389-
mkl_impl_fn=vmi._abs,
387+
mkl_fn_to_call="_mkl_abs_to_call",
388+
mkl_impl_fn="_abs",
390389
)
391390

392391

@@ -461,8 +460,8 @@ def _gradient_num_diff_edges(
461460
ti._add_result_type,
462461
ti._add,
463462
_ADD_DOCSTRING,
464-
mkl_fn_to_call=vmi._mkl_add_to_call,
465-
mkl_impl_fn=vmi._add,
463+
mkl_fn_to_call="_mkl_add_to_call",
464+
mkl_impl_fn="_add",
466465
binary_inplace_fn=ti._add_inplace,
467466
)
468467

@@ -609,8 +608,8 @@ def around(x, /, decimals=0, out=None):
609608
ti._ceil_result_type,
610609
ti._ceil,
611610
_CEIL_DOCSTRING,
612-
mkl_fn_to_call=vmi._mkl_ceil_to_call,
613-
mkl_impl_fn=vmi._ceil,
611+
mkl_fn_to_call="_mkl_ceil_to_call",
612+
mkl_impl_fn="_ceil",
614613
)
615614

616615

@@ -735,8 +734,8 @@ def clip(a, a_min, a_max, *, out=None, order="K", **kwargs):
735734
ti._conj_result_type,
736735
ti._conj,
737736
_CONJ_DOCSTRING,
738-
mkl_fn_to_call=vmi._mkl_conj_to_call,
739-
mkl_impl_fn=vmi._conj,
737+
mkl_fn_to_call="_mkl_conj_to_call",
738+
mkl_impl_fn="_conj",
740739
)
741740

742741
conj = conjugate
@@ -1310,8 +1309,8 @@ def diff(a, n=1, axis=-1, prepend=None, append=None):
13101309
ti._divide_result_type,
13111310
ti._divide,
13121311
_DIVIDE_DOCSTRING,
1313-
mkl_fn_to_call=vmi._mkl_div_to_call,
1314-
mkl_impl_fn=vmi._div,
1312+
mkl_fn_to_call="_mkl_div_to_call",
1313+
mkl_impl_fn="_div",
13151314
binary_inplace_fn=ti._divide_inplace,
13161315
acceptance_fn=_acceptance_fn_divide,
13171316
)
@@ -1408,8 +1407,8 @@ def ediff1d(x1, to_end=None, to_begin=None):
14081407
ufi._fabs_result_type,
14091408
ufi._fabs,
14101409
_FABS_DOCSTRING,
1411-
mkl_fn_to_call=vmi._mkl_abs_to_call,
1412-
mkl_impl_fn=vmi._abs,
1410+
mkl_fn_to_call="_mkl_abs_to_call",
1411+
mkl_impl_fn="_abs",
14131412
)
14141413

14151414

@@ -1466,8 +1465,8 @@ def ediff1d(x1, to_end=None, to_begin=None):
14661465
ti._floor_result_type,
14671466
ti._floor,
14681467
_FLOOR_DOCSTRING,
1469-
mkl_fn_to_call=vmi._mkl_floor_to_call,
1470-
mkl_impl_fn=vmi._floor,
1468+
mkl_fn_to_call="_mkl_floor_to_call",
1469+
mkl_impl_fn="_floor",
14711470
)
14721471

14731472

@@ -1620,8 +1619,8 @@ def ediff1d(x1, to_end=None, to_begin=None):
16201619
ufi._fmax_result_type,
16211620
ufi._fmax,
16221621
_FMAX_DOCSTRING,
1623-
mkl_fn_to_call=vmi._mkl_fmax_to_call,
1624-
mkl_impl_fn=vmi._fmax,
1622+
mkl_fn_to_call="_mkl_fmax_to_call",
1623+
mkl_impl_fn="_fmax",
16251624
)
16261625

16271626

@@ -1705,8 +1704,8 @@ def ediff1d(x1, to_end=None, to_begin=None):
17051704
ufi._fmin_result_type,
17061705
ufi._fmin,
17071706
_FMIN_DOCSTRING,
1708-
mkl_fn_to_call=vmi._mkl_fmin_to_call,
1709-
mkl_impl_fn=vmi._fmin,
1707+
mkl_fn_to_call="_mkl_fmin_to_call",
1708+
mkl_impl_fn="_fmin",
17101709
)
17111710

17121711

@@ -1779,8 +1778,8 @@ def ediff1d(x1, to_end=None, to_begin=None):
17791778
ufi._fmod_result_type,
17801779
ufi._fmod,
17811780
_FMOD_DOCSTRING,
1782-
mkl_fn_to_call=vmi._mkl_fmod_to_call,
1783-
mkl_impl_fn=vmi._fmod,
1781+
mkl_fn_to_call="_mkl_fmod_to_call",
1782+
mkl_impl_fn="_fmod",
17841783
)
17851784

17861785

@@ -2297,8 +2296,8 @@ def modf(x1, **kwargs):
22972296
ti._multiply_result_type,
22982297
ti._multiply,
22992298
_MULTIPLY_DOCSTRING,
2300-
mkl_fn_to_call=vmi._mkl_mul_to_call,
2301-
mkl_impl_fn=vmi._mul,
2299+
mkl_fn_to_call="_mkl_mul_to_call",
2300+
mkl_impl_fn="_mul",
23022301
binary_inplace_fn=ti._multiply_inplace,
23032302
)
23042303

@@ -2500,8 +2499,8 @@ def modf(x1, **kwargs):
25002499
ti._pow_result_type,
25012500
ti._pow,
25022501
_POWER_DOCSTRING,
2503-
mkl_fn_to_call=vmi._mkl_pow_to_call,
2504-
mkl_impl_fn=vmi._pow,
2502+
mkl_fn_to_call="_mkl_pow_to_call",
2503+
mkl_impl_fn="_pow",
25052504
binary_inplace_fn=ti._pow_inplace,
25062505
)
25072506

@@ -2815,8 +2814,8 @@ def prod(
28152814
ti._round_result_type,
28162815
ti._round,
28172816
_RINT_DOCSTRING,
2818-
mkl_fn_to_call=vmi._mkl_round_to_call,
2819-
mkl_impl_fn=vmi._round,
2817+
mkl_fn_to_call="_mkl_round_to_call",
2818+
mkl_impl_fn="'_round'",
28202819
)
28212820

28222821

@@ -2875,8 +2874,8 @@ def prod(
28752874
ti._round_result_type,
28762875
ti._round,
28772876
_ROUND_DOCSTRING,
2878-
mkl_fn_to_call=vmi._mkl_round_to_call,
2879-
mkl_impl_fn=vmi._round,
2877+
mkl_fn_to_call="_mkl_round_to_call",
2878+
mkl_impl_fn="_round",
28802879
)
28812880

28822881

@@ -3055,8 +3054,8 @@ def prod(
30553054
ti._subtract_result_type,
30563055
ti._subtract,
30573056
_SUBTRACT_DOCSTRING,
3058-
mkl_fn_to_call=vmi._mkl_sub_to_call,
3059-
mkl_impl_fn=vmi._sub,
3057+
mkl_fn_to_call="_mkl_sub_to_call",
3058+
mkl_impl_fn="_sub",
30603059
binary_inplace_fn=ti._subtract_inplace,
30613060
acceptance_fn=acceptance_fn_subtract,
30623061
)
@@ -3299,6 +3298,6 @@ def trapz(y1, x1=None, dx=1.0, axis=-1):
32993298
ti._trunc_result_type,
33003299
ti._trunc,
33013300
_TRUNC_DOCSTRING,
3302-
mkl_fn_to_call=vmi._mkl_trunc_to_call,
3303-
mkl_impl_fn=vmi._trunc,
3301+
mkl_fn_to_call="_mkl_trunc_to_call",
3302+
mkl_impl_fn="_trunc",
33043303
)

0 commit comments

Comments
 (0)