Skip to content

Commit 6220539

Browse files
Rename check functions in dpnp_utils_linalg.py (#1807)
* Add check_2d to __all__ for dpnp_utils_linalg.py * Rename check_2d to assert_2d * Rename check_stacked_2d to assert_stacked_2d * Rename check_stacked_square to assert_stacked_square --------- Co-authored-by: Anton <[email protected]>
1 parent c348bf4 commit 6220539

File tree

2 files changed

+39
-53
lines changed

2 files changed

+39
-53
lines changed

dpnp/linalg/dpnp_iface_linalg.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@
4242
import dpnp
4343

4444
from .dpnp_utils_linalg import (
45-
check_2d,
46-
check_stacked_2d,
47-
check_stacked_square,
45+
assert_2d,
46+
assert_stacked_2d,
47+
assert_stacked_square,
4848
dpnp_cholesky,
4949
dpnp_cond,
5050
dpnp_det,
@@ -140,8 +140,8 @@ def cholesky(a, upper=False):
140140
"""
141141

142142
dpnp.check_supported_arrays_type(a)
143-
check_stacked_2d(a)
144-
check_stacked_square(a)
143+
assert_stacked_2d(a)
144+
assert_stacked_square(a)
145145

146146
return dpnp_cholesky(a, upper=upper)
147147

@@ -243,8 +243,8 @@ def det(a):
243243
"""
244244

245245
dpnp.check_supported_arrays_type(a)
246-
check_stacked_2d(a)
247-
check_stacked_square(a)
246+
assert_stacked_2d(a)
247+
assert_stacked_square(a)
248248

249249
return dpnp_det(a)
250250

@@ -334,8 +334,8 @@ def eig(a):
334334
"""
335335

336336
dpnp.check_supported_arrays_type(a)
337-
check_stacked_2d(a)
338-
check_stacked_square(a)
337+
assert_stacked_2d(a)
338+
assert_stacked_square(a)
339339

340340
a_sycl_queue = a.sycl_queue
341341
a_usm_type = a.usm_type
@@ -408,8 +408,8 @@ def eigh(a, UPLO="L"):
408408
"""
409409

410410
dpnp.check_supported_arrays_type(a)
411-
check_stacked_2d(a)
412-
check_stacked_square(a)
411+
assert_stacked_2d(a)
412+
assert_stacked_square(a)
413413

414414
UPLO = UPLO.upper()
415415
if UPLO not in ("L", "U"):
@@ -478,8 +478,8 @@ def eigvals(a):
478478
"""
479479

480480
dpnp.check_supported_arrays_type(a)
481-
check_stacked_2d(a)
482-
check_stacked_square(a)
481+
assert_stacked_2d(a)
482+
assert_stacked_square(a)
483483

484484
# Since geev function from OneMKL LAPACK is not implemented yet,
485485
# use NumPy for this calculation.
@@ -535,8 +535,8 @@ def eigvalsh(a, UPLO="L"):
535535
"""
536536

537537
dpnp.check_supported_arrays_type(a)
538-
check_stacked_2d(a)
539-
check_stacked_square(a)
538+
assert_stacked_2d(a)
539+
assert_stacked_square(a)
540540

541541
UPLO = UPLO.upper()
542542
if UPLO not in ("L", "U"):
@@ -591,8 +591,8 @@ def inv(a):
591591
"""
592592

593593
dpnp.check_supported_arrays_type(a)
594-
check_stacked_2d(a)
595-
check_stacked_square(a)
594+
assert_stacked_2d(a)
595+
assert_stacked_square(a)
596596

597597
return dpnp_inv(a)
598598

@@ -663,7 +663,7 @@ def lstsq(a, b, rcond=None):
663663
"""
664664

665665
dpnp.check_supported_arrays_type(a, b)
666-
check_2d(a)
666+
assert_2d(a)
667667
if rcond is not None and not isinstance(rcond, (int, float)):
668668
raise TypeError("rcond must be integer, floating type, or None")
669669

@@ -724,8 +724,8 @@ def matrix_power(a, n):
724724
"""
725725

726726
dpnp.check_supported_arrays_type(a)
727-
check_stacked_2d(a)
728-
check_stacked_square(a)
727+
assert_stacked_2d(a)
728+
assert_stacked_square(a)
729729

730730
if not isinstance(n, int):
731731
raise TypeError("exponent must be an integer")
@@ -896,7 +896,7 @@ def pinv(a, rcond=1e-15, hermitian=False):
896896

897897
dpnp.check_supported_arrays_type(a)
898898
dpnp.check_supported_arrays_type(rcond, scalar_type=True)
899-
check_stacked_2d(a)
899+
assert_stacked_2d(a)
900900

901901
return dpnp_pinv(a, rcond=rcond, hermitian=hermitian)
902902

@@ -1067,7 +1067,7 @@ def qr(a, mode="reduced"):
10671067
"""
10681068

10691069
dpnp.check_supported_arrays_type(a)
1070-
check_stacked_2d(a)
1070+
assert_stacked_2d(a)
10711071

10721072
if mode not in ("reduced", "complete", "r", "raw"):
10731073
raise ValueError(f"Unrecognized mode {mode}")
@@ -1114,8 +1114,8 @@ def solve(a, b):
11141114
"""
11151115

11161116
dpnp.check_supported_arrays_type(a, b)
1117-
check_stacked_2d(a)
1118-
check_stacked_square(a)
1117+
assert_stacked_2d(a)
1118+
assert_stacked_square(a)
11191119

11201120
if not (
11211121
(a.ndim == b.ndim or a.ndim == b.ndim + 1)
@@ -1222,7 +1222,7 @@ def svd(a, full_matrices=True, compute_uv=True, hermitian=False):
12221222
"""
12231223

12241224
dpnp.check_supported_arrays_type(a)
1225-
check_stacked_2d(a)
1225+
assert_stacked_2d(a)
12261226

12271227
return dpnp_svd(a, full_matrices, compute_uv, hermitian)
12281228

@@ -1277,8 +1277,8 @@ def slogdet(a):
12771277
"""
12781278

12791279
dpnp.check_supported_arrays_type(a)
1280-
check_stacked_2d(a)
1281-
check_stacked_square(a)
1280+
assert_stacked_2d(a)
1281+
assert_stacked_square(a)
12821282

12831283
return dpnp_slogdet(a)
12841284

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 12 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,9 @@
3535
from dpnp.dpnp_utils import get_usm_allocations
3636

3737
__all__ = [
38-
"check_stacked_2d",
39-
"check_stacked_square",
38+
"assert_2d",
39+
"assert_stacked_2d",
40+
"assert_stacked_square",
4041
"dpnp_cholesky",
4142
"dpnp_cond",
4243
"dpnp_det",
@@ -683,9 +684,9 @@ def _triu_inplace(a, host_tasks, depends=None):
683684
return out
684685

685686

686-
def check_2d(*arrays):
687+
def assert_2d(*arrays):
687688
"""
688-
Return ``True`` if each array in `arrays` is exactly two dimensions.
689+
Check that each array in `arrays` is exactly two-dimensional.
689690
690691
If any array is not two-dimensional, `dpnp.linalg.LinAlgError` will be raised.
691692
@@ -694,11 +695,6 @@ def check_2d(*arrays):
694695
arrays : {dpnp.ndarray, usm_ndarray}
695696
A sequence of input arrays to check for dimensionality.
696697
697-
Returns
698-
-------
699-
out : bool
700-
``True`` if each array in `arrays` is exactly two-dimensional.
701-
702698
Raises
703699
------
704700
dpnp.linalg.LinAlgError
@@ -714,9 +710,9 @@ def check_2d(*arrays):
714710
)
715711

716712

717-
def check_stacked_2d(*arrays):
713+
def assert_stacked_2d(*arrays):
718714
"""
719-
Return ``True`` if each array in `arrays` has at least two dimensions.
715+
Check that each array in `arrays` has at least two dimensions.
720716
721717
If any array is less than two-dimensional, `dpnp.linalg.LinAlgError` will be raised.
722718
@@ -725,11 +721,6 @@ def check_stacked_2d(*arrays):
725721
arrays : {dpnp.ndarray, usm_ndarray}
726722
A sequence of input arrays to check for dimensionality.
727723
728-
Returns
729-
-------
730-
out : bool
731-
``True`` if each array in `arrays` is at least two-dimensional.
732-
733724
Raises
734725
------
735726
dpnp.linalg.LinAlgError
@@ -745,30 +736,25 @@ def check_stacked_2d(*arrays):
745736
)
746737

747738

748-
def check_stacked_square(*arrays):
739+
def assert_stacked_square(*arrays):
749740
"""
750-
Return ``True`` if each array in `arrays` is a square matrix.
741+
Check that each array in `arrays` is a square matrix.
751742
752743
If any array does not form a square matrix, `dpnp.linalg.LinAlgError` will be raised.
753744
754745
Precondition: `arrays` are at least 2d. The caller should assert it
755746
beforehand. For example,
756747
757748
>>> def solve(a):
758-
... check_stacked_2d(a)
759-
... check_stacked_square(a)
749+
... assert_stacked_2d(a)
750+
... assert_stacked_square(a)
760751
... ...
761752
762753
Parameters
763754
----------
764755
arrays : {dpnp.ndarray, usm_ndarray}
765756
A sequence of input arrays to check for square matrix shape.
766757
767-
Returns
768-
-------
769-
out : bool
770-
``True`` if each array in `arrays` forms a square matrix.
771-
772758
Raises
773759
------
774760
dpnp.linalg.LinAlgError
@@ -2309,7 +2295,7 @@ def dpnp_svd(
23092295
"""
23102296

23112297
if hermitian:
2312-
check_stacked_square(a)
2298+
assert_stacked_square(a)
23132299

23142300
# _gesvd returns eigenvalues with s ** 2 sorted descending,
23152301
# but dpnp.linalg.eigh returns s sorted ascending so we re-order the eigenvalues

0 commit comments

Comments
 (0)