-
Notifications
You must be signed in to change notification settings - Fork 135
Rewrite scalar solve to division #1453
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Rewrite scalar solve to division #1453
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR optimizes scalar linear solves by rewriting them to simple division (and square‐root for Cholesky) when the matrix is (1,1), avoiding expensive LAPACK calls.
- Add a rewriter (
scalar_solve_to_divison
) to turnsolve
/solve_triangular
/cho_solve
on 1×1 into division. - Extend
rewrite_cholesky_diag_to_sqrt_diag
to map a (1,1) Cholesky tosqrt
. - Add tests in
test_linalg.py
to verify these rewrites discard the originalSolve*
ops.
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
File | Description |
---|---|
tests/tensor/rewriting/test_linalg.py | Added parametrized tests for scalar solve rewrites |
pytensor/tensor/rewriting/linalg.py | Implemented two rewriters for scalar solve cases |
Comments suppressed due to low confidence (3)
pytensor/tensor/rewriting/linalg.py:1035
- [nitpick] The function name
scalar_solve_to_divison
contains a typo. Rename it toscalar_solve_to_division
for consistency.
def scalar_solve_to_divison(fgraph, node):
pytensor/tensor/rewriting/linalg.py:1034
- Using
[Blockwise]
here means the rewrite won’t matchSolveBase
nodes. Specify the correct op class (e.g.,SolveBase
) so this rule can fire.
@node_rewriter([Blockwise])
tests/tensor/rewriting/test_linalg.py:32
- [nitpick] Indentation for this import is inconsistent with the surrounding lines; align it to match the existing style.
CholeskySolve,
@@ -908,6 +910,11 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node): | |||
return None | |||
|
|||
[input] = node.inputs | |||
|
|||
# Check if input is a (1, 1) matrix | |||
if all(input.type.broadcastable[:-2]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The check for a (1,1) matrix is incorrect: it tests all dimensions except the last two. It should verify the last two dims are broadcastable, e.g. all(input.type.broadcastable[-2:])
.
if all(input.type.broadcastable[:-2]): | |
if all(input.type.broadcastable[-2:]): |
Copilot uses AI. Check for mistakes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copilot seems right
Re sqr(sqrt) we should perhaps canonicalize to power, and have the However we may need to introduce some nan switches if x < 0, depending on the values of a (and or refuse to rewrite if either is unknown) |
Oh I didn't even think about the domain issue. That's a good reason we don't do that rewrite. |
We handle log(exp) and exp(log), fine, it's just that depending on the order the rewrite has a nan switch or not |
Common complex L |
|
||
# Special handling for different types of solve | ||
match core_op: | ||
case SolveTriangular(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is expensive? Creating an Op in every match case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Definitely. I was just being too infatuated with match.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Match doesn't work with the class (instead of instance) as an instance check?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No. See the match PEP
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It sounds like it could be clever and not really instantiate the Op in the match check. Can you test with a small example (need not be PyTensor specific)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah it doesn't instantiate stuff:
class AB:
def __init__(self):
print("AB was initialized")
class AA(AB):
def __init__(self):
print("AA was initialized")
class BB(AB):
def __init__(self):
print("BB was initialized")
aa = AA()
bb = BB()
print("Starting match")
for x in (aa, bb, 1):
match x:
case AA():
print("match aa")
case BB():
print("match bb")
case AB():
print("match ab")
case _:
print("match nothing")
# AA was initialized
# BB was initialized
# Starting match
# match aa
# match bb
# match nothing
Not sure how it's implemented, very surprising syntax imo, but seems fine to use like you did
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wonder if we could start using it for match x.owner: Apply(op=Elemwise(scalar_op=Add))
kind of stuff instead of x.owner is not None and isinstance(x.owner.op, Elemwise) and isinstance(x.owner.op, Add))
🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -908,6 +910,11 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node): | |||
return None | |||
|
|||
[input] = node.inputs | |||
|
|||
# Check if input is a (1, 1) matrix | |||
if all(input.type.broadcastable[:-2]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copilot seems right
Description
Small optimization that rewrites
solve(a,b) -> b / a
when the core shape ofa
is(1, 1)
. This avoids calling a LAPACK routine in a case where it's simply not necessary.This came up in the gradient of
minimize
, for cases where the function being minimized has only one input. The L_op in that case requires a bunch of linear algebra, but it can all be rewritten away when we're just dealing with scalars.I also tweaked
rewrite_cholesky_diag_to_sqrt_diag
to apply to the (1, 1) case -- this allowscho_solve
to be rewritten to just b / a ** 2, without any linalg calls. The only bummer is that thesqr
andsqrt
don't cancel. I get this graph:Related Issue
Checklist
Type of change
📚 Documentation preview 📚: https://pytensor--1453.org.readthedocs.build/en/1453/