Skip to content

Commit 16d09b9

Browse files
Incorporate feedback
1 parent 67cce84 commit 16d09b9

File tree

2 files changed

+19
-13
lines changed

2 files changed

+19
-13
lines changed

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -124,19 +124,25 @@ def impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b):
124124
_N = np.int32(A.shape[-1])
125125
_solve_check_input_shapes(A, B)
126126

127+
# Seems weird to not use the b_ndim input directly, but when I did that Numba complained that the output type
128+
# could potentially be 3d (it didn't understand b_ndim was always equal to B.ndim)
127129
B_is_1d = B.ndim == 1
128130

129-
A_copy = _copy_to_fortran_order(A)
131+
# This will only copy if A is not already fortran contiguous
132+
A_f = np.asfortranarray(A)
130133

131-
# This list is exhaustive, but numba freaks out if we include a final else clause
132-
if not overwrite_b and not B_is_1d:
133-
B_copy = _copy_to_fortran_order(B)
134-
elif overwrite_b and not B_is_1d:
135-
B_copy = np.asfortranarray(B)
136-
elif not overwrite_b and B_is_1d:
137-
B_copy = np.copy(np.expand_dims(B, -1))
138-
elif overwrite_b and B_is_1d:
139-
B_copy = np.expand_dims(B, -1)
134+
if overwrite_b:
135+
if B_is_1d:
136+
B_copy = np.expand_dims(B, -1)
137+
else:
138+
# This *will* allow inplace destruction of B, but only if it is already fortran contiguous.
139+
# Otherwise, there's no way to get around the need to copy the data before going into TRTRS
140+
B_copy = np.asfortranarray(B)
141+
else:
142+
if B_is_1d:
143+
B_copy = np.copy(np.expand_dims(B, -1))
144+
else:
145+
B_copy = _copy_to_fortran_order(B)
140146

141147
NRHS = 1 if B_is_1d else int(B_copy.shape[-1])
142148

@@ -155,7 +161,7 @@ def impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b):
155161
DIAG,
156162
N,
157163
NRHS,
158-
A_copy.view(w_type).ctypes,
164+
A_f.view(w_type).ctypes,
159165
LDA,
160166
B_copy.view(w_type).ctypes,
161167
LDB,

tests/link/numba/test_slinalg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,8 @@ def test_solve_triangular_overwrite_b_correct(overwrite_b):
141141
b_test_py = np.asfortranarray(rng.normal(size=(3, 2)))
142142

143143
# .T.copy().T creates an f-contiguous copy of an f-contiguous array (otherwise the copy is c-contiguous)
144-
a_test_nb = a_test_py.T.copy().T
145-
b_test_nb = b_test_py.T.copy().T
144+
a_test_nb = a_test_py.copy(order="F")
145+
b_test_nb = b_test_py.copy(order="F")
146146

147147
op = SolveTriangular(
148148
trans=0,

0 commit comments

Comments
 (0)