Skip to content

Commit e7dec4d

Browse files
Fix solve_triangular output when overwrite_b=True (#1235)
* Fix bug in solve_triangular when `overwrite_b = True` * Add regression test
1 parent 5d4e9e0 commit e7dec4d

File tree

2 files changed

+56
-7
lines changed

2 files changed

+56
-7
lines changed

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -124,20 +124,26 @@ def impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b):
124124
_N = np.int32(A.shape[-1])
125125
_solve_check_input_shapes(A, B)
126126

127+
# Seems weird to not use the b_ndim input directly, but when I did that Numba complained that the output type
128+
# could potentially be 3d (it didn't understand b_ndim was always equal to B.ndim)
127129
B_is_1d = B.ndim == 1
128130

131+
# This will only copy if A is not already fortran contiguous
132+
A_f = np.asfortranarray(A)
133+
129134
if overwrite_b:
130-
B_copy = B
135+
if B_is_1d:
136+
B_copy = np.expand_dims(B, -1)
137+
else:
138+
# This *will* allow inplace destruction of B, but only if it is already fortran contiguous.
139+
# Otherwise, there's no way to get around the need to copy the data before going into TRTRS
140+
B_copy = np.asfortranarray(B)
131141
else:
132142
if B_is_1d:
133-
# _copy_to_fortran_order does nothing with vectors
134-
B_copy = np.copy(B)
143+
B_copy = np.copy(np.expand_dims(B, -1))
135144
else:
136145
B_copy = _copy_to_fortran_order(B)
137146

138-
if B_is_1d:
139-
B_copy = np.expand_dims(B_copy, -1)
140-
141147
NRHS = 1 if B_is_1d else int(B_copy.shape[-1])
142148

143149
UPLO = val_to_int_ptr(ord("L") if lower else ord("U"))
@@ -155,7 +161,7 @@ def impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b):
155161
DIAG,
156162
N,
157163
NRHS,
158-
np.asfortranarray(A).T.view(w_type).ctypes,
164+
A_f.view(w_type).ctypes,
159165
LDA,
160166
B_copy.view(w_type).ctypes,
161167
LDB,

tests/link/numba/test_slinalg.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import pytensor
1111
import pytensor.tensor as pt
1212
from pytensor import config
13+
from pytensor.tensor.slinalg import SolveTriangular
1314
from tests import unittest_tools as utt
1415
from tests.link.numba.test_basic import compare_numba_and_py
1516

@@ -130,6 +131,48 @@ def A_func_pt(x):
130131
)
131132

132133

134+
@pytest.mark.parametrize("overwrite_b", [True, False], ids=["inplace", "not_inplace"])
135+
def test_solve_triangular_overwrite_b_correct(overwrite_b):
136+
# Regression test for issue #1233
137+
138+
rng = np.random.default_rng(utt.fetch_seed())
139+
a_test_py = np.asfortranarray(rng.normal(size=(3, 3)))
140+
a_test_py = np.tril(a_test_py)
141+
b_test_py = np.asfortranarray(rng.normal(size=(3, 2)))
142+
143+
# .T.copy().T creates an f-contiguous copy of an f-contiguous array (otherwise the copy is c-contiguous)
144+
a_test_nb = a_test_py.copy(order="F")
145+
b_test_nb = b_test_py.copy(order="F")
146+
147+
op = SolveTriangular(
148+
trans=0,
149+
unit_diagonal=False,
150+
lower=False,
151+
check_finite=True,
152+
b_ndim=2,
153+
overwrite_b=overwrite_b,
154+
)
155+
156+
a_pt = pt.matrix("a", shape=(3, 3))
157+
b_pt = pt.matrix("b", shape=(3, 2))
158+
out = op(a_pt, b_pt)
159+
160+
py_fn = pytensor.function([a_pt, b_pt], out, accept_inplace=True)
161+
numba_fn = pytensor.function([a_pt, b_pt], out, accept_inplace=True, mode="NUMBA")
162+
163+
x_py = py_fn(a_test_py, b_test_py)
164+
x_nb = numba_fn(a_test_nb, b_test_nb)
165+
166+
np.testing.assert_allclose(
167+
py_fn(a_test_py, b_test_py), numba_fn(a_test_nb, b_test_nb)
168+
)
169+
np.testing.assert_allclose(b_test_py, b_test_nb)
170+
171+
if overwrite_b:
172+
np.testing.assert_allclose(b_test_py, x_py)
173+
np.testing.assert_allclose(b_test_nb, x_nb)
174+
175+
133176
@pytest.mark.parametrize("value", [np.nan, np.inf])
134177
@pytest.mark.filterwarnings(
135178
'ignore:Cannot cache compiled function "numba_funcified_fgraph"'

0 commit comments

Comments
 (0)