|
3 | 3 | from typing import cast
|
4 | 4 |
|
5 | 5 | from pytensor import Variable
|
6 |
| -from pytensor.graph import Apply, FunctionGraph |
| 6 | +from pytensor.graph import Apply, Constant, FunctionGraph |
7 | 7 | from pytensor.graph.rewriting.basic import (
|
8 | 8 | PatternNodeRewriter,
|
9 | 9 | copy_stack_trace,
|
@@ -555,13 +555,16 @@ def _find_solve_with_eye(node) -> bool:
|
555 | 555 | return False
|
556 | 556 | # If the current op is solve, we check for b. If b is an identity matrix (Eye), we can return True
|
557 | 557 | solve_inputs = node.inputs
|
558 |
| - eye_input = solve_inputs[1].owner |
| 558 | + eye_node = solve_inputs[1].owner |
559 | 559 |
|
560 | 560 | # We check for b = Eye and also make sure that if it was an Eye, then k = 0 (1's are only across the main diagonal)
|
561 |
| - if not (eye_input and isinstance(eye_input.op, Eye)): |
| 561 | + if not (eye_node and isinstance(eye_node.op, Eye)): |
562 | 562 | return False
|
563 | 563 |
|
564 |
| - if eye_input.inputs[-1].data.item() != 0: |
| 564 | + if ( |
| 565 | + isinstance(eye_node.inputs[-1], Constant) |
| 566 | + and eye_node.inputs[-1].data.item() != 0 |
| 567 | + ): |
565 | 568 | return False
|
566 | 569 | return True
|
567 | 570 |
|
@@ -593,37 +596,35 @@ def rewrite_inv_inv(fgraph, node):
|
593 | 596 | # In case the outer operation is an inverse, it directly goes to the next step of finding inner operation
|
594 | 597 | # If the outer operation is a solve op with b = Eye, it treats it as inverse and finds the inner operation
|
595 | 598 | # If the outer operation is not an inverse (neither inv nor solve with eye), we do not apply this rewrite
|
596 |
| - inv_check = False |
597 |
| - if isinstance(node.op, Blockwise) and isinstance(node.op.core_op, valid_inverses): |
598 |
| - inv_check = True |
599 |
| - if isinstance(node.op.core_op, valid_solves): |
600 |
| - inv_check = _find_solve_with_eye(node) |
| 599 | + if not isinstance(node.op.core_op, valid_inverses): |
| 600 | + return None |
601 | 601 |
|
602 |
| - if not inv_check: |
| 602 | + if isinstance(node.op.core_op, valid_solves) and not _find_solve_with_eye(node): |
603 | 603 | return None
|
604 | 604 |
|
605 | 605 | potential_inner_inv = node.inputs[0].owner
|
606 | 606 | if potential_inner_inv is None or potential_inner_inv.op is None:
|
607 | 607 | return None
|
608 | 608 |
|
609 |
| - # Similar to the check for outer operation, we now run the same checks for the inner op. |
610 |
| - # If its an inverse or solve with eye, we apply the rewrite. Otherwise, we return None. |
611 |
| - inv_check_inner = False |
612 |
| - if isinstance(potential_inner_inv.op, Blockwise) and isinstance( |
613 |
| - potential_inner_inv.op.core_op, valid_inverses |
614 |
| - ): |
615 |
| - inv_check_inner = True |
616 |
| - if isinstance(potential_inner_inv.op.core_op, valid_solves): |
617 |
| - inv_check_inner = _find_solve_with_eye(potential_inner_inv) |
618 |
| - |
619 |
| - if not inv_check_inner: |
620 |
| - return None |
621 |
| - |
| 609 | + # Check if inner op is blockwise and and possible inv |
622 | 610 | if not (
|
623 | 611 | potential_inner_inv
|
624 | 612 | and isinstance(potential_inner_inv.op, Blockwise)
|
625 | 613 | and isinstance(node.op.core_op, valid_inverses)
|
626 | 614 | ):
|
627 | 615 | return None
|
628 | 616 |
|
| 617 | + # Similar to the check for outer operation, we now run the same checks for the inner op. |
| 618 | + # If its an inverse or solve with eye, we apply the rewrite. Otherwise, we return None. |
| 619 | + if not ( |
| 620 | + isinstance(potential_inner_inv.op, Blockwise) |
| 621 | + and isinstance(potential_inner_inv.op.core_op, valid_inverses) |
| 622 | + ): |
| 623 | + return None |
| 624 | + |
| 625 | + if isinstance( |
| 626 | + potential_inner_inv.op.core_op, valid_solves |
| 627 | + ) and not _find_solve_with_eye(potential_inner_inv): |
| 628 | + return None |
| 629 | + |
629 | 630 | return [potential_inner_inv.inputs[0]]
|
0 commit comments