Skip to content

Speedup truncated_graph_inputs #394

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 26, 2023
Merged

Conversation

ferrine
Copy link
Member

@ferrine ferrine commented Jul 21, 2023

Motivation for these changes

Compilation in some cases goes off for some reason and enters infinite loops. This fix resolves the issue, while the infinite loop in this check indicates that the real cause is some other bug.

As proposed by @aseyboldt I implemented the same logic with small refactoring

diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py
index 0f9655647..ea1dfaed3 100644
--- a/pytensor/graph/basic.py
+++ b/pytensor/graph/basic.py
@@ -1091,6 +1091,8 @@ def truncated_graph_inputs(
     # enforce O(1) check for node in ancestors to include
     ancestors_to_include = blockers.copy()
 
+    seen = {}
+
     while candidates:
         # on any new candidate
         node = candidates.pop()
@@ -1099,6 +1101,12 @@ def truncated_graph_inputs(
         if node in truncated_inputs:
             continue
 
+        seen.setdefault(node, 0)
+        seen[node] += 1
+        if seen[node] > 1:
+            continue
+
         # check if the node is independent, never go above blockers;
         # blockers are independent nodes and ancestors to include
         if node in ancestors_to_include:

This is similar to the closed PR, which aims to optimize the check. Caching seems to be more involved to implement, I think we can go with this fix.

https://github.com/pymc-devs/pytensor/pull/30/files

Implementation details

Continue on duplicate checks for dependency

Checklist

Bugfixes

  • Avoid infinite loops in truncated_graph_inputs

@codecov-commenter
Copy link

Codecov Report

Merging #394 (ec878f7) into main (b2b7e28) will decrease coverage by 0.01%.
The diff coverage is 100.00%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #394      +/-   ##
==========================================
- Coverage   80.47%   80.47%   -0.01%     
==========================================
  Files         156      156              
  Lines       45516    45520       +4     
  Branches    11149    11150       +1     
==========================================
+ Hits        36629    36631       +2     
- Misses       6685     6686       +1     
- Partials     2202     2203       +1     
Impacted Files Coverage Δ
pytensor/graph/basic.py 89.13% <100.00%> (-0.22%) ⬇️

@ferrine ferrine force-pushed the truncated_graph_inputs_optimize branch from ec878f7 to 81638c3 Compare July 21, 2023 15:10
@ferrine ferrine added the bug Something isn't working label Jul 21, 2023
@ferrine ferrine changed the title add seen set add seen set to truncated_graph_inputs Jul 21, 2023
@ricardoV94
Copy link
Member

This fix resolves the issue, while the infinite loop in this check indicates that the real cause is some other bug.

Do you have a reproducible example for the original bug?

Does this PR fix the bug/hide it/ do nothing about it?

@aseyboldt
Copy link
Member

I'd also feel better if we had a proper reproducing example.
This one at least triggers the branch of seen, but it still doesn't lead to an infinite loop:

x = pt.dmatrix("x")
m = x.shape[0][None, None]

f = x / m
w = x / m - f

pytensor.graph.basic.truncated_graph_inputs([w], [x])

@ricardoV94
Copy link
Member

So... question is. Do we want these changes even if they don't fix any bug / the original bug?

Was there actually a bug, or perhaps an invalid cyclic graph or something?

@ferrine
Copy link
Member Author

ferrine commented Jul 24, 2023

So... question is. Do we want these changes even if they don't fix any bug / the original bug?

Was there actually a bug, or perhaps an invalid cyclic graph or something?

There is indeed a bug somewhere, yet the hotfix still optimizes the function by dropping repeated search paths.

@ferrine
Copy link
Member Author

ferrine commented Jul 24, 2023

The proper thing to do is to locate the infinite loop issue and fix it, if it is possible to do before this function is called, I'll try it on the model.

@ericmjl
Copy link
Member

ericmjl commented Jul 24, 2023

@ricardoV94 would you be open to sitting down with @ferrine and @Armavica to look through this problem live? I think it may be helpful to just "see" the thing. 👀

@ricardoV94
Copy link
Member

ricardoV94 commented Jul 25, 2023

I am suspecting exponential growth for deep graphs with many branches and not an actual infinite loop. In the example that Max shared with me, it took 15 minutes to derive the graph inputs for one of the variables alone:

truncated_graph_inputs for grad 0 took 5 s
truncated_graph_inputs for grad 1 took 72 s
truncated_graph_inputs for grad 2 took 214 s
truncated_graph_inputs for grad 3 took 73 s
truncated_graph_inputs for grad 4 took 963 s
truncated_graph_inputs for grad 5 took 140 s
truncated_graph_inputs for grad 6 took 139 s
truncated_graph_inputs for grad 7 took 142 s
truncated_graph_inputs for grad 8 took 606 s
truncated_graph_inputs for grad 9 took 568 s
truncated_graph_inputs for grad 10 took 614 s

~ 1hour total

@ricardoV94 ricardoV94 changed the title add seen set to truncated_graph_inputs Speedup truncated_graph_inputs Jul 25, 2023
@ricardoV94 ricardoV94 merged commit e9a7d7c into main Jul 26, 2023
@ericmjl ericmjl deleted the truncated_graph_inputs_optimize branch July 26, 2023 21:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants