-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Allow jitter boolean to be set through nuts_sampler_kwargs
#7083
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow jitter boolean to be set through nuts_sampler_kwargs
#7083
Conversation
…uts_sampler_kwargs
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Looks good. I just found an issue with the type hint changes.
pymc/sampling/jax.py
Outdated
@@ -256,7 +256,7 @@ def _get_batched_jittered_initial_points( | |||
chains: int, | |||
initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]], | |||
random_seed: RandomSeed, | |||
jitter: bool = True, | |||
jitter: Optional[bool] = True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Optional
hint doesn't make sense. That's for when a variable can have a value of None, but here bool with True/False should suffice
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @ricardoV94! Updated
pymc/sampling/jax.py
Outdated
@@ -564,6 +568,7 @@ def sample_numpyro_nuts( | |||
target_accept: float = 0.8, | |||
random_seed: Optional[RandomState] = None, | |||
initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None, | |||
jitter: Optional[bool] = True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #7083 +/- ##
==========================================
+ Coverage 92.21% 92.23% +0.02%
==========================================
Files 101 101
Lines 16912 16884 -28
==========================================
- Hits 15595 15573 -22
+ Misses 1317 1311 -6
|
nuts_sampler_kwargs
Currently, initial values to the numpyro and jax nuts samplers are automatically jittered.
This PR exposes the jitter option and allows the boolean to be set through nuts_sampler_kwargs.
Description
Related Issue
There is no issue related to this change.
Checklist
Type of change
📚 Documentation preview 📚: https://pymc--7083.org.readthedocs.build/en/7083/