-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Improve collect_default_updates #6620
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve collect_default_updates #6620
Conversation
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #6620 +/- ##
==========================================
- Coverage 91.94% 91.44% -0.50%
==========================================
Files 94 94
Lines 15831 15845 +14
==========================================
- Hits 14556 14490 -66
- Misses 1275 1355 +80
|
230be69
to
04baa16
Compare
pymc/pytensorf.py
Outdated
for shared_rng in ( | ||
inp | ||
for inp in graph_inputs(outputs, blockers=inputs) | ||
# TODO: Test this directly | ||
if ( | ||
(not must_be_shared or isinstance(inp, SharedVariable)) | ||
and isinstance(inp.type, RandomType) | ||
) | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Extracting this iterator into a local variable would make it easier to read.
Extracting it into a function could make it easier to test too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed the TODO, by adding a direct test. I think this test makes the behavior of the function pretty clear, including the hard to read iterator?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still find the iterator hard to read, but proceed on your discretion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am still being stubborn, but added a comment and renamed the iterator variable to be a bit more obvious. Will merge when tests pass again :)
Lines 1062 to 1073 in 87c4fe7
# Iterate over input RNGs. Only consider shared RNGs if `must_be_shared==True` | |
for input_rng in ( | |
inp | |
for inp in graph_inputs(outputs, blockers=inputs) | |
if ( | |
(not must_be_shared or isinstance(inp, SharedVariable)) | |
and isinstance(inp.type, RandomType) | |
) | |
): | |
# Even if an explicit default update is provided, we call it to | |
# issue any warnings about invalid random graphs. | |
default_update = find_default_update(clients, input_rng) |
Thanks!
d8c57f4
to
b44088c
Compare
* It works with nested RNGs * It raises error if RNG used in SymbolicRandomVariable is not given an update * It raises warning if same RNG is used in multiple nodes
b44088c
to
87c4fe7
Compare
This is part of some work to make CustomDist more powerful, specially for timeseries distributions.
Immediate changes: