Skip to content

Improve collect_default_updates #6620

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Mar 24, 2023

This is part of some work to make CustomDist more powerful, specially for timeseries distributions.

Immediate changes:

  • It works with nested RNGs
  • It raises error if RNG used in SymbolicRandomVariable is not given an update
  • It raises warning if same RNG is used in multiple nodes

@codecov
Copy link

codecov bot commented Mar 24, 2023

Codecov Report

Merging #6620 (87c4fe7) into main (46f8e2f) will decrease coverage by 0.50%.
The diff coverage is 100.00%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #6620      +/-   ##
==========================================
- Coverage   91.94%   91.44%   -0.50%     
==========================================
  Files          94       94              
  Lines       15831    15845      +14     
==========================================
- Hits        14556    14490      -66     
- Misses       1275     1355      +80     
Impacted Files Coverage Δ
pymc/distributions/distribution.py 96.30% <100.00%> (-0.31%) ⬇️
pymc/pytensorf.py 88.78% <100.00%> (-3.57%) ⬇️

... and 4 files with indirect coverage changes

@ricardoV94 ricardoV94 force-pushed the improve_collect_default_updates branch from 230be69 to 04baa16 Compare March 24, 2023 07:31
@ricardoV94 ricardoV94 marked this pull request as draft March 24, 2023 07:31
@ricardoV94 ricardoV94 marked this pull request as ready for review March 24, 2023 19:22
Comment on lines 1062 to 1070
for shared_rng in (
inp
for inp in graph_inputs(outputs, blockers=inputs)
# TODO: Test this directly
if (
(not must_be_shared or isinstance(inp, SharedVariable))
and isinstance(inp.type, RandomType)
)
):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extracting this iterator into a local variable would make it easier to read.

Extracting it into a function could make it easier to test too?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the TODO, by adding a direct test. I think this test makes the behavior of the function pretty clear, including the hard to read iterator?

https://github.com/pymc-devs/pymc/blob/d8c57f4af5e272a0d924c76caae625d6f1d0fd1f/tests/test_pytensorf.py#L542-L553

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still find the iterator hard to read, but proceed on your discretion

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am still being stubborn, but added a comment and renamed the iterator variable to be a bit more obvious. Will merge when tests pass again :)

pymc/pymc/pytensorf.py

Lines 1062 to 1073 in 87c4fe7

# Iterate over input RNGs. Only consider shared RNGs if `must_be_shared==True`
for input_rng in (
inp
for inp in graph_inputs(outputs, blockers=inputs)
if (
(not must_be_shared or isinstance(inp, SharedVariable))
and isinstance(inp.type, RandomType)
)
):
# Even if an explicit default update is provided, we call it to
# issue any warnings about invalid random graphs.
default_update = find_default_update(clients, input_rng)

Thanks!

@ricardoV94 ricardoV94 force-pushed the improve_collect_default_updates branch 2 times, most recently from d8c57f4 to b44088c Compare March 30, 2023 09:27
* It works with nested RNGs
* It raises error if RNG used in SymbolicRandomVariable is not given an update
* It raises warning if same RNG is used in multiple nodes
@ricardoV94 ricardoV94 force-pushed the improve_collect_default_updates branch from b44088c to 87c4fe7 Compare March 30, 2023 09:29
@ricardoV94 ricardoV94 merged commit 3f2a1da into pymc-devs:main Mar 30, 2023
@ricardoV94 ricardoV94 deleted the improve_collect_default_updates branch April 7, 2023 05:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants