Skip to content

Rewrite scalar solve to division #1453

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,10 @@
from pytensor.tensor.slinalg import (
BlockDiagonal,
Cholesky,
CholeskySolve,
Solve,
SolveBase,
SolveTriangular,
_bilinear_solve_discrete_lyapunov,
block_diag,
cholesky,
Expand Down Expand Up @@ -908,6 +910,11 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
return None

[input] = node.inputs

# Check if input is a (1, 1) matrix
if all(input.type.broadcastable[:-2]):
Copy link
Preview

Copilot AI Jun 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check for a (1,1) matrix is incorrect: it tests all dimensions except the last two. It should verify the last two dims are broadcastable, e.g. all(input.type.broadcastable[-2:]).

Suggested change
if all(input.type.broadcastable[:-2]):
if all(input.type.broadcastable[-2:]):

Copilot uses AI. Check for mistakes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot seems right

return [pt.sqrt(input)]

# Check for use of pt.diag first
if (
input.owner
Expand Down Expand Up @@ -1020,3 +1027,42 @@ def slogdet_specialization(fgraph, node):
k: slogdet_specialization_map[v] for k, v in dummy_replacements.items()
}
return replacements


@register_stabilize
@register_canonicalize
@node_rewriter([Blockwise])
def scalar_solve_to_divison(fgraph, node):
"""
Replace solve(a, b) with b / a if a is a (1, 1) matrix
"""

core_op = node.op.core_op
if not isinstance(core_op, SolveBase):
return None

a, b = node.inputs
old_out = node.outputs[0]
if not all(a.broadcastable[-2:]):
return None

# Special handling for different types of solve
match core_op:
case SolveTriangular():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is expensive? Creating an Op in every match case?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely. I was just being too infatuated with match.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Match doesn't work with the class (instead of instance) as an instance check?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. See the match PEP

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It sounds like it could be clever and not really instantiate the Op in the match check. Can you test with a small example (need not be PyTensor specific)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it doesn't instantiate stuff:

class AB:
    def __init__(self):
        print("AB was initialized")

class AA(AB):
    def __init__(self):
        print("AA was initialized")

class BB(AB):
    def __init__(self):
        print("BB was initialized")

aa = AA()
bb = BB()

print("Starting match")
for x in (aa, bb, 1):
    match x:
	case AA():
            print("match aa")
        case BB():
            print("match bb")
        case AB():
            print("match ab")
        case _:
            print("match nothing")

# AA was initialized
# BB was initialized
# Starting match
# match aa
# match bb
# match nothing

Not sure how it's implemented, very surprising syntax imo, but seems fine to use like you did

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wonder if we could start using it for match x.owner: Apply(op=Elemwise(scalar_op=Add)) kind of stuff instead of x.owner is not None and isinstance(x.owner.op, Elemwise) and isinstance(x.owner.op, Add)) 🤔

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# Corner case: if user asked for a triangular solve with a unit diagonal, a is taken to be 1
new_out = b / a if not core_op.unit_diagonal else b
case CholeskySolve():
new_out = b / a**2
case Solve():
new_out = b / a
case _:
raise NotImplementedError(
f"Unsupported core_op type: {type(core_op)} in scalar_solve_to_divison"
)

if core_op.b_ndim == 1:
new_out = new_out.squeeze(-1)

copy_stack_trace(old_out, new_out)

return [new_out]
35 changes: 35 additions & 0 deletions tests/tensor/rewriting/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from pytensor.tensor.slinalg import (
BlockDiagonal,
Cholesky,
CholeskySolve,
Solve,
SolveBase,
SolveTriangular,
Expand Down Expand Up @@ -993,3 +994,37 @@ def test_slogdet_specialization():
f = function([x], [exp_det_x, sign_det_x], mode="FAST_RUN")
nodes = f.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, SLogDet) for node in nodes)


@pytest.mark.parametrize(
"Op, fn",
[
(Solve, pt.linalg.solve),
(SolveTriangular, pt.linalg.solve_triangular),
(CholeskySolve, pt.linalg.cho_solve),
],
)
def test_scalar_solve_to_division_rewrite(Op, fn):
rng = np.random.default_rng(sum(map(ord, "scalar_solve_to_division_rewrite")))

a = pt.dmatrix("a", shape=(1, 1))
b = pt.dvector("b")

if Op is CholeskySolve:
# cho_solve expects a tuple (c, lower) as the first input
c = fn((pt.linalg.cholesky(a), True), b, b_ndim=1)
else:
c = fn(a, b, b_ndim=1)

f = function([a, b], c, mode="FAST_RUN")
nodes = f.maker.fgraph.apply_nodes

assert not any(isinstance(node.op, Op) for node in nodes)

a_val = rng.normal(size=(1, 1)).astype(pytensor.config.floatX)
b_val = rng.normal(size=(1,)).astype(pytensor.config.floatX)

c_val = np.linalg.solve(a_val, b_val)
np.testing.assert_allclose(
f(a_val, b_val), c_val, rtol=1e-7 if config.floatX == "float64" else 1e-5
)
Loading