|
32 | 32 | inv,
|
33 | 33 | kron,
|
34 | 34 | pinv,
|
35 |
| - slogdet, |
36 | 35 | svd,
|
37 | 36 | )
|
38 | 37 | from pytensor.tensor.rewriting.basic import (
|
@@ -781,43 +780,43 @@ def rewrite_det_blockdiag(fgraph, node):
|
781 | 780 | return [prod(det_sub_matrices)]
|
782 | 781 |
|
783 | 782 |
|
784 |
| -@register_canonicalize |
785 |
| -@register_stabilize |
786 |
| -@node_rewriter([slogdet]) |
787 |
| -def rewrite_slogdet_blockdiag(fgraph, node): |
788 |
| - """ |
789 |
| - This rewrite simplifies the slogdet of a blockdiagonal matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those |
790 |
| -
|
791 |
| - slogdet(block_diag(a,b,c,....)) = prod(sign(a), sign(b), sign(c),...), sum(logdet(a), logdet(b), logdet(c),....) |
792 |
| -
|
793 |
| - Parameters |
794 |
| - ---------- |
795 |
| - fgraph: FunctionGraph |
796 |
| - Function graph being optimized |
797 |
| - node: Apply |
798 |
| - Node of the function graph to be optimized |
799 |
| -
|
800 |
| - Returns |
801 |
| - ------- |
802 |
| - list of Variable, optional |
803 |
| - List of optimized variables, or None if no optimization was performed |
804 |
| - """ |
805 |
| - # Check for inner block_diag operation |
806 |
| - potential_block_diag = node.inputs[0].owner |
807 |
| - if not ( |
808 |
| - potential_block_diag |
809 |
| - and isinstance(potential_block_diag.op, Blockwise) |
810 |
| - and isinstance(potential_block_diag.op.core_op, BlockDiagonal) |
811 |
| - ): |
812 |
| - return None |
813 |
| - |
814 |
| - # Find the composing sub_matrices |
815 |
| - sub_matrices = potential_block_diag.inputs |
816 |
| - sign_sub_matrices, logdet_sub_matrices = zip( |
817 |
| - *[slogdet(sub_matrices[i]) for i in range(len(sub_matrices))] |
818 |
| - ) |
819 |
| - |
820 |
| - return [prod(sign_sub_matrices), sum(logdet_sub_matrices)] |
| 783 | +# @register_canonicalize |
| 784 | +# @register_stabilize |
| 785 | +# @node_rewriter([slogdet]) |
| 786 | +# def rewrite_slogdet_blockdiag(fgraph, node): |
| 787 | +# """ |
| 788 | +# This rewrite simplifies the slogdet of a blockdiagonal matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those |
| 789 | + |
| 790 | +# slogdet(block_diag(a,b,c,....)) = prod(sign(a), sign(b), sign(c),...), sum(logdet(a), logdet(b), logdet(c),....) |
| 791 | + |
| 792 | +# Parameters |
| 793 | +# ---------- |
| 794 | +# fgraph: FunctionGraph |
| 795 | +# Function graph being optimized |
| 796 | +# node: Apply |
| 797 | +# Node of the function graph to be optimized |
| 798 | + |
| 799 | +# Returns |
| 800 | +# ------- |
| 801 | +# list of Variable, optional |
| 802 | +# List of optimized variables, or None if no optimization was performed |
| 803 | +# """ |
| 804 | +# # Check for inner block_diag operation |
| 805 | +# potential_block_diag = node.inputs[0].owner |
| 806 | +# if not ( |
| 807 | +# potential_block_diag |
| 808 | +# and isinstance(potential_block_diag.op, Blockwise) |
| 809 | +# and isinstance(potential_block_diag.op.core_op, BlockDiagonal) |
| 810 | +# ): |
| 811 | +# return None |
| 812 | + |
| 813 | +# # Find the composing sub_matrices |
| 814 | +# sub_matrices = potential_block_diag.inputs |
| 815 | +# sign_sub_matrices, logdet_sub_matrices = zip( |
| 816 | +# *[slogdet(sub_matrices[i]) for i in range(len(sub_matrices))] |
| 817 | +# ) |
| 818 | + |
| 819 | +# return [prod(sign_sub_matrices), sum(logdet_sub_matrices)] |
821 | 820 |
|
822 | 821 |
|
823 | 822 | @register_canonicalize
|
@@ -854,39 +853,39 @@ def rewrite_diag_kronecker(fgraph, node):
|
854 | 853 | return [outer_prod_as_vector]
|
855 | 854 |
|
856 | 855 |
|
857 |
| -@register_canonicalize |
858 |
| -@register_stabilize |
859 |
| -@node_rewriter([slogdet]) |
860 |
| -def rewrite_slogdet_kronecker(fgraph, node): |
861 |
| - """ |
862 |
| - This rewrite simplifies the slogdet of a kronecker-structured matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those |
863 |
| -
|
864 |
| - Parameters |
865 |
| - ---------- |
866 |
| - fgraph: FunctionGraph |
867 |
| - Function graph being optimized |
868 |
| - node: Apply |
869 |
| - Node of the function graph to be optimized |
870 |
| -
|
871 |
| - Returns |
872 |
| - ------- |
873 |
| - list of Variable, optional |
874 |
| - List of optimized variables, or None if no optimization was performed |
875 |
| - """ |
876 |
| - # Check for inner kron operation |
877 |
| - potential_kron = node.inputs[0].owner |
878 |
| - if not (potential_kron and isinstance(potential_kron.op, KroneckerProduct)): |
879 |
| - return None |
880 |
| - |
881 |
| - # Find the matrices |
882 |
| - a, b = potential_kron.inputs |
883 |
| - signs, logdets = zip(*[slogdet(a), slogdet(b)]) |
884 |
| - sizes = [a.shape[-1], b.shape[-1]] |
885 |
| - prod_sizes = prod(sizes, no_zeros_in_input=True) |
886 |
| - signs_final = [signs[i] ** (prod_sizes / sizes[i]) for i in range(2)] |
887 |
| - logdet_final = [logdets[i] * prod_sizes / sizes[i] for i in range(2)] |
888 |
| - |
889 |
| - return [prod(signs_final, no_zeros_in_input=True), sum(logdet_final)] |
| 856 | +# @register_canonicalize |
| 857 | +# @register_stabilize |
| 858 | +# @node_rewriter([slogdet]) |
| 859 | +# def rewrite_slogdet_kronecker(fgraph, node): |
| 860 | +# """ |
| 861 | +# This rewrite simplifies the slogdet of a kronecker-structured matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those |
| 862 | + |
| 863 | +# Parameters |
| 864 | +# ---------- |
| 865 | +# fgraph: FunctionGraph |
| 866 | +# Function graph being optimized |
| 867 | +# node: Apply |
| 868 | +# Node of the function graph to be optimized |
| 869 | + |
| 870 | +# Returns |
| 871 | +# ------- |
| 872 | +# list of Variable, optional |
| 873 | +# List of optimized variables, or None if no optimization was performed |
| 874 | +# """ |
| 875 | +# # Check for inner kron operation |
| 876 | +# potential_kron = node.inputs[0].owner |
| 877 | +# if not (potential_kron and isinstance(potential_kron.op, KroneckerProduct)): |
| 878 | +# return None |
| 879 | + |
| 880 | +# # Find the matrices |
| 881 | +# a, b = potential_kron.inputs |
| 882 | +# signs, logdets = zip(*[slogdet(a), slogdet(b)]) |
| 883 | +# sizes = [a.shape[-1], b.shape[-1]] |
| 884 | +# prod_sizes = prod(sizes, no_zeros_in_input=True) |
| 885 | +# signs_final = [signs[i] ** (prod_sizes / sizes[i]) for i in range(2)] |
| 886 | +# logdet_final = [logdets[i] * prod_sizes / sizes[i] for i in range(2)] |
| 887 | + |
| 888 | +# return [prod(signs_final, no_zeros_in_input=True), sum(logdet_final)] |
890 | 889 |
|
891 | 890 |
|
892 | 891 | @register_canonicalize
|
|
0 commit comments