-
Notifications
You must be signed in to change notification settings - Fork 135
Rewrite scalar solve to division #1453
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -47,8 +47,10 @@ | |
from pytensor.tensor.slinalg import ( | ||
BlockDiagonal, | ||
Cholesky, | ||
CholeskySolve, | ||
Solve, | ||
SolveBase, | ||
SolveTriangular, | ||
_bilinear_solve_discrete_lyapunov, | ||
block_diag, | ||
cholesky, | ||
|
@@ -908,6 +910,11 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node): | |
return None | ||
|
||
[input] = node.inputs | ||
|
||
# Check if input is a (1, 1) matrix | ||
if all(input.type.broadcastable[:-2]): | ||
return [pt.sqrt(input)] | ||
|
||
# Check for use of pt.diag first | ||
if ( | ||
input.owner | ||
|
@@ -1020,3 +1027,42 @@ def slogdet_specialization(fgraph, node): | |
k: slogdet_specialization_map[v] for k, v in dummy_replacements.items() | ||
} | ||
return replacements | ||
|
||
|
||
@register_stabilize | ||
@register_canonicalize | ||
@node_rewriter([Blockwise]) | ||
def scalar_solve_to_divison(fgraph, node): | ||
""" | ||
Replace solve(a, b) with b / a if a is a (1, 1) matrix | ||
""" | ||
|
||
core_op = node.op.core_op | ||
if not isinstance(core_op, SolveBase): | ||
return None | ||
|
||
a, b = node.inputs | ||
old_out = node.outputs[0] | ||
if not all(a.broadcastable[-2:]): | ||
return None | ||
|
||
# Special handling for different types of solve | ||
match core_op: | ||
case SolveTriangular(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is expensive? Creating an Op in every match case? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Definitely. I was just being too infatuated with match. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Match doesn't work with the class (instead of instance) as an instance check? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No. See the match PEP There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It sounds like it could be clever and not really instantiate the Op in the match check. Can you test with a small example (need not be PyTensor specific) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah it doesn't instantiate stuff: class AB:
def __init__(self):
print("AB was initialized")
class AA(AB):
def __init__(self):
print("AA was initialized")
class BB(AB):
def __init__(self):
print("BB was initialized")
aa = AA()
bb = BB()
print("Starting match")
for x in (aa, bb, 1):
match x:
case AA():
print("match aa")
case BB():
print("match bb")
case AB():
print("match ab")
case _:
print("match nothing")
# AA was initialized
# BB was initialized
# Starting match
# match aa
# match bb
# match nothing Not sure how it's implemented, very surprising syntax imo, but seems fine to use like you did There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wonder if we could start using it for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
# Corner case: if user asked for a triangular solve with a unit diagonal, a is taken to be 1 | ||
new_out = b / a if not core_op.unit_diagonal else b | ||
case CholeskySolve(): | ||
new_out = b / a**2 | ||
jessegrabowski marked this conversation as resolved.
Show resolved
Hide resolved
|
||
case Solve(): | ||
new_out = b / a | ||
case _: | ||
raise NotImplementedError( | ||
f"Unsupported core_op type: {type(core_op)} in scalar_solve_to_divison" | ||
) | ||
|
||
if core_op.b_ndim == 1: | ||
new_out = new_out.squeeze(-1) | ||
|
||
copy_stack_trace(old_out, new_out) | ||
|
||
return [new_out] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The check for a (1,1) matrix is incorrect: it tests all dimensions except the last two. It should verify the last two dims are broadcastable, e.g.
all(input.type.broadcastable[-2:])
.Copilot uses AI. Check for mistakes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copilot seems right