|
42 | 42 | import dpnp
|
43 | 43 |
|
44 | 44 | from .dpnp_utils_linalg import (
|
45 |
| - check_2d, |
46 |
| - check_stacked_2d, |
47 |
| - check_stacked_square, |
| 45 | + assert_2d, |
| 46 | + assert_stacked_2d, |
| 47 | + assert_stacked_square, |
48 | 48 | dpnp_cholesky,
|
49 | 49 | dpnp_cond,
|
50 | 50 | dpnp_det,
|
@@ -140,8 +140,8 @@ def cholesky(a, upper=False):
|
140 | 140 | """
|
141 | 141 |
|
142 | 142 | dpnp.check_supported_arrays_type(a)
|
143 |
| - check_stacked_2d(a) |
144 |
| - check_stacked_square(a) |
| 143 | + assert_stacked_2d(a) |
| 144 | + assert_stacked_square(a) |
145 | 145 |
|
146 | 146 | return dpnp_cholesky(a, upper=upper)
|
147 | 147 |
|
@@ -243,8 +243,8 @@ def det(a):
|
243 | 243 | """
|
244 | 244 |
|
245 | 245 | dpnp.check_supported_arrays_type(a)
|
246 |
| - check_stacked_2d(a) |
247 |
| - check_stacked_square(a) |
| 246 | + assert_stacked_2d(a) |
| 247 | + assert_stacked_square(a) |
248 | 248 |
|
249 | 249 | return dpnp_det(a)
|
250 | 250 |
|
@@ -334,8 +334,8 @@ def eig(a):
|
334 | 334 | """
|
335 | 335 |
|
336 | 336 | dpnp.check_supported_arrays_type(a)
|
337 |
| - check_stacked_2d(a) |
338 |
| - check_stacked_square(a) |
| 337 | + assert_stacked_2d(a) |
| 338 | + assert_stacked_square(a) |
339 | 339 |
|
340 | 340 | a_sycl_queue = a.sycl_queue
|
341 | 341 | a_usm_type = a.usm_type
|
@@ -408,8 +408,8 @@ def eigh(a, UPLO="L"):
|
408 | 408 | """
|
409 | 409 |
|
410 | 410 | dpnp.check_supported_arrays_type(a)
|
411 |
| - check_stacked_2d(a) |
412 |
| - check_stacked_square(a) |
| 411 | + assert_stacked_2d(a) |
| 412 | + assert_stacked_square(a) |
413 | 413 |
|
414 | 414 | UPLO = UPLO.upper()
|
415 | 415 | if UPLO not in ("L", "U"):
|
@@ -478,8 +478,8 @@ def eigvals(a):
|
478 | 478 | """
|
479 | 479 |
|
480 | 480 | dpnp.check_supported_arrays_type(a)
|
481 |
| - check_stacked_2d(a) |
482 |
| - check_stacked_square(a) |
| 481 | + assert_stacked_2d(a) |
| 482 | + assert_stacked_square(a) |
483 | 483 |
|
484 | 484 | # Since geev function from OneMKL LAPACK is not implemented yet,
|
485 | 485 | # use NumPy for this calculation.
|
@@ -535,8 +535,8 @@ def eigvalsh(a, UPLO="L"):
|
535 | 535 | """
|
536 | 536 |
|
537 | 537 | dpnp.check_supported_arrays_type(a)
|
538 |
| - check_stacked_2d(a) |
539 |
| - check_stacked_square(a) |
| 538 | + assert_stacked_2d(a) |
| 539 | + assert_stacked_square(a) |
540 | 540 |
|
541 | 541 | UPLO = UPLO.upper()
|
542 | 542 | if UPLO not in ("L", "U"):
|
@@ -591,8 +591,8 @@ def inv(a):
|
591 | 591 | """
|
592 | 592 |
|
593 | 593 | dpnp.check_supported_arrays_type(a)
|
594 |
| - check_stacked_2d(a) |
595 |
| - check_stacked_square(a) |
| 594 | + assert_stacked_2d(a) |
| 595 | + assert_stacked_square(a) |
596 | 596 |
|
597 | 597 | return dpnp_inv(a)
|
598 | 598 |
|
@@ -663,7 +663,7 @@ def lstsq(a, b, rcond=None):
|
663 | 663 | """
|
664 | 664 |
|
665 | 665 | dpnp.check_supported_arrays_type(a, b)
|
666 |
| - check_2d(a) |
| 666 | + assert_2d(a) |
667 | 667 | if rcond is not None and not isinstance(rcond, (int, float)):
|
668 | 668 | raise TypeError("rcond must be integer, floating type, or None")
|
669 | 669 |
|
@@ -724,8 +724,8 @@ def matrix_power(a, n):
|
724 | 724 | """
|
725 | 725 |
|
726 | 726 | dpnp.check_supported_arrays_type(a)
|
727 |
| - check_stacked_2d(a) |
728 |
| - check_stacked_square(a) |
| 727 | + assert_stacked_2d(a) |
| 728 | + assert_stacked_square(a) |
729 | 729 |
|
730 | 730 | if not isinstance(n, int):
|
731 | 731 | raise TypeError("exponent must be an integer")
|
@@ -896,7 +896,7 @@ def pinv(a, rcond=1e-15, hermitian=False):
|
896 | 896 |
|
897 | 897 | dpnp.check_supported_arrays_type(a)
|
898 | 898 | dpnp.check_supported_arrays_type(rcond, scalar_type=True)
|
899 |
| - check_stacked_2d(a) |
| 899 | + assert_stacked_2d(a) |
900 | 900 |
|
901 | 901 | return dpnp_pinv(a, rcond=rcond, hermitian=hermitian)
|
902 | 902 |
|
@@ -1067,7 +1067,7 @@ def qr(a, mode="reduced"):
|
1067 | 1067 | """
|
1068 | 1068 |
|
1069 | 1069 | dpnp.check_supported_arrays_type(a)
|
1070 |
| - check_stacked_2d(a) |
| 1070 | + assert_stacked_2d(a) |
1071 | 1071 |
|
1072 | 1072 | if mode not in ("reduced", "complete", "r", "raw"):
|
1073 | 1073 | raise ValueError(f"Unrecognized mode {mode}")
|
@@ -1114,8 +1114,8 @@ def solve(a, b):
|
1114 | 1114 | """
|
1115 | 1115 |
|
1116 | 1116 | dpnp.check_supported_arrays_type(a, b)
|
1117 |
| - check_stacked_2d(a) |
1118 |
| - check_stacked_square(a) |
| 1117 | + assert_stacked_2d(a) |
| 1118 | + assert_stacked_square(a) |
1119 | 1119 |
|
1120 | 1120 | if not (
|
1121 | 1121 | (a.ndim == b.ndim or a.ndim == b.ndim + 1)
|
@@ -1222,7 +1222,7 @@ def svd(a, full_matrices=True, compute_uv=True, hermitian=False):
|
1222 | 1222 | """
|
1223 | 1223 |
|
1224 | 1224 | dpnp.check_supported_arrays_type(a)
|
1225 |
| - check_stacked_2d(a) |
| 1225 | + assert_stacked_2d(a) |
1226 | 1226 |
|
1227 | 1227 | return dpnp_svd(a, full_matrices, compute_uv, hermitian)
|
1228 | 1228 |
|
@@ -1277,8 +1277,8 @@ def slogdet(a):
|
1277 | 1277 | """
|
1278 | 1278 |
|
1279 | 1279 | dpnp.check_supported_arrays_type(a)
|
1280 |
| - check_stacked_2d(a) |
1281 |
| - check_stacked_square(a) |
| 1280 | + assert_stacked_2d(a) |
| 1281 | + assert_stacked_square(a) |
1282 | 1282 |
|
1283 | 1283 | return dpnp_slogdet(a)
|
1284 | 1284 |
|
|
0 commit comments