Skip to content

Add Ops for LU Factorization #1218

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 19, 2025

Conversation

jessegrabowski
Copy link
Member

@jessegrabowski jessegrabowski commented Feb 18, 2025

Description

This PR will add the following Ops:

  • lu
  • lu_factor
  • lu_solve

As well as dispatches for numba/jax (and maybe torch, though help is welcome there).

The reason for wanting these is that it will make the gradients of solve faster. I think this is a major reason why jax has faster gradients than us (at least when solve is implicated). They route everything to lu_solve(lu_factor(A), b), and reuse lu_factor(A) in the backward pass.

Related Issue

  • Closes #
  • Related to #

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pytensor--1218.org.readthedocs.build/en/1218/

@jessegrabowski jessegrabowski force-pushed the LU-factorization branch 3 times, most recently from b1f8c9d to 2a5b361 Compare February 20, 2025 06:05
@ricardoV94
Copy link
Member

ricardoV94 commented Feb 28, 2025

Any benchmarks on solve written with these Ops?

@jessegrabowski
Copy link
Member Author

Working on that next, just ironing out some bugs in the lu_solve Op (which is what all this is building towards)

@jessegrabowski jessegrabowski force-pushed the LU-factorization branch 4 times, most recently from a03bcce to b7aa9f8 Compare March 21, 2025 11:22
@jessegrabowski jessegrabowski marked this pull request as ready for review March 22, 2025 04:37
@jessegrabowski jessegrabowski force-pushed the LU-factorization branch 5 times, most recently from a3ea772 to d648a15 Compare April 19, 2025 00:14
@jessegrabowski jessegrabowski force-pushed the LU-factorization branch 3 times, most recently from 6a245ab to 8e8311e Compare April 19, 2025 01:54
Copy link

codecov bot commented Apr 19, 2025

Codecov Report

Attention: Patch coverage is 74.80720% with 98 lines in your changes missing coverage. Please review.

Project coverage is 82.07%. Comparing base (676296c) to head (1e42069).
Report is 47 commits behind head on main.

Files with missing lines Patch % Lines
...ensor/link/numba/dispatch/linalg/solve/lu_solve.py 41.50% 31 Missing ⚠️
...sor/link/numba/dispatch/linalg/decomposition/lu.py 66.66% 20 Missing ⚠️
...k/numba/dispatch/linalg/decomposition/lu_factor.py 56.81% 19 Missing ⚠️
pytensor/link/numba/dispatch/slinalg.py 71.42% 8 Missing and 6 partials ⚠️
pytensor/tensor/slinalg.py 92.20% 6 Missing and 6 partials ⚠️
pytensor/link/jax/dispatch/slinalg.py 92.30% 1 Missing and 1 partial ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1218      +/-   ##
==========================================
+ Coverage   82.05%   82.07%   +0.02%     
==========================================
  Files         203      206       +3     
  Lines       48863    49174     +311     
  Branches     8695     8720      +25     
==========================================
+ Hits        40093    40359     +266     
- Misses       6619     6656      +37     
- Partials     2151     2159       +8     
Files with missing lines Coverage Δ
pytensor/link/numba/dispatch/basic.py 80.38% <ø> (ø)
...tensor/link/numba/dispatch/linalg/solve/general.py 51.72% <100.00%> (+7.17%) ⬆️
pytensor/tensor/blockwise.py 90.40% <100.00%> (+4.75%) ⬆️
pytensor/tensor/elemwise.py 90.01% <ø> (+0.41%) ⬆️
pytensor/link/jax/dispatch/slinalg.py 85.33% <92.30%> (+3.70%) ⬆️
pytensor/tensor/slinalg.py 93.10% <92.20%> (-0.30%) ⬇️
pytensor/link/numba/dispatch/slinalg.py 69.76% <71.42%> (+0.66%) ⬆️
...k/numba/dispatch/linalg/decomposition/lu_factor.py 56.81% <56.81%> (ø)
...sor/link/numba/dispatch/linalg/decomposition/lu.py 66.66% <66.66%> (ø)
...ensor/link/numba/dispatch/linalg/solve/lu_solve.py 41.50% <41.50%> (ø)

... and 2 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@jessegrabowski jessegrabowski merged commit e98cbbc into pymc-devs:main Apr 19, 2025
72 of 73 checks passed
@ricardoV94 ricardoV94 added the enhancement New feature or request label May 19, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants