Skip to content

Commit 5d5a9ce

Browse files
committed
simplifed logic for inv check
1 parent 1e42e10 commit 5d5a9ce

File tree

2 files changed

+27
-32
lines changed

2 files changed

+27
-32
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import cast
44

55
from pytensor import Variable
6-
from pytensor.graph import Apply, FunctionGraph
6+
from pytensor.graph import Apply, Constant, FunctionGraph
77
from pytensor.graph.rewriting.basic import (
88
PatternNodeRewriter,
99
copy_stack_trace,
@@ -555,13 +555,16 @@ def _find_solve_with_eye(node) -> bool:
555555
return False
556556
# If the current op is solve, we check for b. If b is an identity matrix (Eye), we can return True
557557
solve_inputs = node.inputs
558-
eye_input = solve_inputs[1].owner
558+
eye_node = solve_inputs[1].owner
559559

560560
# We check for b = Eye and also make sure that if it was an Eye, then k = 0 (1's are only across the main diagonal)
561-
if not (eye_input and isinstance(eye_input.op, Eye)):
561+
if not (eye_node and isinstance(eye_node.op, Eye)):
562562
return False
563563

564-
if eye_input.inputs[-1].data.item() != 0:
564+
if (
565+
isinstance(eye_node.inputs[-1], Constant)
566+
and eye_node.inputs[-1].data.item() != 0
567+
):
565568
return False
566569
return True
567570

@@ -593,37 +596,35 @@ def rewrite_inv_inv(fgraph, node):
593596
# In case the outer operation is an inverse, it directly goes to the next step of finding inner operation
594597
# If the outer operation is a solve op with b = Eye, it treats it as inverse and finds the inner operation
595598
# If the outer operation is not an inverse (neither inv nor solve with eye), we do not apply this rewrite
596-
inv_check = False
597-
if isinstance(node.op, Blockwise) and isinstance(node.op.core_op, valid_inverses):
598-
inv_check = True
599-
if isinstance(node.op.core_op, valid_solves):
600-
inv_check = _find_solve_with_eye(node)
599+
if not isinstance(node.op.core_op, valid_inverses):
600+
return None
601601

602-
if not inv_check:
602+
if isinstance(node.op.core_op, valid_solves) and not _find_solve_with_eye(node):
603603
return None
604604

605605
potential_inner_inv = node.inputs[0].owner
606606
if potential_inner_inv is None or potential_inner_inv.op is None:
607607
return None
608608

609-
# Similar to the check for outer operation, we now run the same checks for the inner op.
610-
# If its an inverse or solve with eye, we apply the rewrite. Otherwise, we return None.
611-
inv_check_inner = False
612-
if isinstance(potential_inner_inv.op, Blockwise) and isinstance(
613-
potential_inner_inv.op.core_op, valid_inverses
614-
):
615-
inv_check_inner = True
616-
if isinstance(potential_inner_inv.op.core_op, valid_solves):
617-
inv_check_inner = _find_solve_with_eye(potential_inner_inv)
618-
619-
if not inv_check_inner:
620-
return None
621-
609+
# Check if inner op is blockwise and and possible inv
622610
if not (
623611
potential_inner_inv
624612
and isinstance(potential_inner_inv.op, Blockwise)
625613
and isinstance(node.op.core_op, valid_inverses)
626614
):
627615
return None
628616

617+
# Similar to the check for outer operation, we now run the same checks for the inner op.
618+
# If its an inverse or solve with eye, we apply the rewrite. Otherwise, we return None.
619+
if not (
620+
isinstance(potential_inner_inv.op, Blockwise)
621+
and isinstance(potential_inner_inv.op.core_op, valid_inverses)
622+
):
623+
return None
624+
625+
if isinstance(
626+
potential_inner_inv.op.core_op, valid_solves
627+
) and not _find_solve_with_eye(potential_inner_inv):
628+
return None
629+
629630
return [potential_inner_inv.inputs[0]]

tests/tensor/rewriting/test_linalg.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from pytensor import tensor as pt
1111
from pytensor.compile import get_default_mode
1212
from pytensor.configdefaults import config
13+
from pytensor.graph.rewriting.utils import rewrite_graph
1314
from pytensor.tensor import swapaxes
1415
from pytensor.tensor.blockwise import Blockwise
1516
from pytensor.tensor.elemwise import DimShuffle
@@ -558,12 +559,5 @@ def get_pt_function(x, op_name):
558559
x = pt.matrix("x")
559560
op1 = get_pt_function(x, inv_op_1)
560561
op2 = get_pt_function(op1, inv_op_2)
561-
f_rewritten = function([x], op2, mode="FAST_RUN")
562-
print(f_rewritten.dprint())
563-
nodes = f_rewritten.maker.fgraph.apply_nodes
564-
565-
valid_inverses = (MatrixInverse, MatrixPinv, Solve, SolveTriangular)
566-
567-
assert not any(isinstance(node.op, valid_inverses) for node in nodes)
568-
x_testing = np.random.rand(10, 10).astype(config.floatX)
569-
np.testing.assert_allclose(f_rewritten(x_testing), x_testing)
562+
rewritten_out = rewrite_graph(op2)
563+
assert rewritten_out == x

0 commit comments

Comments
 (0)