Skip to content

Updates a rewrite for det(kronecker) instead of slogdet #1042

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 78 additions & 45 deletions pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
inv,
kron,
pinv,
slogdet,
svd,
)
from pytensor.tensor.rewriting.basic import (
Expand Down Expand Up @@ -781,43 +780,43 @@ def rewrite_det_blockdiag(fgraph, node):
return [prod(det_sub_matrices)]


@register_canonicalize
@register_stabilize
@node_rewriter([slogdet])
def rewrite_slogdet_blockdiag(fgraph, node):
"""
This rewrite simplifies the slogdet of a blockdiagonal matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those

slogdet(block_diag(a,b,c,....)) = prod(sign(a), sign(b), sign(c),...), sum(logdet(a), logdet(b), logdet(c),....)

Parameters
----------
fgraph: FunctionGraph
Function graph being optimized
node: Apply
Node of the function graph to be optimized

Returns
-------
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""
# Check for inner block_diag operation
potential_block_diag = node.inputs[0].owner
if not (
potential_block_diag
and isinstance(potential_block_diag.op, Blockwise)
and isinstance(potential_block_diag.op.core_op, BlockDiagonal)
):
return None

# Find the composing sub_matrices
sub_matrices = potential_block_diag.inputs
sign_sub_matrices, logdet_sub_matrices = zip(
*[slogdet(sub_matrices[i]) for i in range(len(sub_matrices))]
)

return [prod(sign_sub_matrices), sum(logdet_sub_matrices)]
# @register_canonicalize
# @register_stabilize
# @node_rewriter([slogdet])
# def rewrite_slogdet_blockdiag(fgraph, node):
# """
# This rewrite simplifies the slogdet of a blockdiagonal matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those

# slogdet(block_diag(a,b,c,....)) = prod(sign(a), sign(b), sign(c),...), sum(logdet(a), logdet(b), logdet(c),....)

# Parameters
# ----------
# fgraph: FunctionGraph
# Function graph being optimized
# node: Apply
# Node of the function graph to be optimized

# Returns
# -------
# list of Variable, optional
# List of optimized variables, or None if no optimization was performed
# """
# # Check for inner block_diag operation
# potential_block_diag = node.inputs[0].owner
# if not (
# potential_block_diag
# and isinstance(potential_block_diag.op, Blockwise)
# and isinstance(potential_block_diag.op.core_op, BlockDiagonal)
# ):
# return None

# # Find the composing sub_matrices
# sub_matrices = potential_block_diag.inputs
# sign_sub_matrices, logdet_sub_matrices = zip(
# *[slogdet(sub_matrices[i]) for i in range(len(sub_matrices))]
# )

# return [prod(sign_sub_matrices), sum(logdet_sub_matrices)]


@register_canonicalize
Expand Down Expand Up @@ -854,12 +853,47 @@ def rewrite_diag_kronecker(fgraph, node):
return [outer_prod_as_vector]


# @register_canonicalize
# @register_stabilize
# @node_rewriter([slogdet])
# def rewrite_slogdet_kronecker(fgraph, node):
# """
# This rewrite simplifies the slogdet of a kronecker-structured matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those

# Parameters
# ----------
# fgraph: FunctionGraph
# Function graph being optimized
# node: Apply
# Node of the function graph to be optimized

# Returns
# -------
# list of Variable, optional
# List of optimized variables, or None if no optimization was performed
# """
# # Check for inner kron operation
# potential_kron = node.inputs[0].owner
# if not (potential_kron and isinstance(potential_kron.op, KroneckerProduct)):
# return None

# # Find the matrices
# a, b = potential_kron.inputs
# signs, logdets = zip(*[slogdet(a), slogdet(b)])
# sizes = [a.shape[-1], b.shape[-1]]
# prod_sizes = prod(sizes, no_zeros_in_input=True)
# signs_final = [signs[i] ** (prod_sizes / sizes[i]) for i in range(2)]
# logdet_final = [logdets[i] * prod_sizes / sizes[i] for i in range(2)]

# return [prod(signs_final, no_zeros_in_input=True), sum(logdet_final)]


@register_canonicalize
@register_stabilize
@node_rewriter([slogdet])
def rewrite_slogdet_kronecker(fgraph, node):
@node_rewriter([det])
def rewrite_det_kronecker(fgraph, node):
"""
This rewrite simplifies the slogdet of a kronecker-structured matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those
This rewrite simplifies the determinant of a kronecker-structured matrix by extracting the individual sub matrices and returning the det values computed using those

Parameters
----------
Expand All @@ -880,13 +914,12 @@ def rewrite_slogdet_kronecker(fgraph, node):

# Find the matrices
a, b = potential_kron.inputs
signs, logdets = zip(*[slogdet(a), slogdet(b)])
dets = [det(a), det(b)]
sizes = [a.shape[-1], b.shape[-1]]
prod_sizes = prod(sizes, no_zeros_in_input=True)
signs_final = [signs[i] ** (prod_sizes / sizes[i]) for i in range(2)]
logdet_final = [logdets[i] * prod_sizes / sizes[i] for i in range(2)]
det_final = prod([dets[i] ** (prod_sizes / sizes[i]) for i in range(2)])

return [prod(signs_final, no_zeros_in_input=True), sum(logdet_final)]
return [det_final]


@register_canonicalize
Expand Down
49 changes: 36 additions & 13 deletions tests/tensor/rewriting/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,11 +776,40 @@ def test_diag_kronecker_rewrite():
)


def test_slogdet_kronecker_rewrite():
# def test_slogdet_kronecker_rewrite():
# a, b = pt.dmatrices("a", "b")
# kron_prod = pt.linalg.kron(a, b)
# sign_output, logdet_output = pt.linalg.slogdet(kron_prod)
# f_rewritten = function([kron_prod], [sign_output, logdet_output], mode="FAST_RUN")

# # Rewrite Test
# nodes = f_rewritten.maker.fgraph.apply_nodes
# assert not any(isinstance(node.op, KroneckerProduct) for node in nodes)

# # Value Test
# a_test, b_test = np.random.rand(2, 20, 20)
# kron_prod_test = np.kron(a_test, b_test)
# sign_output_test, logdet_output_test = np.linalg.slogdet(kron_prod_test)
# rewritten_sign_val, rewritten_logdet_val = f_rewritten(kron_prod_test)
# assert_allclose(
# sign_output_test,
# rewritten_sign_val,
# atol=1e-3 if config.floatX == "float32" else 1e-8,
# rtol=1e-3 if config.floatX == "float32" else 1e-8,
# )
# assert_allclose(
# logdet_output_test,
# rewritten_logdet_val,
# atol=1e-3 if config.floatX == "float32" else 1e-8,
# rtol=1e-3 if config.floatX == "float32" else 1e-8,
# )


def test_det_kronecker_rewrite():
a, b = pt.dmatrices("a", "b")
kron_prod = pt.linalg.kron(a, b)
sign_output, logdet_output = pt.linalg.slogdet(kron_prod)
f_rewritten = function([kron_prod], [sign_output, logdet_output], mode="FAST_RUN")
det_output = pt.linalg.det(kron_prod)
f_rewritten = function([kron_prod], [det_output], mode="FAST_RUN")

# Rewrite Test
nodes = f_rewritten.maker.fgraph.apply_nodes
Expand All @@ -789,17 +818,11 @@ def test_slogdet_kronecker_rewrite():
# Value Test
a_test, b_test = np.random.rand(2, 20, 20)
kron_prod_test = np.kron(a_test, b_test)
sign_output_test, logdet_output_test = np.linalg.slogdet(kron_prod_test)
rewritten_sign_val, rewritten_logdet_val = f_rewritten(kron_prod_test)
det_output_test = np.linalg.det(kron_prod_test)
rewritten_det_val = f_rewritten(kron_prod_test)
assert_allclose(
sign_output_test,
rewritten_sign_val,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)
assert_allclose(
logdet_output_test,
rewritten_logdet_val,
det_output_test,
rewritten_det_val,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)
Expand Down
Loading