Skip to content

Add nuts_sampler_kwargs to pm.sample #6581

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Mar 12, 2023
Merged

Conversation

fonnesbeck
Copy link
Member

@fonnesbeck fonnesbeck commented Mar 8, 2023

What is this PR about?

In the current release, optional parameters for sample_numpyro_nuts were not making it through when passed from pm.sample:

ValueError: Unused step method arguments: {'postprocessing_backend', 'chain_method', 'nuts_kwargs'}

This PR reinstates the nuts_kwargs argument as a means for getting NUTS arguments to the sampler, and adds a sampler_kwargs argument for passing sampler-specific arguments, such as "chain_method" or "postprocessing_backend".

with model:
    trace = pm.sample(draws=500, tune=1000, chains=2, nuts_sampler='numpyro', target_accept=0.7, nuts_kwargs={"max_tree_depth": 12}, sampler_kwargs="chain_method":'vectorized'}, idata_kwargs={"log_likelihood": False})

Checklist

Major / Breaking Changes

New features

  • Add nuts_sampler_kwargs optional argument to pm.sample() to forward kwargs to an external nuts implementation.

Bugfixes

  • ...

Documentation

  • ...

Maintenance

  • ...

@fonnesbeck fonnesbeck requested a review from twiecki March 8, 2023 15:05
@codecov
Copy link

codecov bot commented Mar 8, 2023

Codecov Report

Merging #6581 (29f41ce) into main (a41d524) will decrease coverage by 0.01%.
The diff coverage is 75.00%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #6581      +/-   ##
==========================================
- Coverage   92.01%   92.00%   -0.01%     
==========================================
  Files          92       92              
  Lines       15535    15539       +4     
==========================================
+ Hits        14294    14297       +3     
- Misses       1241     1242       +1     
Impacted Files Coverage Δ
pymc/sampling/jax.py 98.26% <ø> (ø)
pymc/sampling/mcmc.py 90.13% <75.00%> (-0.15%) ⬇️

... and 1 file with indirect coverage changes

@fonnesbeck fonnesbeck changed the title Add sampler_kwargs and nuts_kwargs to pm.sample Add nuts_sampler_kwargs and nuts_kwargs to pm.sample Mar 8, 2023
@fonnesbeck fonnesbeck requested a review from twiecki March 12, 2023 18:22
@fonnesbeck
Copy link
Member Author

Greatly simplified now.

Copy link
Member

@twiecki twiecki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See minor comment, otherwise this looks great.

Apply suggested nuts_sampler_kwarg docstring change

Co-authored-by: Thomas Wiecki <[email protected]>
@fonnesbeck fonnesbeck merged commit da68d11 into pymc-devs:main Mar 12, 2023
@fonnesbeck fonnesbeck deleted the nuts_kwargs2 branch March 12, 2023 18:32
import pytest

from pymc import Model, Normal, sample

# turns all warnings into errors for this module
pytestmark = pytest.mark.filterwarnings("error")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a test raises a warning it should be caught explicitly with pytest.warns.

@ricardoV94
Copy link
Member

pre-commit failed

@twiecki
Copy link
Member

twiecki commented Mar 13, 2023

image
it did?

@ricardoV94
Copy link
Member

Weird, this is what I see here:
image

@ricardoV94
Copy link
Member

@twiecki
Copy link
Member

twiecki commented Mar 13, 2023

Oh dang, it was my last-minute edit.

@fonnesbeck fonnesbeck changed the title Add nuts_sampler_kwargs and nuts_kwargs to pm.sample Add nuts_sampler_kwargs to pm.sample Mar 18, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants