|
34 | 34 | inv,
|
35 | 35 | kron,
|
36 | 36 | pinv,
|
37 |
| - slogdet, |
38 | 37 | svd,
|
39 | 38 | )
|
40 | 39 | from pytensor.tensor.rewriting.basic import (
|
@@ -785,43 +784,43 @@ def rewrite_det_blockdiag(fgraph, node):
|
785 | 784 | return [prod(det_sub_matrices)]
|
786 | 785 |
|
787 | 786 |
|
788 |
| -@register_canonicalize |
789 |
| -@register_stabilize |
790 |
| -@node_rewriter([slogdet]) |
791 |
| -def rewrite_slogdet_blockdiag(fgraph, node): |
792 |
| - """ |
793 |
| - This rewrite simplifies the slogdet of a blockdiagonal matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those |
794 |
| -
|
795 |
| - slogdet(block_diag(a,b,c,....)) = prod(sign(a), sign(b), sign(c),...), sum(logdet(a), logdet(b), logdet(c),....) |
796 |
| -
|
797 |
| - Parameters |
798 |
| - ---------- |
799 |
| - fgraph: FunctionGraph |
800 |
| - Function graph being optimized |
801 |
| - node: Apply |
802 |
| - Node of the function graph to be optimized |
803 |
| -
|
804 |
| - Returns |
805 |
| - ------- |
806 |
| - list of Variable, optional |
807 |
| - List of optimized variables, or None if no optimization was performed |
808 |
| - """ |
809 |
| - # Check for inner block_diag operation |
810 |
| - potential_block_diag = node.inputs[0].owner |
811 |
| - if not ( |
812 |
| - potential_block_diag |
813 |
| - and isinstance(potential_block_diag.op, Blockwise) |
814 |
| - and isinstance(potential_block_diag.op.core_op, BlockDiagonal) |
815 |
| - ): |
816 |
| - return None |
817 |
| - |
818 |
| - # Find the composing sub_matrices |
819 |
| - sub_matrices = potential_block_diag.inputs |
820 |
| - sign_sub_matrices, logdet_sub_matrices = zip( |
821 |
| - *[slogdet(sub_matrices[i]) for i in range(len(sub_matrices))] |
822 |
| - ) |
823 |
| - |
824 |
| - return [prod(sign_sub_matrices), sum(logdet_sub_matrices)] |
| 787 | +# @register_canonicalize |
| 788 | +# @register_stabilize |
| 789 | +# @node_rewriter([slogdet]) |
| 790 | +# def rewrite_slogdet_blockdiag(fgraph, node): |
| 791 | +# """ |
| 792 | +# This rewrite simplifies the slogdet of a blockdiagonal matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those |
| 793 | + |
| 794 | +# slogdet(block_diag(a,b,c,....)) = prod(sign(a), sign(b), sign(c),...), sum(logdet(a), logdet(b), logdet(c),....) |
| 795 | + |
| 796 | +# Parameters |
| 797 | +# ---------- |
| 798 | +# fgraph: FunctionGraph |
| 799 | +# Function graph being optimized |
| 800 | +# node: Apply |
| 801 | +# Node of the function graph to be optimized |
| 802 | + |
| 803 | +# Returns |
| 804 | +# ------- |
| 805 | +# list of Variable, optional |
| 806 | +# List of optimized variables, or None if no optimization was performed |
| 807 | +# """ |
| 808 | +# # Check for inner block_diag operation |
| 809 | +# potential_block_diag = node.inputs[0].owner |
| 810 | +# if not ( |
| 811 | +# potential_block_diag |
| 812 | +# and isinstance(potential_block_diag.op, Blockwise) |
| 813 | +# and isinstance(potential_block_diag.op.core_op, BlockDiagonal) |
| 814 | +# ): |
| 815 | +# return None |
| 816 | + |
| 817 | +# # Find the composing sub_matrices |
| 818 | +# sub_matrices = potential_block_diag.inputs |
| 819 | +# sign_sub_matrices, logdet_sub_matrices = zip( |
| 820 | +# *[slogdet(sub_matrices[i]) for i in range(len(sub_matrices))] |
| 821 | +# ) |
| 822 | + |
| 823 | +# return [prod(sign_sub_matrices), sum(logdet_sub_matrices)] |
825 | 824 |
|
826 | 825 |
|
827 | 826 | @register_canonicalize
|
@@ -858,39 +857,39 @@ def rewrite_diag_kronecker(fgraph, node):
|
858 | 857 | return [outer_prod_as_vector]
|
859 | 858 |
|
860 | 859 |
|
861 |
| -@register_canonicalize |
862 |
| -@register_stabilize |
863 |
| -@node_rewriter([slogdet]) |
864 |
| -def rewrite_slogdet_kronecker(fgraph, node): |
865 |
| - """ |
866 |
| - This rewrite simplifies the slogdet of a kronecker-structured matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those |
867 |
| -
|
868 |
| - Parameters |
869 |
| - ---------- |
870 |
| - fgraph: FunctionGraph |
871 |
| - Function graph being optimized |
872 |
| - node: Apply |
873 |
| - Node of the function graph to be optimized |
874 |
| -
|
875 |
| - Returns |
876 |
| - ------- |
877 |
| - list of Variable, optional |
878 |
| - List of optimized variables, or None if no optimization was performed |
879 |
| - """ |
880 |
| - # Check for inner kron operation |
881 |
| - potential_kron = node.inputs[0].owner |
882 |
| - if not (potential_kron and isinstance(potential_kron.op, KroneckerProduct)): |
883 |
| - return None |
884 |
| - |
885 |
| - # Find the matrices |
886 |
| - a, b = potential_kron.inputs |
887 |
| - signs, logdets = zip(*[slogdet(a), slogdet(b)]) |
888 |
| - sizes = [a.shape[-1], b.shape[-1]] |
889 |
| - prod_sizes = prod(sizes, no_zeros_in_input=True) |
890 |
| - signs_final = [signs[i] ** (prod_sizes / sizes[i]) for i in range(2)] |
891 |
| - logdet_final = [logdets[i] * prod_sizes / sizes[i] for i in range(2)] |
892 |
| - |
893 |
| - return [prod(signs_final, no_zeros_in_input=True), sum(logdet_final)] |
| 860 | +# @register_canonicalize |
| 861 | +# @register_stabilize |
| 862 | +# @node_rewriter([slogdet]) |
| 863 | +# def rewrite_slogdet_kronecker(fgraph, node): |
| 864 | +# """ |
| 865 | +# This rewrite simplifies the slogdet of a kronecker-structured matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those |
| 866 | + |
| 867 | +# Parameters |
| 868 | +# ---------- |
| 869 | +# fgraph: FunctionGraph |
| 870 | +# Function graph being optimized |
| 871 | +# node: Apply |
| 872 | +# Node of the function graph to be optimized |
| 873 | + |
| 874 | +# Returns |
| 875 | +# ------- |
| 876 | +# list of Variable, optional |
| 877 | +# List of optimized variables, or None if no optimization was performed |
| 878 | +# """ |
| 879 | +# # Check for inner kron operation |
| 880 | +# potential_kron = node.inputs[0].owner |
| 881 | +# if not (potential_kron and isinstance(potential_kron.op, KroneckerProduct)): |
| 882 | +# return None |
| 883 | + |
| 884 | +# # Find the matrices |
| 885 | +# a, b = potential_kron.inputs |
| 886 | +# signs, logdets = zip(*[slogdet(a), slogdet(b)]) |
| 887 | +# sizes = [a.shape[-1], b.shape[-1]] |
| 888 | +# prod_sizes = prod(sizes, no_zeros_in_input=True) |
| 889 | +# signs_final = [signs[i] ** (prod_sizes / sizes[i]) for i in range(2)] |
| 890 | +# logdet_final = [logdets[i] * prod_sizes / sizes[i] for i in range(2)] |
| 891 | + |
| 892 | +# return [prod(signs_final, no_zeros_in_input=True), sum(logdet_final)] |
894 | 893 |
|
895 | 894 |
|
896 | 895 | @register_canonicalize
|
|
0 commit comments