@@ -124,19 +124,25 @@ def impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b):
124
124
_N = np .int32 (A .shape [- 1 ])
125
125
_solve_check_input_shapes (A , B )
126
126
127
+ # Seems weird to not use the b_ndim input directly, but when I did that Numba complained that the output type
128
+ # could potentially be 3d (it didn't understand b_ndim was always equal to B.ndim)
127
129
B_is_1d = B .ndim == 1
128
130
129
- A_copy = _copy_to_fortran_order (A )
131
+ # This will only copy if A is not already fortran contiguous
132
+ A_f = np .asfortranarray (A )
130
133
131
- # This list is exhaustive, but numba freaks out if we include a final else clause
132
- if not overwrite_b and not B_is_1d :
133
- B_copy = _copy_to_fortran_order (B )
134
- elif overwrite_b and not B_is_1d :
135
- B_copy = np .asfortranarray (B )
136
- elif not overwrite_b and B_is_1d :
137
- B_copy = np .copy (np .expand_dims (B , - 1 ))
138
- elif overwrite_b and B_is_1d :
139
- B_copy = np .expand_dims (B , - 1 )
134
+ if overwrite_b :
135
+ if B_is_1d :
136
+ B_copy = np .expand_dims (B , - 1 )
137
+ else :
138
+ # This *will* allow inplace destruction of B, but only if it is already fortran contiguous.
139
+ # Otherwise, there's no way to get around the need to copy the data before going into TRTRS
140
+ B_copy = np .asfortranarray (B )
141
+ else :
142
+ if B_is_1d :
143
+ B_copy = np .copy (np .expand_dims (B , - 1 ))
144
+ else :
145
+ B_copy = _copy_to_fortran_order (B )
140
146
141
147
NRHS = 1 if B_is_1d else int (B_copy .shape [- 1 ])
142
148
@@ -155,7 +161,7 @@ def impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b):
155
161
DIAG ,
156
162
N ,
157
163
NRHS ,
158
- A_copy .view (w_type ).ctypes ,
164
+ A_f .view (w_type ).ctypes ,
159
165
LDA ,
160
166
B_copy .view (w_type ).ctypes ,
161
167
LDB ,
0 commit comments