Skip to content

Commit 6c2036f

Browse files
authored
update dpnp.linalg.multi_dot implementation (#1729)
* update dpnp.multi_dot * fix pre-commit * update check limitation calls * address comments * use sycl_queue and usm_type for m and s
1 parent 4b6650f commit 6c2036f

15 files changed

+584
-216
lines changed

dpnp/dpnp_array.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,7 @@ def astype(self, dtype, order="K", casting="unsafe", subok=True, copy=True):
540540
Array data type casting.
541541
dtype : dtype
542542
Target data type.
543-
order : {'C', 'F', 'A', 'K'}
543+
order : {"C", "F", "A", "K"}, optional
544544
Row-major (C-style) or column-major (Fortran-style) order.
545545
When ``order`` is 'A', it uses 'F' if ``a`` is column-major and uses 'C' otherwise.
546546
And when ``order`` is 'K', it keeps strides as closely as possible.

dpnp/dpnp_iface.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
"array_equal",
5757
"asnumpy",
5858
"astype",
59+
"check_limitations",
5960
"check_supported_arrays_type",
6061
"convert_single_elem_array_to_scalar",
6162
"default_float_type",
@@ -232,6 +233,58 @@ def astype(x1, dtype, order="K", casting="unsafe", copy=True):
232233
return dpnp_array._create_from_usm_ndarray(array_obj)
233234

234235

236+
def check_limitations(
237+
order=None, subok=False, like=None, initial=None, where=True
238+
):
239+
"""
240+
Checking limitation kwargs for their supported values.
241+
242+
Parameter `order` is only supported with values ``"C"``, ``"F"``
243+
and ``None``.
244+
Parameter `subok` is only supported with default value ``False``.
245+
Parameter `like` is only supported with default value ``None``.
246+
Parameter `initial` is only supported with default value ``None``.
247+
Parameter `where` is only supported with default value ``True``.
248+
249+
Raises
250+
------
251+
NotImplementedError
252+
If any input kwargs is of unsupported value.
253+
254+
"""
255+
256+
if order in ("A", "a", "K", "k"):
257+
raise NotImplementedError(
258+
"Keyword argument `order` is supported only with "
259+
f"values ``'C'`` and ``'F'``, but got {order}"
260+
)
261+
if order not in ("C", "c", "F", "f", None):
262+
raise ValueError(
263+
"Unrecognized `order` keyword value, expecting "
264+
f"``'C'`` or ``'F'``, but got {order}"
265+
)
266+
if like is not None:
267+
raise NotImplementedError(
268+
"Keyword argument `like` is supported only with "
269+
f"default value ``None``, but got {like}"
270+
)
271+
if subok is not False:
272+
raise NotImplementedError(
273+
"Keyword argument `subok` is supported only with "
274+
f"default value ``False``, but got {subok}"
275+
)
276+
if initial is not None:
277+
raise NotImplementedError(
278+
"Keyword argument `initial` is only supported with "
279+
f"default value ``None``, but got {initial}"
280+
)
281+
if where is not True:
282+
raise NotImplementedError(
283+
"Keyword argument `where` is supported only with "
284+
f"default value ``True``, but got {where}"
285+
)
286+
287+
235288
def check_supported_arrays_type(*arrays, scalar_type=False, all_scalars=False):
236289
"""
237290
Return ``True`` if each array has either type of scalar,

dpnp/dpnp_iface_arraycreation.py

Lines changed: 20 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -99,43 +99,6 @@
9999
]
100100

101101

102-
def _check_limitations(order=None, subok=False, like=None):
103-
"""
104-
Checking limitation kwargs for their supported values.
105-
106-
Parameter `order` is supported only with values ``C``, ``F`` and ``None``.
107-
Parameter `subok` is supported only with default value ``False``.
108-
Parameter `like` is supported only with default value ``None``.
109-
110-
Raises
111-
------
112-
NotImplementedError
113-
If any input kwargs is of unsupported value.
114-
115-
"""
116-
117-
if order in ("A", "a", "K", "k"):
118-
raise NotImplementedError(
119-
"Keyword argument `order` is supported only with "
120-
f"values ``'C'`` and ``'F'``, but got {order}"
121-
)
122-
if order not in ("C", "c", "F", "f", None):
123-
raise ValueError(
124-
"Unrecognized `order` keyword value, expecting "
125-
f"``'C'`` or ``'F'``, but got {order}"
126-
)
127-
if like is not None:
128-
raise NotImplementedError(
129-
"Keyword argument `like` is supported only with "
130-
f"default value ``None``, but got {like}"
131-
)
132-
if subok is not False:
133-
raise NotImplementedError(
134-
"Keyword argument `subok` is supported only with "
135-
f"default value ``False``, but got {subok}"
136-
)
137-
138-
139102
def arange(
140103
start,
141104
/,
@@ -223,7 +186,7 @@ def arange(
223186
224187
"""
225188

226-
_check_limitations(like=like)
189+
dpnp.check_limitations(like=like)
227190

228191
return dpnp_container.arange(
229192
start,
@@ -343,7 +306,7 @@ def array(
343306
344307
"""
345308

346-
_check_limitations(subok=subok, like=like)
309+
dpnp.check_limitations(subok=subok, like=like)
347310
if ndmin != 0:
348311
raise NotImplementedError(
349312
"Keyword argument `ndmin` is supported only with "
@@ -451,7 +414,7 @@ def asanyarray(
451414
452415
"""
453416

454-
_check_limitations(like=like)
417+
dpnp.check_limitations(like=like)
455418

456419
return asarray(
457420
a,
@@ -548,7 +511,7 @@ def asarray(
548511
549512
"""
550513

551-
_check_limitations(like=like)
514+
dpnp.check_limitations(like=like)
552515

553516
return dpnp_container.asarray(
554517
a,
@@ -654,7 +617,7 @@ def ascontiguousarray(
654617
655618
"""
656619

657-
_check_limitations(like=like)
620+
dpnp.check_limitations(like=like)
658621

659622
# at least 1-d array has to be returned
660623
if dpnp.isscalar(a) or hasattr(a, "ndim") and a.ndim == 0:
@@ -768,7 +731,7 @@ def asfortranarray(
768731
769732
"""
770733

771-
_check_limitations(like=like)
734+
dpnp.check_limitations(like=like)
772735

773736
# at least 1-d array has to be returned
774737
if dpnp.isscalar(a) or hasattr(a, "ndim") and a.ndim == 0:
@@ -867,7 +830,7 @@ def copy(
867830
868831
"""
869832

870-
_check_limitations(subok=subok)
833+
dpnp.check_limitations(subok=subok)
871834

872835
if dpnp.is_supported_array_type(a):
873836
sycl_queue_normalized = dpnp.get_normalized_queue_device(
@@ -1176,7 +1139,7 @@ def empty(
11761139
11771140
"""
11781141

1179-
_check_limitations(order=order, like=like)
1142+
dpnp.check_limitations(order=order, like=like)
11801143
return dpnp_container.empty(
11811144
shape,
11821145
dtype=dtype,
@@ -1276,7 +1239,7 @@ def empty_like(
12761239
"""
12771240

12781241
dpnp.check_supported_arrays_type(a)
1279-
_check_limitations(order=order, subok=subok)
1242+
dpnp.check_limitations(order=order, subok=subok)
12801243

12811244
_shape = a.shape if shape is None else shape
12821245
_dtype = a.dtype if dtype is None else dtype
@@ -1385,7 +1348,7 @@ def eye(
13851348
13861349
"""
13871350

1388-
_check_limitations(order=order, like=like)
1351+
dpnp.check_limitations(order=order, like=like)
13891352

13901353
return dpnp_container.eye(
13911354
N,
@@ -1485,7 +1448,7 @@ def frombuffer(
14851448
14861449
"""
14871450

1488-
_check_limitations(like=like)
1451+
dpnp.check_limitations(like=like)
14891452
return asarray(
14901453
numpy.frombuffer(buffer, dtype=dtype, count=count, offset=offset),
14911454
device=device,
@@ -1609,7 +1572,7 @@ def fromfile(
16091572
16101573
"""
16111574

1612-
_check_limitations(like=like)
1575+
dpnp.check_limitations(like=like)
16131576
return asarray(
16141577
numpy.fromfile(file, dtype=dtype, count=count, sep=sep, offset=offset),
16151578
device=device,
@@ -1725,7 +1688,7 @@ def fromstring(
17251688
17261689
"""
17271690

1728-
_check_limitations(like=like)
1691+
dpnp.check_limitations(like=like)
17291692
return asarray(
17301693
numpy.fromstring(string, dtype=dtype, count=count, sep=sep),
17311694
device=device,
@@ -1819,7 +1782,7 @@ def full(
18191782
18201783
"""
18211784

1822-
_check_limitations(order=order, like=like)
1785+
dpnp.check_limitations(order=order, like=like)
18231786

18241787
return dpnp_container.full(
18251788
shape,
@@ -1926,7 +1889,7 @@ def full_like(
19261889
"""
19271890

19281891
dpnp.check_supported_arrays_type(a)
1929-
_check_limitations(order=order, subok=subok)
1892+
dpnp.check_limitations(order=order, subok=subok)
19301893

19311894
_shape = a.shape if shape is None else shape
19321895
_dtype = a.dtype if dtype is None else dtype
@@ -2155,7 +2118,7 @@ def identity(
21552118
if n < 0:
21562119
raise ValueError("negative dimensions are not allowed")
21572120

2158-
_check_limitations(like=like)
2121+
dpnp.check_limitations(like=like)
21592122

21602123
_dtype = dpnp.default_float_type() if dtype is None else dtype
21612124
return dpnp.eye(
@@ -2759,7 +2722,7 @@ def ones(
27592722
27602723
"""
27612724

2762-
_check_limitations(order=order, like=like)
2725+
dpnp.check_limitations(order=order, like=like)
27632726

27642727
return dpnp_container.ones(
27652728
shape,
@@ -2861,7 +2824,7 @@ def ones_like(
28612824
28622825
"""
28632826
dpnp.check_supported_arrays_type(a)
2864-
_check_limitations(order=order, subok=subok)
2827+
dpnp.check_limitations(order=order, subok=subok)
28652828

28662829
_shape = a.shape if shape is None else shape
28672830
_dtype = a.dtype if dtype is None else dtype
@@ -3347,7 +3310,7 @@ def zeros(
33473310
33483311
"""
33493312

3350-
_check_limitations(order=order, like=like)
3313+
dpnp.check_limitations(order=order, like=like)
33513314

33523315
return dpnp_container.zeros(
33533316
shape,
@@ -3450,7 +3413,7 @@ def zeros_like(
34503413
"""
34513414

34523415
dpnp.check_supported_arrays_type(a)
3453-
_check_limitations(order=order, subok=subok)
3416+
dpnp.check_limitations(order=order, subok=subok)
34543417

34553418
_shape = a.shape if shape is None else shape
34563419
_dtype = a.dtype if dtype is None else dtype

dpnp/dpnp_iface_linearalgebra.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def dot(a, b, out=None):
8282
b : {dpnp.ndarray, usm_ndarray, scalar}
8383
Second input array. Both inputs `a` and `b` can not be scalars
8484
at the same time.
85-
out : {dpnp.ndarray, usm_ndarray}, optional
85+
out : {None, dpnp.ndarray, usm_ndarray}, optional
8686
Alternative output array in which to place the result. It must have
8787
the same shape and data type as the expected output and should be
8888
C-contiguous. If these conditions are not met, an exception is
@@ -345,11 +345,11 @@ def matmul(
345345
346346
Parameters
347347
----------
348-
x1 : {dpnp_array, usm_ndarray}
348+
x1 : {dpnp.ndarray, usm_ndarray}
349349
First input array.
350-
x2 : {dpnp_array, usm_ndarray}
350+
x2 : {dpnp.ndarray, usm_ndarray}
351351
Second input array.
352-
out : {dpnp.ndarray, usm_ndarray}, optional
352+
out : {None, dpnp.ndarray, usm_ndarray}, optional
353353
Alternative output array in which to place the result. It must have
354354
a shape that matches the signature `(n,k),(k,m)->(n,m)` but the type
355355
(of the calculated values) will be cast if necessary. Default: ``None``.

dpnp/dpnp_iface_manipulation.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1202,14 +1202,13 @@ def ravel(a, order="C"):
12021202
x : {dpnp.ndarray, usm_ndarray}
12031203
Input array. The elements in `a` are read in the order specified by
12041204
order, and packed as a 1-D array.
1205-
order : {'C', 'F'}, optional
1206-
The elements of `a` are read using this index order. ``C`` means to
1205+
order : {"C", "F"}, optional
1206+
The elements of `a` are read using this index order. ``"C"`` means to
12071207
index the elements in row-major, C-style order, with the last axis
12081208
index changing fastest, back to the first axis index changing slowest.
1209-
``F`` means to index the elements in column-major, Fortran-style order,
1210-
with the first index changing fastest, and the last index changing
1211-
slowest.
1212-
By default, ``C`` index order is used.
1209+
``"F"`` means to index the elements in column-major, Fortran-style
1210+
order, with the first index changing fastest, and the last index
1211+
changing slowest. By default, ``"C"`` index order is used.
12131212
12141213
Returns
12151214
-------
@@ -1313,15 +1312,15 @@ def reshape(a, /, newshape, order="C", copy=None):
13131312
an integer, then the result will be a 1-D array of that length.
13141313
One shape dimension can be -1. In this case, the value is
13151314
inferred from the length of the array and remaining dimensions.
1316-
order : {'C', 'F'}, optional
1315+
order : {"C", "F"}, optional
13171316
Read the elements of `a` using this index order, and place the
1318-
elements into the reshaped array using this index order. 'C'
1317+
elements into the reshaped array using this index order. ``"C"``
13191318
means to read / write the elements using C-like index order,
13201319
with the last axis index changing fastest, back to the first
1321-
axis index changing slowest. 'F' means to read / write the
1320+
axis index changing slowest. ``"F"`` means to read / write the
13221321
elements using Fortran-like index order, with the first index
13231322
changing fastest, and the last index changing slowest. Note that
1324-
the 'C' and 'F' options take no account of the memory layout of
1323+
the ``"C"`` and ``"F"`` options take no account of the memory layout of
13251324
the underlying array, and only refer to the order of indexing.
13261325
copy : bool, optional
13271326
Boolean indicating whether or not to copy the input array.

dpnp/dpnp_iface_mathematical.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ def clip(a, a_min, a_max, *, out=None, order="K", **kwargs):
449449
a_min, a_max : {dpnp.ndarray, usm_ndarray, None}
450450
Minimum and maximum value. If ``None``, clipping is not performed on the corresponding edge.
451451
Only one of `a_min` and `a_max` may be ``None``. Both are broadcast against `a`.
452-
out : {dpnp.ndarray, usm_ndarray}, optional
452+
out : {None, dpnp.ndarray, usm_ndarray}, optional
453453
The results will be placed in this array. It may be the input array for in-place clipping.
454454
`out` must be of the right shape to hold the output. Its type is preserved.
455455
order : {"C", "F", "A", "K", None}, optional
@@ -614,8 +614,8 @@ def copysign(
614614
out : ({None, dpnp.ndarray, usm_ndarray}, optional):
615615
Output array to populate.
616616
Array must have the correct shape and the expected data type.
617-
order : ({'C', 'F', 'A', 'K'}, optional):
618-
Memory layout of the newly output array, if parameter `out` is `None`.
617+
order : {"C", "F", "A", "K"}, optional
618+
Memory layout of the newly output array, if parameter `out` is ``None``.
619619
Default: "K".
620620
621621
Returns
@@ -2848,7 +2848,7 @@ def sum(
28482848
data type of `a`, the input array elements are cast to the
28492849
specified data type before computing the sum.
28502850
Default: ``None``.
2851-
out : {dpnp.ndarray, usm_ndarray}, optional
2851+
out : {None, dpnp.ndarray, usm_ndarray}, optional
28522852
Alternative output array in which to place the result. It must
28532853
have the same shape as the expected output, but the type of
28542854
the output values will be cast if necessary.

0 commit comments

Comments
 (0)