Skip to content

Rewrite scalar solve to division #1453

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

jessegrabowski
Copy link
Member

@jessegrabowski jessegrabowski commented Jun 7, 2025

Description

Small optimization that rewrites solve(a,b) -> b / a when the core shape of a is (1, 1). This avoids calling a LAPACK routine in a case where it's simply not necessary.

This came up in the gradient of minimize, for cases where the function being minimized has only one input. The L_op in that case requires a bunch of linear algebra, but it can all be rewritten away when we're just dealing with scalars.

I also tweaked rewrite_cholesky_diag_to_sqrt_diag to apply to the (1, 1) case -- this allows cho_solve to be rewritten to just b / a ** 2, without any linalg calls. The only bummer is that the sqr and sqrt don't cancel. I get this graph:

import pytensor.tensor as pt
import pytensor
x = pt.tensor('x', shape=(1, 1))
c_and_lower = pt.linalg.cholesky(x), True
b = pt.tensor('b', shape=(None,))

f = pytensor.function([x, b], pt.linalg.cho_solve(c_and_lower, b))

f.dprint()

Composite{(i1 / sqr(sqrt(i0)))} [id A] 2
 ├─ Squeeze{axis=1} [id B] 1
 │  └─ x [id C]
 └─ DimShuffle{order=[x]} [id D] 0
    └─ b [id E]

Inner graphs:

Composite{(i1 / sqr(sqrt(i0)))} [id A]
 ← true_div [id F] 'o0'
    ├─ i1 [id G]
    └─ sqr [id H]
       └─ sqrt [id I]
          └─ i0 [id J]

Related Issue

  • Closes #
  • Related to #

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pytensor--1453.org.readthedocs.build/en/1453/

Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR optimizes scalar linear solves by rewriting them to simple division (and square‐root for Cholesky) when the matrix is (1,1), avoiding expensive LAPACK calls.

  • Add a rewriter (scalar_solve_to_divison) to turn solve/solve_triangular/cho_solve on 1×1 into division.
  • Extend rewrite_cholesky_diag_to_sqrt_diag to map a (1,1) Cholesky to sqrt.
  • Add tests in test_linalg.py to verify these rewrites discard the original Solve* ops.

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
tests/tensor/rewriting/test_linalg.py Added parametrized tests for scalar solve rewrites
pytensor/tensor/rewriting/linalg.py Implemented two rewriters for scalar solve cases
Comments suppressed due to low confidence (3)

pytensor/tensor/rewriting/linalg.py:1035

  • [nitpick] The function name scalar_solve_to_divison contains a typo. Rename it to scalar_solve_to_division for consistency.
def scalar_solve_to_divison(fgraph, node):

pytensor/tensor/rewriting/linalg.py:1034

  • Using [Blockwise] here means the rewrite won’t match SolveBase nodes. Specify the correct op class (e.g., SolveBase) so this rule can fire.
@node_rewriter([Blockwise])

tests/tensor/rewriting/test_linalg.py:32

  • [nitpick] Indentation for this import is inconsistent with the surrounding lines; align it to match the existing style.
    CholeskySolve,

@@ -908,6 +910,11 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
return None

[input] = node.inputs

# Check if input is a (1, 1) matrix
if all(input.type.broadcastable[:-2]):
Copy link
Preview

Copilot AI Jun 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check for a (1,1) matrix is incorrect: it tests all dimensions except the last two. It should verify the last two dims are broadcastable, e.g. all(input.type.broadcastable[-2:]).

Suggested change
if all(input.type.broadcastable[:-2]):
if all(input.type.broadcastable[-2:]):

Copilot uses AI. Check for mistakes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot seems right

@jessegrabowski jessegrabowski added enhancement New feature or request graph rewriting linalg Linear algebra labels Jun 7, 2025
@ricardoV94
Copy link
Member

ricardoV94 commented Jun 7, 2025

Re sqr(sqrt) we should perhaps canonicalize to power, and have the pow(pow(x, a1), a2) -> pow(x, a1*a2) rewrite.

However we may need to introduce some nan switches if x < 0, depending on the values of a (and or refuse to rewrite if either is unknown)

@jessegrabowski
Copy link
Member Author

Oh I didn't even think about the domain issue. That's a good reason we don't do that rewrite.

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 8, 2025

Oh I didn't even think about the domain issue. That's a good reason we don't do that rewrite.

We handle log(exp) and exp(log), fine, it's just that depending on the order the rewrite has a nan switch or not

@jessegrabowski
Copy link
Member Author

Common complex L


# Special handling for different types of solve
match core_op:
case SolveTriangular():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is expensive? Creating an Op in every match case?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely. I was just being too infatuated with match.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Match doesn't work with the class (instead of instance) as an instance check?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. See the match PEP

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It sounds like it could be clever and not really instantiate the Op in the match check. Can you test with a small example (need not be PyTensor specific)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it doesn't instantiate stuff:

class AB:
    def __init__(self):
        print("AB was initialized")

class AA(AB):
    def __init__(self):
        print("AA was initialized")

class BB(AB):
    def __init__(self):
        print("BB was initialized")

aa = AA()
bb = BB()

print("Starting match")
for x in (aa, bb, 1):
    match x:
	case AA():
            print("match aa")
        case BB():
            print("match bb")
        case AB():
            print("match ab")
        case _:
            print("match nothing")

# AA was initialized
# BB was initialized
# Starting match
# match aa
# match bb
# match nothing

Not sure how it's implemented, very surprising syntax imo, but seems fine to use like you did

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wonder if we could start using it for match x.owner: Apply(op=Elemwise(scalar_op=Add)) kind of stuff instead of x.owner is not None and isinstance(x.owner.op, Elemwise) and isinstance(x.owner.op, Add)) 🤔

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -908,6 +910,11 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
return None

[input] = node.inputs

# Check if input is a (1, 1) matrix
if all(input.type.broadcastable[:-2]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot seems right

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request graph rewriting linalg Linear algebra
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants