Skip to content

Commit c930399

Browse files
committed
Added alternative slogdet to return sign and logdet of det op
1 parent a377c22 commit c930399

File tree

2 files changed

+74
-72
lines changed

2 files changed

+74
-72
lines changed

pytensor/tensor/nlinalg.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,10 @@ def __str__(self):
266266
return "SLogDet"
267267

268268

269-
slogdet = Blockwise(SLogDet())
269+
# slogdet = Blockwise(SLogDet())
270+
def slogdet(x):
271+
det_val = det(x)
272+
return ptm.sign(det_val), ptm.log(ptm.abs(det_val))
270273

271274

272275
class Eig(Op):

pytensor/tensor/rewriting/linalg.py

Lines changed: 70 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
inv,
3333
kron,
3434
pinv,
35-
slogdet,
3635
svd,
3736
)
3837
from pytensor.tensor.rewriting.basic import (
@@ -781,43 +780,43 @@ def rewrite_det_blockdiag(fgraph, node):
781780
return [prod(det_sub_matrices)]
782781

783782

784-
@register_canonicalize
785-
@register_stabilize
786-
@node_rewriter([slogdet])
787-
def rewrite_slogdet_blockdiag(fgraph, node):
788-
"""
789-
This rewrite simplifies the slogdet of a blockdiagonal matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those
790-
791-
slogdet(block_diag(a,b,c,....)) = prod(sign(a), sign(b), sign(c),...), sum(logdet(a), logdet(b), logdet(c),....)
792-
793-
Parameters
794-
----------
795-
fgraph: FunctionGraph
796-
Function graph being optimized
797-
node: Apply
798-
Node of the function graph to be optimized
799-
800-
Returns
801-
-------
802-
list of Variable, optional
803-
List of optimized variables, or None if no optimization was performed
804-
"""
805-
# Check for inner block_diag operation
806-
potential_block_diag = node.inputs[0].owner
807-
if not (
808-
potential_block_diag
809-
and isinstance(potential_block_diag.op, Blockwise)
810-
and isinstance(potential_block_diag.op.core_op, BlockDiagonal)
811-
):
812-
return None
813-
814-
# Find the composing sub_matrices
815-
sub_matrices = potential_block_diag.inputs
816-
sign_sub_matrices, logdet_sub_matrices = zip(
817-
*[slogdet(sub_matrices[i]) for i in range(len(sub_matrices))]
818-
)
819-
820-
return [prod(sign_sub_matrices), sum(logdet_sub_matrices)]
783+
# @register_canonicalize
784+
# @register_stabilize
785+
# @node_rewriter([slogdet])
786+
# def rewrite_slogdet_blockdiag(fgraph, node):
787+
# """
788+
# This rewrite simplifies the slogdet of a blockdiagonal matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those
789+
790+
# slogdet(block_diag(a,b,c,....)) = prod(sign(a), sign(b), sign(c),...), sum(logdet(a), logdet(b), logdet(c),....)
791+
792+
# Parameters
793+
# ----------
794+
# fgraph: FunctionGraph
795+
# Function graph being optimized
796+
# node: Apply
797+
# Node of the function graph to be optimized
798+
799+
# Returns
800+
# -------
801+
# list of Variable, optional
802+
# List of optimized variables, or None if no optimization was performed
803+
# """
804+
# # Check for inner block_diag operation
805+
# potential_block_diag = node.inputs[0].owner
806+
# if not (
807+
# potential_block_diag
808+
# and isinstance(potential_block_diag.op, Blockwise)
809+
# and isinstance(potential_block_diag.op.core_op, BlockDiagonal)
810+
# ):
811+
# return None
812+
813+
# # Find the composing sub_matrices
814+
# sub_matrices = potential_block_diag.inputs
815+
# sign_sub_matrices, logdet_sub_matrices = zip(
816+
# *[slogdet(sub_matrices[i]) for i in range(len(sub_matrices))]
817+
# )
818+
819+
# return [prod(sign_sub_matrices), sum(logdet_sub_matrices)]
821820

822821

823822
@register_canonicalize
@@ -854,39 +853,39 @@ def rewrite_diag_kronecker(fgraph, node):
854853
return [outer_prod_as_vector]
855854

856855

857-
@register_canonicalize
858-
@register_stabilize
859-
@node_rewriter([slogdet])
860-
def rewrite_slogdet_kronecker(fgraph, node):
861-
"""
862-
This rewrite simplifies the slogdet of a kronecker-structured matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those
863-
864-
Parameters
865-
----------
866-
fgraph: FunctionGraph
867-
Function graph being optimized
868-
node: Apply
869-
Node of the function graph to be optimized
870-
871-
Returns
872-
-------
873-
list of Variable, optional
874-
List of optimized variables, or None if no optimization was performed
875-
"""
876-
# Check for inner kron operation
877-
potential_kron = node.inputs[0].owner
878-
if not (potential_kron and isinstance(potential_kron.op, KroneckerProduct)):
879-
return None
880-
881-
# Find the matrices
882-
a, b = potential_kron.inputs
883-
signs, logdets = zip(*[slogdet(a), slogdet(b)])
884-
sizes = [a.shape[-1], b.shape[-1]]
885-
prod_sizes = prod(sizes, no_zeros_in_input=True)
886-
signs_final = [signs[i] ** (prod_sizes / sizes[i]) for i in range(2)]
887-
logdet_final = [logdets[i] * prod_sizes / sizes[i] for i in range(2)]
888-
889-
return [prod(signs_final, no_zeros_in_input=True), sum(logdet_final)]
856+
# @register_canonicalize
857+
# @register_stabilize
858+
# @node_rewriter([slogdet])
859+
# def rewrite_slogdet_kronecker(fgraph, node):
860+
# """
861+
# This rewrite simplifies the slogdet of a kronecker-structured matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those
862+
863+
# Parameters
864+
# ----------
865+
# fgraph: FunctionGraph
866+
# Function graph being optimized
867+
# node: Apply
868+
# Node of the function graph to be optimized
869+
870+
# Returns
871+
# -------
872+
# list of Variable, optional
873+
# List of optimized variables, or None if no optimization was performed
874+
# """
875+
# # Check for inner kron operation
876+
# potential_kron = node.inputs[0].owner
877+
# if not (potential_kron and isinstance(potential_kron.op, KroneckerProduct)):
878+
# return None
879+
880+
# # Find the matrices
881+
# a, b = potential_kron.inputs
882+
# signs, logdets = zip(*[slogdet(a), slogdet(b)])
883+
# sizes = [a.shape[-1], b.shape[-1]]
884+
# prod_sizes = prod(sizes, no_zeros_in_input=True)
885+
# signs_final = [signs[i] ** (prod_sizes / sizes[i]) for i in range(2)]
886+
# logdet_final = [logdets[i] * prod_sizes / sizes[i] for i in range(2)]
887+
888+
# return [prod(signs_final, no_zeros_in_input=True), sum(logdet_final)]
890889

891890

892891
@register_canonicalize

0 commit comments

Comments
 (0)