Skip to content

Commit 4e54ec7

Browse files
Rename check_stacked_square to assert_stacked_square
1 parent bdc0b20 commit 4e54ec7

File tree

2 files changed

+30
-35
lines changed

2 files changed

+30
-35
lines changed

dpnp/linalg/dpnp_iface_linalg.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@
4343

4444
from .dpnp_utils_linalg import (
4545
assert_2d,
46-
check_stacked_2d,
47-
check_stacked_square,
46+
assert_stacked_2d,
47+
assert_stacked_square,
4848
dpnp_cholesky,
4949
dpnp_cond,
5050
dpnp_det,
@@ -140,8 +140,8 @@ def cholesky(a, upper=False):
140140
"""
141141

142142
dpnp.check_supported_arrays_type(a)
143-
check_stacked_2d(a)
144-
check_stacked_square(a)
143+
assert_stacked_2d(a)
144+
assert_stacked_square(a)
145145

146146
return dpnp_cholesky(a, upper=upper)
147147

@@ -243,8 +243,8 @@ def det(a):
243243
"""
244244

245245
dpnp.check_supported_arrays_type(a)
246-
check_stacked_2d(a)
247-
check_stacked_square(a)
246+
assert_stacked_2d(a)
247+
assert_stacked_square(a)
248248

249249
return dpnp_det(a)
250250

@@ -334,8 +334,8 @@ def eig(a):
334334
"""
335335

336336
dpnp.check_supported_arrays_type(a)
337-
check_stacked_2d(a)
338-
check_stacked_square(a)
337+
assert_stacked_2d(a)
338+
assert_stacked_square(a)
339339

340340
a_sycl_queue = a.sycl_queue
341341
a_usm_type = a.usm_type
@@ -408,8 +408,8 @@ def eigh(a, UPLO="L"):
408408
"""
409409

410410
dpnp.check_supported_arrays_type(a)
411-
check_stacked_2d(a)
412-
check_stacked_square(a)
411+
assert_stacked_2d(a)
412+
assert_stacked_square(a)
413413

414414
UPLO = UPLO.upper()
415415
if UPLO not in ("L", "U"):
@@ -478,8 +478,8 @@ def eigvals(a):
478478
"""
479479

480480
dpnp.check_supported_arrays_type(a)
481-
check_stacked_2d(a)
482-
check_stacked_square(a)
481+
assert_stacked_2d(a)
482+
assert_stacked_square(a)
483483

484484
# Since geev function from OneMKL LAPACK is not implemented yet,
485485
# use NumPy for this calculation.
@@ -535,8 +535,8 @@ def eigvalsh(a, UPLO="L"):
535535
"""
536536

537537
dpnp.check_supported_arrays_type(a)
538-
check_stacked_2d(a)
539-
check_stacked_square(a)
538+
assert_stacked_2d(a)
539+
assert_stacked_square(a)
540540

541541
UPLO = UPLO.upper()
542542
if UPLO not in ("L", "U"):
@@ -591,8 +591,8 @@ def inv(a):
591591
"""
592592

593593
dpnp.check_supported_arrays_type(a)
594-
check_stacked_2d(a)
595-
check_stacked_square(a)
594+
assert_stacked_2d(a)
595+
assert_stacked_square(a)
596596

597597
return dpnp_inv(a)
598598

@@ -724,8 +724,8 @@ def matrix_power(a, n):
724724
"""
725725

726726
dpnp.check_supported_arrays_type(a)
727-
check_stacked_2d(a)
728-
check_stacked_square(a)
727+
assert_stacked_2d(a)
728+
assert_stacked_square(a)
729729

730730
if not isinstance(n, int):
731731
raise TypeError("exponent must be an integer")
@@ -896,7 +896,7 @@ def pinv(a, rcond=1e-15, hermitian=False):
896896

897897
dpnp.check_supported_arrays_type(a)
898898
dpnp.check_supported_arrays_type(rcond, scalar_type=True)
899-
check_stacked_2d(a)
899+
assert_stacked_2d(a)
900900

901901
return dpnp_pinv(a, rcond=rcond, hermitian=hermitian)
902902

@@ -1067,7 +1067,7 @@ def qr(a, mode="reduced"):
10671067
"""
10681068

10691069
dpnp.check_supported_arrays_type(a)
1070-
check_stacked_2d(a)
1070+
assert_stacked_2d(a)
10711071

10721072
if mode not in ("reduced", "complete", "r", "raw"):
10731073
raise ValueError(f"Unrecognized mode {mode}")
@@ -1114,8 +1114,8 @@ def solve(a, b):
11141114
"""
11151115

11161116
dpnp.check_supported_arrays_type(a, b)
1117-
check_stacked_2d(a)
1118-
check_stacked_square(a)
1117+
assert_stacked_2d(a)
1118+
assert_stacked_square(a)
11191119

11201120
if not (
11211121
(a.ndim == b.ndim or a.ndim == b.ndim + 1)
@@ -1222,7 +1222,7 @@ def svd(a, full_matrices=True, compute_uv=True, hermitian=False):
12221222
"""
12231223

12241224
dpnp.check_supported_arrays_type(a)
1225-
check_stacked_2d(a)
1225+
assert_stacked_2d(a)
12261226

12271227
return dpnp_svd(a, full_matrices, compute_uv, hermitian)
12281228

@@ -1277,8 +1277,8 @@ def slogdet(a):
12771277
"""
12781278

12791279
dpnp.check_supported_arrays_type(a)
1280-
check_stacked_2d(a)
1281-
check_stacked_square(a)
1280+
assert_stacked_2d(a)
1281+
assert_stacked_square(a)
12821282

12831283
return dpnp_slogdet(a)
12841284

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
__all__ = [
3838
"assert_2d",
3939
"assert_stacked_2d",
40-
"check_stacked_square",
40+
"assert_stacked_square",
4141
"dpnp_cholesky",
4242
"dpnp_cond",
4343
"dpnp_det",
@@ -736,9 +736,9 @@ def assert_stacked_2d(*arrays):
736736
)
737737

738738

739-
def check_stacked_square(*arrays):
739+
def assert_stacked_square(*arrays):
740740
"""
741-
Return ``True`` if each array in `arrays` is a square matrix.
741+
Check that each array in `arrays` is a square matrix.
742742
743743
If any array does not form a square matrix, `dpnp.linalg.LinAlgError` will be raised.
744744
@@ -747,19 +747,14 @@ def check_stacked_square(*arrays):
747747
748748
>>> def solve(a):
749749
... assert_stacked_2d(a)
750-
... check_stacked_square(a)
750+
... assert_stacked_square(a)
751751
... ...
752752
753753
Parameters
754754
----------
755755
arrays : {dpnp.ndarray, usm_ndarray}
756756
A sequence of input arrays to check for square matrix shape.
757757
758-
Returns
759-
-------
760-
out : bool
761-
``True`` if each array in `arrays` forms a square matrix.
762-
763758
Raises
764759
------
765760
dpnp.linalg.LinAlgError
@@ -2300,7 +2295,7 @@ def dpnp_svd(
23002295
"""
23012296

23022297
if hermitian:
2303-
check_stacked_square(a)
2298+
assert_stacked_square(a)
23042299

23052300
# _gesvd returns eigenvalues with s ** 2 sorted descending,
23062301
# but dpnp.linalg.eigh returns s sorted ascending so we re-order the eigenvalues

0 commit comments

Comments
 (0)