Skip to content

Commit af31cea

Browse files
committed
Added alternative slogdet to return sign and logdet of det op
1 parent 6132203 commit af31cea

File tree

2 files changed

+74
-72
lines changed

2 files changed

+74
-72
lines changed

pytensor/tensor/nlinalg.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,10 @@ def __str__(self):
266266
return "SLogDet"
267267

268268

269-
slogdet = Blockwise(SLogDet())
269+
# slogdet = Blockwise(SLogDet())
270+
def slogdet(x):
271+
det_val = det(x)
272+
return ptm.sign(det_val), ptm.log(ptm.abs(det_val))
270273

271274

272275
class Eig(Op):

pytensor/tensor/rewriting/linalg.py

Lines changed: 70 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
inv,
3535
kron,
3636
pinv,
37-
slogdet,
3837
svd,
3938
)
4039
from pytensor.tensor.rewriting.basic import (
@@ -785,43 +784,43 @@ def rewrite_det_blockdiag(fgraph, node):
785784
return [prod(det_sub_matrices)]
786785

787786

788-
@register_canonicalize
789-
@register_stabilize
790-
@node_rewriter([slogdet])
791-
def rewrite_slogdet_blockdiag(fgraph, node):
792-
"""
793-
This rewrite simplifies the slogdet of a blockdiagonal matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those
794-
795-
slogdet(block_diag(a,b,c,....)) = prod(sign(a), sign(b), sign(c),...), sum(logdet(a), logdet(b), logdet(c),....)
796-
797-
Parameters
798-
----------
799-
fgraph: FunctionGraph
800-
Function graph being optimized
801-
node: Apply
802-
Node of the function graph to be optimized
803-
804-
Returns
805-
-------
806-
list of Variable, optional
807-
List of optimized variables, or None if no optimization was performed
808-
"""
809-
# Check for inner block_diag operation
810-
potential_block_diag = node.inputs[0].owner
811-
if not (
812-
potential_block_diag
813-
and isinstance(potential_block_diag.op, Blockwise)
814-
and isinstance(potential_block_diag.op.core_op, BlockDiagonal)
815-
):
816-
return None
817-
818-
# Find the composing sub_matrices
819-
sub_matrices = potential_block_diag.inputs
820-
sign_sub_matrices, logdet_sub_matrices = zip(
821-
*[slogdet(sub_matrices[i]) for i in range(len(sub_matrices))]
822-
)
823-
824-
return [prod(sign_sub_matrices), sum(logdet_sub_matrices)]
787+
# @register_canonicalize
788+
# @register_stabilize
789+
# @node_rewriter([slogdet])
790+
# def rewrite_slogdet_blockdiag(fgraph, node):
791+
# """
792+
# This rewrite simplifies the slogdet of a blockdiagonal matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those
793+
794+
# slogdet(block_diag(a,b,c,....)) = prod(sign(a), sign(b), sign(c),...), sum(logdet(a), logdet(b), logdet(c),....)
795+
796+
# Parameters
797+
# ----------
798+
# fgraph: FunctionGraph
799+
# Function graph being optimized
800+
# node: Apply
801+
# Node of the function graph to be optimized
802+
803+
# Returns
804+
# -------
805+
# list of Variable, optional
806+
# List of optimized variables, or None if no optimization was performed
807+
# """
808+
# # Check for inner block_diag operation
809+
# potential_block_diag = node.inputs[0].owner
810+
# if not (
811+
# potential_block_diag
812+
# and isinstance(potential_block_diag.op, Blockwise)
813+
# and isinstance(potential_block_diag.op.core_op, BlockDiagonal)
814+
# ):
815+
# return None
816+
817+
# # Find the composing sub_matrices
818+
# sub_matrices = potential_block_diag.inputs
819+
# sign_sub_matrices, logdet_sub_matrices = zip(
820+
# *[slogdet(sub_matrices[i]) for i in range(len(sub_matrices))]
821+
# )
822+
823+
# return [prod(sign_sub_matrices), sum(logdet_sub_matrices)]
825824

826825

827826
@register_canonicalize
@@ -858,39 +857,39 @@ def rewrite_diag_kronecker(fgraph, node):
858857
return [outer_prod_as_vector]
859858

860859

861-
@register_canonicalize
862-
@register_stabilize
863-
@node_rewriter([slogdet])
864-
def rewrite_slogdet_kronecker(fgraph, node):
865-
"""
866-
This rewrite simplifies the slogdet of a kronecker-structured matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those
867-
868-
Parameters
869-
----------
870-
fgraph: FunctionGraph
871-
Function graph being optimized
872-
node: Apply
873-
Node of the function graph to be optimized
874-
875-
Returns
876-
-------
877-
list of Variable, optional
878-
List of optimized variables, or None if no optimization was performed
879-
"""
880-
# Check for inner kron operation
881-
potential_kron = node.inputs[0].owner
882-
if not (potential_kron and isinstance(potential_kron.op, KroneckerProduct)):
883-
return None
884-
885-
# Find the matrices
886-
a, b = potential_kron.inputs
887-
signs, logdets = zip(*[slogdet(a), slogdet(b)])
888-
sizes = [a.shape[-1], b.shape[-1]]
889-
prod_sizes = prod(sizes, no_zeros_in_input=True)
890-
signs_final = [signs[i] ** (prod_sizes / sizes[i]) for i in range(2)]
891-
logdet_final = [logdets[i] * prod_sizes / sizes[i] for i in range(2)]
892-
893-
return [prod(signs_final, no_zeros_in_input=True), sum(logdet_final)]
860+
# @register_canonicalize
861+
# @register_stabilize
862+
# @node_rewriter([slogdet])
863+
# def rewrite_slogdet_kronecker(fgraph, node):
864+
# """
865+
# This rewrite simplifies the slogdet of a kronecker-structured matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those
866+
867+
# Parameters
868+
# ----------
869+
# fgraph: FunctionGraph
870+
# Function graph being optimized
871+
# node: Apply
872+
# Node of the function graph to be optimized
873+
874+
# Returns
875+
# -------
876+
# list of Variable, optional
877+
# List of optimized variables, or None if no optimization was performed
878+
# """
879+
# # Check for inner kron operation
880+
# potential_kron = node.inputs[0].owner
881+
# if not (potential_kron and isinstance(potential_kron.op, KroneckerProduct)):
882+
# return None
883+
884+
# # Find the matrices
885+
# a, b = potential_kron.inputs
886+
# signs, logdets = zip(*[slogdet(a), slogdet(b)])
887+
# sizes = [a.shape[-1], b.shape[-1]]
888+
# prod_sizes = prod(sizes, no_zeros_in_input=True)
889+
# signs_final = [signs[i] ** (prod_sizes / sizes[i]) for i in range(2)]
890+
# logdet_final = [logdets[i] * prod_sizes / sizes[i] for i in range(2)]
891+
892+
# return [prod(signs_final, no_zeros_in_input=True), sum(logdet_final)]
894893

895894

896895
@register_canonicalize

0 commit comments

Comments
 (0)