Skip to content

Commit 8c97bb2

Browse files
committed
Fix Numba pos solve condition number calculation
1 parent 2e5e38a commit 8c97bb2

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -884,7 +884,7 @@ def _posv(
884884
overwrite_b: bool,
885885
check_finite: bool,
886886
transposed: bool,
887-
) -> tuple[np.ndarray, int]:
887+
) -> tuple[np.ndarray, np.ndarray, int]:
888888
"""
889889
Placeholder for solving a linear system with a positive-definite matrix; used by linalg.solve.
890890
"""
@@ -901,7 +901,8 @@ def posv_impl(
901901
check_finite: bool,
902902
transposed: bool,
903903
) -> Callable[
904-
[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], tuple[np.ndarray, int]
904+
[np.ndarray, np.ndarray, bool, bool, bool, bool, bool],
905+
tuple[np.ndarray, np.ndarray, int],
905906
]:
906907
ensure_lapack()
907908
_check_scipy_linalg_matrix(A, "solve")
@@ -918,7 +919,7 @@ def impl(
918919
overwrite_b: bool,
919920
check_finite: bool,
920921
transposed: bool,
921-
) -> tuple[np.ndarray, int]:
922+
) -> tuple[np.ndarray, np.ndarray, int]:
922923
_solve_check_input_shapes(A, B)
923924

924925
_N = np.int32(A.shape[-1])
@@ -962,8 +963,9 @@ def impl(
962963
)
963964

964965
if B_is_1d:
965-
return B_copy[..., 0], int_ptr_to_val(INFO)
966-
return B_copy, int_ptr_to_val(INFO)
966+
B_copy = B_copy[..., 0]
967+
968+
return A_copy, B_copy, int_ptr_to_val(INFO)
967969

968970
return impl
969971

@@ -1064,10 +1066,12 @@ def impl(
10641066
) -> np.ndarray:
10651067
_solve_check_input_shapes(A, B)
10661068

1067-
x, info = _posv(A, B, lower, overwrite_a, overwrite_b, check_finite, transposed)
1069+
C, x, info = _posv(
1070+
A, B, lower, overwrite_a, overwrite_b, check_finite, transposed
1071+
)
10681072
_solve_check(A.shape[-1], info)
10691073

1070-
rcond, info = _pocon(x, _xlange(A))
1074+
rcond, info = _pocon(C, _xlange(A))
10711075
_solve_check(A.shape[-1], info=info, lamch=True, rcond=rcond)
10721076

10731077
return x

0 commit comments

Comments
 (0)