Skip to content

Commit 3f23a5d

Browse files
Incorporate feedback
1 parent 67cce84 commit 3f23a5d

File tree

2 files changed

+17
-13
lines changed

2 files changed

+17
-13
lines changed

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -126,17 +126,21 @@ def impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b):
126126

127127
B_is_1d = B.ndim == 1
128128

129-
A_copy = _copy_to_fortran_order(A)
130-
131-
# This list is exhaustive, but numba freaks out if we include a final else clause
132-
if not overwrite_b and not B_is_1d:
133-
B_copy = _copy_to_fortran_order(B)
134-
elif overwrite_b and not B_is_1d:
135-
B_copy = np.asfortranarray(B)
136-
elif not overwrite_b and B_is_1d:
137-
B_copy = np.copy(np.expand_dims(B, -1))
138-
elif overwrite_b and B_is_1d:
139-
B_copy = np.expand_dims(B, -1)
129+
# Despite the name, this won't copy if A is already fortran contiguous
130+
A_copy = np.asfortranarray(A)
131+
132+
if overwrite_b:
133+
if B_is_1d:
134+
B_copy = np.expand_dims(B, -1)
135+
else:
136+
# Same here, this *will* allow inplace destruction of B, but only if it is already fortran contiguous.
137+
# Otherwise, there's no way to get around the need to copy the data before going into TRTRS
138+
B_copy = np.asfortranarray(B)
139+
else:
140+
if B_is_1d:
141+
B_copy = np.copy(np.expand_dims(B, -1))
142+
else:
143+
B_copy = _copy_to_fortran_order(B)
140144

141145
NRHS = 1 if B_is_1d else int(B_copy.shape[-1])
142146

tests/link/numba/test_slinalg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,8 @@ def test_solve_triangular_overwrite_b_correct(overwrite_b):
141141
b_test_py = np.asfortranarray(rng.normal(size=(3, 2)))
142142

143143
# .T.copy().T creates an f-contiguous copy of an f-contiguous array (otherwise the copy is c-contiguous)
144-
a_test_nb = a_test_py.T.copy().T
145-
b_test_nb = b_test_py.T.copy().T
144+
a_test_nb = a_test_py.copy(order="F")
145+
b_test_nb = b_test_py.copy(order="F")
146146

147147
op = SolveTriangular(
148148
trans=0,

0 commit comments

Comments
 (0)