You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
PyMC can compile its models to various execution backends through PyTensor, including:
26
-
* C
27
-
* JAX
28
-
* Numba
25
+
PyMC offers multiple sampling backends that can dramatically improve performance depending on your model size and requirements. Each backend has distinct advantages and is optimized for different use cases.
29
26
30
-
By default, PyMC is using the C backend which then gets called by the Python-based samplers.
27
+
### PyMC's Built-in Sampler
31
28
32
-
However, by compiling to other backends, we can use samplers written in other languages than Python that call the PyMC model without any Python-overhead.
29
+
```python
30
+
pm.sample()
31
+
```
32
+
33
+
The default PyMC sampler uses a Python-based NUTS implementation that provides maximum compatibility with all PyMC features. This sampler is always used when working with models that contain discrete variables, as it's the only option that supports non-gradient based samplers like Slice and Metropolis. While this sampler can compile the underlying model to different backends (C, Numba, or JAX) using the `compile_kwargs` parameter, it still maintains Python overhead that can limit performance for large models.
Nutpie is on the cutting-edge of PyMC sampling performance. Written in Rust, it eliminates most Python overhead and provides exceptional performance for continuous models. The Numba backend typically offers the highest performance for most use cases, while the JAX backend excels with very large models and provides GPU acceleration capabilities. Nutpie is particularly well-suited for production workflows where sampling speed is critical.
NumPyro provides a mature JAX-based sampling implementation that integrates seamlessly with the broader JAX ecosystem. This sampler typically performs best with small to medium-sized models and offers excellent GPU support. NumPyro benefits from years of development within the JAX community and provides reliable performance characteristics, though it may have compilation overhead for very large models.
54
+
55
+
### BlackJAX Sampler
56
+
57
+
```python
58
+
pm.sample(nuts_sampler="blackjax")
59
+
```
33
60
34
-
For the JAX backend there is the NumPyro and BlackJAX NUTS sampler available. To use these samplers, you have to install `numpyro` and `blackjax`. Both of them are available through conda/mamba: `mamba install -c conda-forge numpyro blackjax`.
61
+
BlackJAX offers another JAX-based sampling implementation focused on flexibility and research applications. While it provides similar capabilities to NumPyro, it's less commonly used in production environments. BlackJAX can be valuable for experimental workflows or when specific JAX-based features are required that aren't available in other samplers.
35
62
36
-
For the Numba backend, there is the [Nutpie sampler](https://github.com/pymc-devs/nutpie) written in Rust. To use this sampler you need `nutpie` installed: `mamba install -c conda-forge nutpie`.
63
+
## Performance Guidelines
64
+
65
+
Understanding when to use each sampler depends on several key factors including model size, variable types, and computational requirements.
66
+
67
+
**Model Size Considerations**
68
+
69
+
For small models, NumPyro typically provides the best balance of performance and reliability. The compilation overhead is minimal, and the mature JAX implementation handles these models efficiently. Larger models often benefit from Nutpie with the Numba backend, which provides excellent performance without the memory overhead sometimes associated with JAX compilation.
70
+
71
+
Large models generally perform best with either Nutpie's JAX backend or Nutpie's Numba backend. The choice between these depends on whether GPU acceleration is needed and how the model's computational graph interacts with each backend's optimization strategies.
72
+
73
+
**Variable Type Requirements**
74
+
75
+
Models containing discrete variables have no choice but to use PyMC's built-in sampler, as it's the only implementation that supports the necessary Slice and Metropolis sampling algorithms. For purely continuous models, all sampling backends are available, making performance the primary consideration.
76
+
77
+
**Computational Backend Selection**
78
+
79
+
Numba excels at CPU optimization and provides consistent performance across different model types. It's particularly effective for models with complex mathematical operations that benefit from just-in-time compilation. JAX offers superior performance for very large models and provides natural GPU acceleration, making it ideal when computational resources are a limiting factor. The traditional C backend serves as a reliable fallback option with broad compatibility but typically offers lower performance than the alternatives.
37
80
38
81
```{code-cell} ipython3
39
82
import arviz as az
@@ -50,7 +93,7 @@ print(f"Running on PyMC v{pm.__version__}")
50
93
az.style.use("arviz-darkgrid")
51
94
```
52
95
53
-
We will use a simple probabilistic PCA model as our example.
96
+
We'll demonstrate the performance differences using a Probabilistic Principal Component Analysis (PPCA) model.
54
97
55
98
```{code-cell} ipython3
56
99
def build_toy_dataset(N, D, K, sigma=1):
@@ -91,44 +134,97 @@ with pm.Model() as PPCA:
91
134
x = pm.Normal("x", mu=w.dot(z.T), sigma=1, shape=[D, N], observed=data)
92
135
```
93
136
94
-
## Sampling using Python NUTS sampler
137
+
## Performance Comparison
138
+
139
+
Now let's compare the performance of different sampling backends on our PPCA model. We'll measure both compilation time and sampling speed.
To use the various sampling backends, you need to install the corresponding packages. Nutpie is the recommended high-performance option and can be installed with pip or conda/mamba (e.g. `conda install nutpie`). For JAX-based workflows, NumPyro provides mature functionality and is installed with the `numpyro` package. BlackJAX offers an alternative JAX implementation and is available in the `blackjax` package.
180
+
181
+
+++
182
+
183
+
## Special Cases and Advanced Usage
184
+
185
+
### Using PyMC's Built-in Sampler with Different Backends
186
+
187
+
In certain scenarios, you may need to use PyMC's Python-based sampler while still benefiting from faster computational backends. This situation commonly arises when working with models that contain discrete variables, which require PyMC's specialized sampling algorithms. Even in these cases, you can significantly improve performance by compiling the model's computational graph to more efficient backends.
188
+
189
+
The following examples demonstrate how to use PyMC's built-in sampler with different compilation targets. The `fast_run` mode uses optimized C compilation, which provides good performance while maintaining full compatibility. The `numba` mode offers the only way to access Numba's just-in-time compilation benefits when using PyMC's sampler. The `jax` mode enables JAX compilation, though for JAX workflows, Nutpie or NumPyro typically provide better performance.
The above examples are commented out to avoid redundant sampling in this demonstration notebook. In practice, you would uncomment and run the configuration that matches your model's requirements. These compilation modes allow you to access faster computational backends even when you must use PyMC's Python-based sampler for compatibility reasons.
203
+
204
+
+++
205
+
206
+
### Models with Discrete Variables
207
+
208
+
When working with models that contain discrete variables, you have no choice but to use PyMC's built-in sampler. This is because discrete variables require specialized sampling algorithms like Slice sampling or Metropolis-Hastings that are only available in PyMC's Python implementation. The example below demonstrates a typical scenario where this constraint applies.
0 commit comments