Skip to content

Commit b2b7e28

Browse files
Fix bug with dummy output clients in local_det_chol rewrite (#393)
* check for dummy outputs in local_det_chol rewrite * add rewrite check to 2nd test case * fix test
1 parent 82aeefc commit b2b7e28

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,8 @@ def local_det_chol(fgraph, node):
162162
if isinstance(node.op, Det):
163163
(x,) = node.inputs
164164
for cl, xpos in fgraph.clients[x]:
165+
if cl == "output":
166+
continue
165167
if isinstance(cl.op, Cholesky):
166168
L = cl.outputs[0]
167169
return [prod(at.extract_diag(L) ** 2)]

tests/tensor/rewriting/test_linalg.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from pytensor.configdefaults import config
1212
from pytensor.tensor.elemwise import DimShuffle
1313
from pytensor.tensor.math import _allclose
14-
from pytensor.tensor.nlinalg import MatrixInverse, matrix_inverse
14+
from pytensor.tensor.nlinalg import Det, MatrixInverse, matrix_inverse
1515
from pytensor.tensor.rewriting.linalg import inv_as_solve
1616
from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular, solve
1717
from pytensor.tensor.type import dmatrix, matrix, vector
@@ -202,3 +202,19 @@ def test_cholesky_ldotlt(tag, cholesky_form, product):
202202
f(Av),
203203
)
204204
)
205+
206+
207+
def test_local_det_chol():
208+
X = matrix("X")
209+
L = at.linalg.cholesky(X)
210+
det_X = at.linalg.det(X)
211+
212+
f = function([X], [L, det_X])
213+
214+
nodes = f.maker.fgraph.toposort()
215+
assert not any(isinstance(node, Det) for node in nodes)
216+
217+
# This previously raised an error (issue #392)
218+
f = function([X], [L, det_X, X])
219+
nodes = f.maker.fgraph.toposort()
220+
assert not any(isinstance(node, Det) for node in nodes)

0 commit comments

Comments
 (0)