@@ -884,7 +884,7 @@ def _posv(
884
884
overwrite_b : bool ,
885
885
check_finite : bool ,
886
886
transposed : bool ,
887
- ) -> tuple [np .ndarray , int ]:
887
+ ) -> tuple [np .ndarray , np . ndarray , int ]:
888
888
"""
889
889
Placeholder for solving a linear system with a positive-definite matrix; used by linalg.solve.
890
890
"""
@@ -901,7 +901,8 @@ def posv_impl(
901
901
check_finite : bool ,
902
902
transposed : bool ,
903
903
) -> Callable [
904
- [np .ndarray , np .ndarray , bool , bool , bool , bool , bool ], tuple [np .ndarray , int ]
904
+ [np .ndarray , np .ndarray , bool , bool , bool , bool , bool ],
905
+ tuple [np .ndarray , np .ndarray , int ],
905
906
]:
906
907
ensure_lapack ()
907
908
_check_scipy_linalg_matrix (A , "solve" )
@@ -918,7 +919,7 @@ def impl(
918
919
overwrite_b : bool ,
919
920
check_finite : bool ,
920
921
transposed : bool ,
921
- ) -> tuple [np .ndarray , int ]:
922
+ ) -> tuple [np .ndarray , np . ndarray , int ]:
922
923
_solve_check_input_shapes (A , B )
923
924
924
925
_N = np .int32 (A .shape [- 1 ])
@@ -962,8 +963,9 @@ def impl(
962
963
)
963
964
964
965
if B_is_1d :
965
- return B_copy [..., 0 ], int_ptr_to_val (INFO )
966
- return B_copy , int_ptr_to_val (INFO )
966
+ B_copy = B_copy [..., 0 ]
967
+
968
+ return A_copy , B_copy , int_ptr_to_val (INFO )
967
969
968
970
return impl
969
971
@@ -1064,10 +1066,12 @@ def impl(
1064
1066
) -> np .ndarray :
1065
1067
_solve_check_input_shapes (A , B )
1066
1068
1067
- x , info = _posv (A , B , lower , overwrite_a , overwrite_b , check_finite , transposed )
1069
+ C , x , info = _posv (
1070
+ A , B , lower , overwrite_a , overwrite_b , check_finite , transposed
1071
+ )
1068
1072
_solve_check (A .shape [- 1 ], info )
1069
1073
1070
- rcond , info = _pocon (x , _xlange (A ))
1074
+ rcond , info = _pocon (C , _xlange (A ))
1071
1075
_solve_check (A .shape [- 1 ], info = info , lamch = True , rcond = rcond )
1072
1076
1073
1077
return x
0 commit comments