Skip to content

Commit 85e1145

Browse files
committed
Draft revision of jax/numba notebook
1 parent 36a6732 commit 85e1145

File tree

4 files changed

+393
-177
lines changed

4 files changed

+393
-177
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,7 @@ examples/gallery.rst
1111

1212
pixi.lock
1313

14+
15+
# pixi environments
16+
.pixi
17+
*.egg-info

examples/samplers/fast_sampling_with_jax_and_numba.ipynb

Lines changed: 263 additions & 133 deletions
Large diffs are not rendered by default.

examples/samplers/fast_sampling_with_jax_and_numba.myst.md

Lines changed: 116 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ jupytext:
55
format_name: myst
66
format_version: 0.13
77
kernelspec:
8-
display_name: pymc5recent
8+
display_name: default
99
language: python
10-
name: pymc5recent
10+
name: python3
1111
---
1212

1313
(faster_sampling_notebook)=
@@ -22,18 +22,61 @@ kernelspec:
2222

2323
+++
2424

25-
PyMC can compile its models to various execution backends through PyTensor, including:
26-
* C
27-
* JAX
28-
* Numba
25+
PyMC offers multiple sampling backends that can dramatically improve performance depending on your model size and requirements. Each backend has distinct advantages and is optimized for different use cases.
2926

30-
By default, PyMC is using the C backend which then gets called by the Python-based samplers.
27+
### PyMC's Built-in Sampler
3128

32-
However, by compiling to other backends, we can use samplers written in other languages than Python that call the PyMC model without any Python-overhead.
29+
```python
30+
pm.sample()
31+
```
32+
33+
The default PyMC sampler uses a Python-based NUTS implementation that provides maximum compatibility with all PyMC features. This sampler is always used when working with models that contain discrete variables, as it's the only option that supports non-gradient based samplers like Slice and Metropolis. While this sampler can compile the underlying model to different backends (C, Numba, or JAX) using the `compile_kwargs` parameter, it still maintains Python overhead that can limit performance for large models.
34+
35+
### Nutpie Sampler
36+
37+
```python
38+
pm.sample(nuts_sampler="nutpie", nuts_sampler_kwargs={"backend": "numba"})
39+
pm.sample(nuts_sampler="nutpie", nuts_sampler_kwargs={"backend": "jax"})
40+
pm.sample(nuts_sampler="nutpie", nuts_sampler_kwargs={"backend": "jax", "gradient_backend": "pytensor"})
41+
```
42+
43+
Nutpie is on the cutting-edge of PyMC sampling performance. Written in Rust, it eliminates most Python overhead and provides exceptional performance for continuous models. The Numba backend typically offers the highest performance for most use cases, while the JAX backend excels with very large models and provides GPU acceleration capabilities. Nutpie is particularly well-suited for production workflows where sampling speed is critical.
44+
45+
### NumPyro Sampler
46+
47+
```python
48+
pm.sample(nuts_sampler="numpyro", nuts_sampler_kwargs={"chain_method": "parallel"})
49+
# GPU-accelerated
50+
pm.sample(nuts_sampler="numpyro", nuts_sampler_kwargs={"chain_method": "vectorized"})
51+
```
52+
53+
NumPyro provides a mature JAX-based sampling implementation that integrates seamlessly with the broader JAX ecosystem. This sampler typically performs best with small to medium-sized models and offers excellent GPU support. NumPyro benefits from years of development within the JAX community and provides reliable performance characteristics, though it may have compilation overhead for very large models.
54+
55+
### BlackJAX Sampler
56+
57+
```python
58+
pm.sample(nuts_sampler="blackjax")
59+
```
3360

34-
For the JAX backend there is the NumPyro and BlackJAX NUTS sampler available. To use these samplers, you have to install `numpyro` and `blackjax`. Both of them are available through conda/mamba: `mamba install -c conda-forge numpyro blackjax`.
61+
BlackJAX offers another JAX-based sampling implementation focused on flexibility and research applications. While it provides similar capabilities to NumPyro, it's less commonly used in production environments. BlackJAX can be valuable for experimental workflows or when specific JAX-based features are required that aren't available in other samplers.
3562

36-
For the Numba backend, there is the [Nutpie sampler](https://github.com/pymc-devs/nutpie) written in Rust. To use this sampler you need `nutpie` installed: `mamba install -c conda-forge nutpie`.
63+
## Performance Guidelines
64+
65+
Understanding when to use each sampler depends on several key factors including model size, variable types, and computational requirements.
66+
67+
**Model Size Considerations**
68+
69+
For small models, NumPyro typically provides the best balance of performance and reliability. The compilation overhead is minimal, and the mature JAX implementation handles these models efficiently. Larger models often benefit from Nutpie with the Numba backend, which provides excellent performance without the memory overhead sometimes associated with JAX compilation.
70+
71+
Large models generally perform best with either Nutpie's JAX backend or Nutpie's Numba backend. The choice between these depends on whether GPU acceleration is needed and how the model's computational graph interacts with each backend's optimization strategies.
72+
73+
**Variable Type Requirements**
74+
75+
Models containing discrete variables have no choice but to use PyMC's built-in sampler, as it's the only implementation that supports the necessary Slice and Metropolis sampling algorithms. For purely continuous models, all sampling backends are available, making performance the primary consideration.
76+
77+
**Computational Backend Selection**
78+
79+
Numba excels at CPU optimization and provides consistent performance across different model types. It's particularly effective for models with complex mathematical operations that benefit from just-in-time compilation. JAX offers superior performance for very large models and provides natural GPU acceleration, making it ideal when computational resources are a limiting factor. The traditional C backend serves as a reliable fallback option with broad compatibility but typically offers lower performance than the alternatives.
3780

3881
```{code-cell} ipython3
3982
import arviz as az
@@ -50,7 +93,7 @@ print(f"Running on PyMC v{pm.__version__}")
5093
az.style.use("arviz-darkgrid")
5194
```
5295

53-
We will use a simple probabilistic PCA model as our example.
96+
We'll demonstrate the performance differences using a Probabilistic Principal Component Analysis (PPCA) model.
5497

5598
```{code-cell} ipython3
5699
def build_toy_dataset(N, D, K, sigma=1):
@@ -91,44 +134,97 @@ with pm.Model() as PPCA:
91134
x = pm.Normal("x", mu=w.dot(z.T), sigma=1, shape=[D, N], observed=data)
92135
```
93136

94-
## Sampling using Python NUTS sampler
137+
## Performance Comparison
138+
139+
Now let's compare the performance of different sampling backends on our PPCA model. We'll measure both compilation time and sampling speed.
140+
141+
### 1. PyMC Default Sampler (Python NUTS)
95142

96143
```{code-cell} ipython3
97144
%%time
98145
with PPCA:
99146
idata_pymc = pm.sample()
100147
```
101148

102-
## Sampling using NumPyro JAX NUTS sampler
149+
### 2. Nutpie with Numba Backend
103150

104151
```{code-cell} ipython3
105152
%%time
106153
with PPCA:
107-
idata_numpyro = pm.sample(nuts_sampler="numpyro", progressbar=False)
154+
idata_nutpie_numba = pm.sample(
155+
nuts_sampler="nutpie", nuts_sampler_kwargs={"backend": "numba"}, progressbar=False
156+
)
108157
```
109158

110-
## Sampling using BlackJAX NUTS sampler
159+
### 3. Nutpie with JAX Backend
111160

112161
```{code-cell} ipython3
113162
%%time
114163
with PPCA:
115-
idata_blackjax = pm.sample(nuts_sampler="blackjax")
164+
idata_nutpie_jax = pm.sample(
165+
nuts_sampler="nutpie", nuts_sampler_kwargs={"backend": "jax"}, progressbar=False
166+
)
116167
```
117168

118-
## Sampling using Nutpie Rust NUTS sampler
169+
### 4. NumPyro Sampler
119170

120171
```{code-cell} ipython3
121172
%%time
122173
with PPCA:
123-
idata_nutpie = pm.sample(nuts_sampler="nutpie")
174+
idata_numpyro = pm.sample(nuts_sampler="numpyro", progressbar=False)
175+
```
176+
177+
## Installation Requirements
178+
179+
To use the various sampling backends, you need to install the corresponding packages. Nutpie is the recommended high-performance option and can be installed with pip or conda/mamba (e.g. `conda install nutpie`). For JAX-based workflows, NumPyro provides mature functionality and is installed with the `numpyro` package. BlackJAX offers an alternative JAX implementation and is available in the `blackjax` package.
180+
181+
+++
182+
183+
## Special Cases and Advanced Usage
184+
185+
### Using PyMC's Built-in Sampler with Different Backends
186+
187+
In certain scenarios, you may need to use PyMC's Python-based sampler while still benefiting from faster computational backends. This situation commonly arises when working with models that contain discrete variables, which require PyMC's specialized sampling algorithms. Even in these cases, you can significantly improve performance by compiling the model's computational graph to more efficient backends.
188+
189+
The following examples demonstrate how to use PyMC's built-in sampler with different compilation targets. The `fast_run` mode uses optimized C compilation, which provides good performance while maintaining full compatibility. The `numba` mode offers the only way to access Numba's just-in-time compilation benefits when using PyMC's sampler. The `jax` mode enables JAX compilation, though for JAX workflows, Nutpie or NumPyro typically provide better performance.
190+
191+
```{code-cell} ipython3
192+
with PPCA:
193+
idata_c = pm.sample(nuts_sampler="pymc", compile_kwargs={"mode": "fast_run"})
194+
195+
# with PPCA:
196+
# idata_pymc_numba = pm.sample(nuts_sampler="pymc", compile_kwargs={"mode": "numba"})
197+
198+
# with PPCA:
199+
# idata_pymc_jax = pm.sample(nuts_sampler="pymc", compile_kwargs={"mode": "jax"})
200+
```
201+
202+
The above examples are commented out to avoid redundant sampling in this demonstration notebook. In practice, you would uncomment and run the configuration that matches your model's requirements. These compilation modes allow you to access faster computational backends even when you must use PyMC's Python-based sampler for compatibility reasons.
203+
204+
+++
205+
206+
### Models with Discrete Variables
207+
208+
When working with models that contain discrete variables, you have no choice but to use PyMC's built-in sampler. This is because discrete variables require specialized sampling algorithms like Slice sampling or Metropolis-Hastings that are only available in PyMC's Python implementation. The example below demonstrates a typical scenario where this constraint applies.
209+
210+
```{code-cell} ipython3
211+
with pm.Model() as discrete_model:
212+
cluster = pm.Categorical("cluster", p=[0.3, 0.7], shape=100)
213+
mu = pm.Normal("mu", 0, 1, shape=2)
214+
sigma = pm.HalfNormal("sigma", 1, shape=2)
215+
obs = pm.Normal("obs", mu=mu[cluster], sigma=sigma[cluster], observed=rng.normal(0, 1, 100))
216+
217+
trace_discrete = pm.sample()
124218
```
125219

126220
## Authors
127-
Authored by Thomas Wiecki in July 2023
221+
222+
- Originally authored by Thomas Wiecki in July 2023
223+
- Substantially updated and expanded by Chris Fonnesbeck in May 2025
128224

129225
```{code-cell} ipython3
130226
%load_ext watermark
131-
%watermark -n -u -v -iv -w -p pytensor,arviz,pymc,numpyro,blackjax,nutpie
227+
%watermark -n -u -v -iv -w -p pytensor,aeppl,xarray
132228
```
133229

134230
:::{include} ../page_footer.md

pixi.toml

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,21 @@
1-
[project]
1+
[workspace]
22
authors = ["Chris Fonnesbeck <[email protected]>"]
33
channels = ["conda-forge"]
4-
description = "Add a short description here"
54
name = "pymc-examples"
65
platforms = ["linux-64"]
76
version = "0.1.0"
87

98
[tasks]
109

1110
[dependencies]
12-
python = ">=3.12.5,<4"
13-
pymc = ">=5.16.2,<6"
14-
jupyter = ">=1.1.1,<2"
11+
pymc = ">=5.22.0,<6"
12+
nutpie = ">=0.14.3,<0.15"
13+
numpyro = ">=0.18.0,<0.19"
14+
numba = ">=0.61.2,<0.62"
15+
ipywidgets = ">=8.1.7,<9"
16+
arviz = ">=0.21.0,<0.22"
17+
matplotlib = ">=3.10.3,<4"
18+
python = ">=3.12.10,<3.13"
1519
ipykernel = ">=6.29.5,<7"
16-
ipywidgets = ">=8.1.5,<9"
17-
numpy = ">=1.26.4,<2"
18-
arviz = ">=0.19.0,<0.20"
19-
numpyro = ">=0.15.2,<0.16"
20-
seaborn = ">=0.13.2,<0.14"
21-
matplotlib = ">=3.9.2,<4"
22-
pandas = ">=2.2.2,<3"
23-
polars = ">=1.6.0,<2"
24-
esbonio = ">=0.16.4,<0.17"
20+
blackjax = ">=1.2.4,<2"
2521
watermark = ">=2.5.0,<3"
26-
nutpie = ">=0.13.2,<0.14"
27-
numba = ">=0.60.0,<0.61"
28-
scikit-learn = ">=1.5.2,<2"
29-
blackjax = ">=1.2.3,<2"
30-
networkx = ">=3.4.2,<4"
31-
bokeh = ">=3.7.2,<4"
32-
33-
[pypi-dependencies]
34-
pymc-experimental = ">=0.1.2, <0.2"
35-
pymc-extras = ">=0.2.0, <0.3"

0 commit comments

Comments
 (0)