Skip to content

Reuse cholesky decomposition with cho_solve in graphs with multiple pt.solve when assume_a = "pos" #1467

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 49 additions & 26 deletions pytensor/tensor/_linalg/solve/rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,29 @@
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.rewriting.basic import register_specialize
from pytensor.tensor.rewriting.linalg import is_matrix_transpose
from pytensor.tensor.slinalg import Solve, lu_factor, lu_solve
from pytensor.tensor.slinalg import Solve, cho_solve, cholesky, lu_factor, lu_solve
from pytensor.tensor.variable import TensorVariable


def decompose_A(A, assume_a, check_finite):
def decompose_A(A, assume_a, check_finite, lower):
if assume_a == "gen":
return lu_factor(A, check_finite=check_finite)
elif assume_a == "tridiagonal":
# We didn't implement check_finite for tridiagonal LU factorization
return tridiagonal_lu_factor(A)
elif assume_a == "pos":
return cholesky(A, lower=lower, check_finite=check_finite)
else:
raise NotImplementedError


def solve_lu_decomposed_system(A_decomp, b, transposed=False, *, core_solve_op: Solve):
def solve_decomposed_system(
A_decomp, b, transposed=False, lower=False, *, core_solve_op: Solve
):
b_ndim = core_solve_op.b_ndim
check_finite = core_solve_op.check_finite
assume_a = core_solve_op.assume_a

if assume_a == "gen":
return lu_solve(
A_decomp,
Expand All @@ -49,11 +54,19 @@ def solve_lu_decomposed_system(A_decomp, b, transposed=False, *, core_solve_op:
b_ndim=b_ndim,
transposed=transposed,
)
elif assume_a == "pos":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's True in (A_decomp, True)? Whether it's upper or lower? Don't we need to know?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's the lower flag. I was thinking it doesn't matter because we will be adding in the decomposition ourselves via rewrite, so we control which one is done. I could respect the setting on the solve Op if you think that's better.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fine, maybe add a comment for future devs?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And the transposed doesn't matter because it's symmetric?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah exactly

Copy link
Member Author

@jessegrabowski jessegrabowski Jun 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah exactly. I brought up that it doesn't have a flag because the nodes should still be merged right? Or do the inputs need to be the same as well? I was thinking cho_solve((A, False), b) and cho_solve((A, True), b) would be the same function (with different inputs ofc)

We could change (A, False) to (A.T, True), but then the inputs still aren't the same. The more I think about it, the more I believe we have to be respectful of the user's flags, in case only one half of A is being stored.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But the user isn't creating chol_factor nor the chol solve in these rewrites so it doesn't matter ever? Unless I'm missing something your first approach was correct.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably shouldn't even allow chol_factor, lower=True at the graph level, but always do upper and transpose if the user requested.

It's like the solve transposed, we handle the transpositions symbolically to keep less variations floating around.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

solve(A, b, lower=False, assume_a='pos') will only ever look at the upper triangle of A to do the computation. So the user might pass in data that is structured in a special way, taking this into account (for example -- only storing half of the matrix in memory).

When we rewrite, if we choose to always use c_and_lower = (cholesky(A), True), regardless of what was requested, we are assuming that the A matrix is actually symmetrical. That assumption isn't consistent with what LAPACK actually requires, so it could lead to (silent!) incorrect computation.

I don't see any any downside to respecting what the user asked for in the rewrite.

Copy link
Member

@ricardoV94 ricardoV94 Jun 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can just transpose A in that case. The issue is one of merging less scenarios. What happens now if user has a Solve(A, b1, lower=True), and another Solve(A.T, b2, lower=False).

Are we merging it here correctly? You were ignoring the transpose info coming from the rewrite that's used by the other solves.

That's what should determine the flag, not the original lower. Or the two together. Here we actually have two lowers, which one is used?

Note that if we never represented one of the forms our scenario simplifies.

# We can ignore the transposed argument here because A is symmetric by assumption
return cho_solve(
(A_decomp, lower),
b,
b_ndim=b_ndim,
check_finite=check_finite,
)
else:
raise NotImplementedError


def _split_lu_solve_steps(
def _split_decomp_and_solve_steps(
fgraph, node, *, eager: bool, allowed_assume_a: Container[str]
):
if not isinstance(node.op.core_op, Solve):
Expand Down Expand Up @@ -133,13 +146,21 @@ def find_solve_clients(var, assume_a):
if client.op.core_op.check_finite:
check_finite_decomp = True
break
A_decomp = decompose_A(A, assume_a=assume_a, check_finite=check_finite_decomp)

lower = node.op.core_op.lower
A_decomp = decompose_A(
A, assume_a=assume_a, check_finite=check_finite_decomp, lower=lower
)

replacements = {}
for client, transposed in A_solve_clients_and_transpose:
_, b = client.inputs
new_x = solve_lu_decomposed_system(
A_decomp, b, transposed=transposed, core_solve_op=client.op.core_op
new_x = solve_decomposed_system(
A_decomp,
b,
transposed=transposed,
lower=lower,
core_solve_op=client.op.core_op,
)
[old_x] = client.outputs
new_x = atleast_Nd(new_x, n=old_x.type.ndim).astype(old_x.type.dtype)
Expand All @@ -149,7 +170,7 @@ def find_solve_clients(var, assume_a):
return replacements


def _scan_split_non_sequence_lu_decomposition_solve(
def _scan_split_non_sequence_decomposition_and_solve(
fgraph, node, *, allowed_assume_a: Container[str]
):
"""If the A of a Solve within a Scan is a function of non-sequences, split the LU decomposition step.
Expand Down Expand Up @@ -179,7 +200,7 @@ def _scan_split_non_sequence_lu_decomposition_solve(
non_sequences = {equiv[non_seq] for non_seq in non_sequences}
inner_node = equiv[inner_node] # type: ignore

replace_dict = _split_lu_solve_steps(
replace_dict = _split_decomp_and_solve_steps(
new_scan_fgraph,
inner_node,
eager=True,
Expand Down Expand Up @@ -207,22 +228,22 @@ def _scan_split_non_sequence_lu_decomposition_solve(

@register_specialize
@node_rewriter([Blockwise])
def reuse_lu_decomposition_multiple_solves(fgraph, node):
return _split_lu_solve_steps(
fgraph, node, eager=False, allowed_assume_a={"gen", "tridiagonal"}
def reuse_decomposition_multiple_solves(fgraph, node):
return _split_decomp_and_solve_steps(
fgraph, node, eager=False, allowed_assume_a={"gen", "tridiagonal", "pos"}
)


@node_rewriter([Scan])
def scan_split_non_sequence_lu_decomposition_solve(fgraph, node):
return _scan_split_non_sequence_lu_decomposition_solve(
fgraph, node, allowed_assume_a={"gen", "tridiagonal"}
def scan_split_non_sequence_decomposition_and_solve(fgraph, node):
return _scan_split_non_sequence_decomposition_and_solve(
fgraph, node, allowed_assume_a={"gen", "tridiagonal", "pos"}
)


scan_seqopt1.register(
"scan_split_non_sequence_lu_decomposition_solve",
in2out(scan_split_non_sequence_lu_decomposition_solve, ignore_newtrees=True),
scan_split_non_sequence_decomposition_and_solve.__name__,
in2out(scan_split_non_sequence_decomposition_and_solve, ignore_newtrees=True),
"fast_run",
"scan",
"scan_pushout",
Expand All @@ -231,28 +252,30 @@ def scan_split_non_sequence_lu_decomposition_solve(fgraph, node):


@node_rewriter([Blockwise])
def reuse_lu_decomposition_multiple_solves_jax(fgraph, node):
return _split_lu_solve_steps(fgraph, node, eager=False, allowed_assume_a={"gen"})
def reuse_decomposition_multiple_solves_jax(fgraph, node):
return _split_decomp_and_solve_steps(
fgraph, node, eager=False, allowed_assume_a={"gen", "pos"}
)


optdb["specialize"].register(
reuse_lu_decomposition_multiple_solves_jax.__name__,
in2out(reuse_lu_decomposition_multiple_solves_jax, ignore_newtrees=True),
reuse_decomposition_multiple_solves_jax.__name__,
in2out(reuse_decomposition_multiple_solves_jax, ignore_newtrees=True),
"jax",
use_db_name_as_tag=False,
)


@node_rewriter([Scan])
def scan_split_non_sequence_lu_decomposition_solve_jax(fgraph, node):
return _scan_split_non_sequence_lu_decomposition_solve(
fgraph, node, allowed_assume_a={"gen"}
def scan_split_non_sequence_decomposition_and_solve_jax(fgraph, node):
return _scan_split_non_sequence_decomposition_and_solve(
fgraph, node, allowed_assume_a={"gen", "pos"}
)


scan_seqopt1.register(
scan_split_non_sequence_lu_decomposition_solve_jax.__name__,
in2out(scan_split_non_sequence_lu_decomposition_solve_jax, ignore_newtrees=True),
scan_split_non_sequence_decomposition_and_solve_jax.__name__,
in2out(scan_split_non_sequence_decomposition_and_solve_jax, ignore_newtrees=True),
"jax",
use_db_name_as_tag=False,
position=2,
Expand Down
Loading