Skip to content

Allow jitter boolean to be set through nuts_sampler_kwargs #7083

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

VMBoehm
Copy link
Contributor

@VMBoehm VMBoehm commented Jan 5, 2024

Currently, initial values to the numpyro and jax nuts samplers are automatically jittered.

This PR exposes the jitter option and allows the boolean to be set through nuts_sampler_kwargs.

Description

Related Issue

There is no issue related to this change.

  • Closes #
  • Related to #

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pymc--7083.org.readthedocs.build/en/7083/

Copy link

welcome bot commented Jan 5, 2024

Thank You Banner
💖 Thanks for opening this pull request! 💖 The PyMC community really appreciates your time and effort to contribute to the project. Please make sure you have read our Contributing Guidelines and filled in our pull request template to the best of your ability.

@VMBoehm VMBoehm marked this pull request as draft January 5, 2024 23:11
@VMBoehm VMBoehm marked this pull request as ready for review January 6, 2024 00:54
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Looks good. I just found an issue with the type hint changes.

@@ -256,7 +256,7 @@ def _get_batched_jittered_initial_points(
chains: int,
initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]],
random_seed: RandomSeed,
jitter: bool = True,
jitter: Optional[bool] = True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Optional hint doesn't make sense. That's for when a variable can have a value of None, but here bool with True/False should suffice

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @ricardoV94! Updated

@@ -564,6 +568,7 @@ def sample_numpyro_nuts(
target_accept: float = 0.8,
random_seed: Optional[RandomState] = None,
initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None,
jitter: Optional[bool] = True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

@ricardoV94 ricardoV94 changed the title allow jitter boolean to be set through nuts_sampler_kwargs Allow jitter boolean to be set through nuts_sampler_kwargs Jan 11, 2024
Copy link

codecov bot commented Jan 11, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Comparison is base (7bb2ccd) 92.21% compared to head (c4ea24a) 92.23%.
Report is 3 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #7083      +/-   ##
==========================================
+ Coverage   92.21%   92.23%   +0.02%     
==========================================
  Files         101      101              
  Lines       16912    16884      -28     
==========================================
- Hits        15595    15573      -22     
+ Misses       1317     1311       -6     
Files Coverage Δ
pymc/sampling/jax.py 93.07% <ø> (ø)

... and 53 files with indirect coverage changes

@VMBoehm VMBoehm requested a review from ricardoV94 January 11, 2024 19:33
@ricardoV94 ricardoV94 merged commit 2da4050 into pymc-devs:main Jan 23, 2024
Copy link

welcome bot commented Jan 23, 2024

Congratulations Banner]
Congrats on merging your first pull request! 🎉 We here at PyMC are proud of you! 💖 Thank you so much for your contribution 🎁

@ricardoV94 ricardoV94 changed the title Allow jitter boolean to be set through nuts_sampler_kwargs Allow jitter boolean to be set through nuts_sampler_kwargs Feb 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants