Skip to content

Commit 0c880d2

Browse files
committed
Minor changes made to the fit_pathfinder function and added test
`fit_pathfinder` - Edited `fit_pathfinder` to produce `pathfinder_state`, `pathfinder_info`, `pathfinder_samples` and `pathfinder_idata` for closer examination of the outputs. - Changed the `num_samples` argument name to `num_draws` to avoid `TypeError` got multiple values for keyword argument 'num_samples'. - Initial points are automatically set to jitter as jitter is required for pathfinder. Extras - New function 'get_jaxified_logp_ravel_inputs' to simplify previous code structure in fit_pathfinder. Tests - Added extra test for pathfinder to test pathfinder_info variables and pathfinder_idata are consistent for a given random seed.
1 parent 4540b84 commit 0c880d2

File tree

2 files changed

+123
-19
lines changed

2 files changed

+123
-19
lines changed

pymc_experimental/inference/pathfinder.py

Lines changed: 49 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,55 @@
1515
import collections
1616
import sys
1717

18+
from collections.abc import Callable
19+
1820
import arviz as az
1921
import blackjax
2022
import jax
2123
import numpy as np
2224
import pymc as pm
2325

2426
from packaging import version
27+
from pymc import Model
2528
from pymc.backends.arviz import coords_and_dims_for_inferencedata
2629
from pymc.blocking import DictToArrayBijection, RaveledVars
30+
from pymc.initial_point import make_initial_point_fn
2731
from pymc.model import modelcontext
32+
from pymc.model.core import Point
2833
from pymc.sampling.jax import get_jaxified_graph
2934
from pymc.util import RandomSeed, _get_seeds_per_chain, get_default_varnames
3035

3136

37+
def get_jaxified_logp_ravel_inputs(
38+
model: Model,
39+
initial_points: dict | None = None,
40+
) -> tuple[Callable, DictToArrayBijection]:
41+
"""
42+
Get jaxified logp function and ravel inputs for a PyMC model.
43+
44+
Parameters
45+
----------
46+
model : Model
47+
PyMC model to jaxify.
48+
49+
Returns
50+
-------
51+
tuple[Callable, DictToArrayBijection]
52+
A tuple containing the jaxified logp function and the DictToArrayBijection.
53+
"""
54+
55+
new_logprob, new_input = pm.pytensorf.join_nonshared_inputs(
56+
initial_points, (model.logp(),), model.value_vars, ()
57+
)
58+
59+
logprob_fn_list = get_jaxified_graph([new_input], new_logprob)
60+
61+
def logprob_fn(x):
62+
return logprob_fn_list(x)[0]
63+
64+
return logprob_fn, DictToArrayBijection.map(initial_points)
65+
66+
3267
def convert_flat_trace_to_idata(
3368
samples,
3469
include_transformed=False,
@@ -37,7 +72,7 @@ def convert_flat_trace_to_idata(
3772
):
3873
model = modelcontext(model)
3974
ip = model.initial_point()
40-
ip_point_map_info = pm.blocking.DictToArrayBijection.map(ip).point_map_info
75+
ip_point_map_info = DictToArrayBijection.map(ip).point_map_info
4176
trace = collections.defaultdict(list)
4277
for sample in samples:
4378
raveld_vars = RaveledVars(sample, ip_point_map_info)
@@ -62,10 +97,10 @@ def convert_flat_trace_to_idata(
6297

6398

6499
def fit_pathfinder(
65-
num_samples=1000,
100+
model=None,
101+
num_draws=1000,
66102
random_seed: RandomSeed | None = None,
67103
postprocessing_backend="cpu",
68-
model=None,
69104
**pathfinder_kwargs,
70105
):
71106
"""
@@ -99,22 +134,19 @@ def fit_pathfinder(
99134

100135
model = modelcontext(model)
101136

102-
ip = model.initial_point()
103-
ip_map = DictToArrayBijection.map(ip)
137+
[jitter_seed, pathfinder_seed, sample_seed] = _get_seeds_per_chain(random_seed, 3)
104138

105-
new_logprob, new_input = pm.pytensorf.join_nonshared_inputs(
106-
ip, (model.logp(),), model.value_vars, ()
139+
# set initial points. PF requires jittering of initial points
140+
ipfn = make_initial_point_fn(
141+
model=model,
142+
jitter_rvs=set(model.free_RVs),
143+
# TODO: add argument for jitter strategy
107144
)
108-
109-
logprob_fn_list = get_jaxified_graph([new_input], new_logprob)
110-
111-
def logprob_fn(x):
112-
return logprob_fn_list(x)[0]
113-
114-
[pathfinder_seed, sample_seed] = _get_seeds_per_chain(random_seed, 2)
145+
ip = Point(ipfn(jitter_seed), model=model)
146+
logprob_fn, ip_map = get_jaxified_logp_ravel_inputs(model, initial_points=ip)
115147

116148
print("Running pathfinder...", file=sys.stdout)
117-
pathfinder_state, _ = blackjax.vi.pathfinder.approximate(
149+
pathfinder_state, pathfinder_info = blackjax.vi.pathfinder.approximate(
118150
rng_key=jax.random.key(pathfinder_seed),
119151
logdensity_fn=logprob_fn,
120152
initial_position=ip_map.data,
@@ -123,12 +155,12 @@ def logprob_fn(x):
123155
pathfinder_samples, _ = blackjax.vi.pathfinder.sample(
124156
rng_key=jax.random.key(sample_seed),
125157
state=pathfinder_state,
126-
num_samples=num_samples,
158+
num_samples=num_draws,
127159
)
128160

129161
idata = convert_flat_trace_to_idata(
130162
pathfinder_samples,
131163
postprocessing_backend=postprocessing_backend,
132164
model=model,
133165
)
134-
return idata
166+
return pathfinder_state, pathfinder_info, pathfinder_samples, idata

tests/test_pathfinder.py

Lines changed: 74 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@
1717
import numpy as np
1818
import pymc as pm
1919
import pytest
20+
import xarray as xr
2021

2122
import pymc_experimental as pmx
2223

24+
from pymc_experimental.inference.pathfinder import fit_pathfinder
2325

24-
@pytest.mark.skipif(sys.platform == "win32", reason="JAX not supported on windows.")
25-
def test_pathfinder():
26+
27+
def build_eight_schools_model():
2628
# Data of the Eight Schools Model
2729
J = 8
2830
y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
@@ -35,6 +37,14 @@ def test_pathfinder():
3537
theta = pm.Normal("theta", mu=0, sigma=1, shape=J)
3638
obs = pm.Normal("obs", mu=mu + tau * theta, sigma=sigma, shape=J, observed=y)
3739

40+
return model
41+
42+
43+
@pytest.mark.skipif(sys.platform == "win32", reason="JAX not supported on windows.")
44+
def test_pathfinder():
45+
model = build_eight_schools_model()
46+
47+
with model:
3848
idata = pmx.fit(method="pathfinder", random_seed=41)
3949

4050
assert idata.posterior["mu"].shape == (1, 1000)
@@ -43,3 +53,65 @@ def test_pathfinder():
4353
# FIXME: pathfinder doesn't find a reasonable mean! Fix bug or choose model pathfinder can handle
4454
# np.testing.assert_allclose(idata.posterior["mu"].mean(), 5.0)
4555
np.testing.assert_allclose(idata.posterior["tau"].mean(), 4.15, atol=0.5)
56+
57+
58+
def test_pathfinder_pmx_equivalence():
59+
model = build_eight_schools_model()
60+
with model:
61+
idata_pmx = pmx.fit(method="pathfinder", random_seed=41)
62+
idata_pmx = idata_pmx[-1]
63+
64+
ntests = 2
65+
runs = dict()
66+
for k in range(ntests):
67+
runs[k] = {}
68+
(
69+
runs[k]["pathfinder_state"],
70+
runs[k]["pathfinder_info"],
71+
runs[k]["pathfinder_samples"],
72+
runs[k]["pathfinder_idata"],
73+
) = fit_pathfinder(model=model, random_seed=41)
74+
75+
runs[k]["finite_idx"] = (
76+
np.argwhere(np.isfinite(runs[k]["pathfinder_info"].path.elbo)).ravel()[-1] + 1
77+
)
78+
79+
np.testing.assert_allclose(
80+
runs[0]["pathfinder_info"].path.elbo[: runs[0]["finite_idx"]],
81+
runs[1]["pathfinder_info"].path.elbo[: runs[1]["finite_idx"]],
82+
)
83+
84+
np.testing.assert_allclose(
85+
runs[0]["pathfinder_info"].path.alpha,
86+
runs[1]["pathfinder_info"].path.alpha,
87+
)
88+
89+
np.testing.assert_allclose(
90+
runs[0]["pathfinder_info"].path.beta,
91+
runs[1]["pathfinder_info"].path.beta,
92+
)
93+
94+
np.testing.assert_allclose(
95+
runs[0]["pathfinder_info"].path.gamma,
96+
runs[1]["pathfinder_info"].path.gamma,
97+
)
98+
99+
np.testing.assert_allclose(
100+
runs[0]["pathfinder_info"].path.position,
101+
runs[1]["pathfinder_info"].path.position,
102+
)
103+
104+
np.testing.assert_allclose(
105+
runs[0]["pathfinder_info"].path.grad_position,
106+
runs[1]["pathfinder_info"].path.grad_position,
107+
)
108+
109+
xr.testing.assert_allclose(
110+
idata_pmx.posterior,
111+
runs[0]["pathfinder_idata"].posterior,
112+
)
113+
114+
xr.testing.assert_allclose(
115+
idata_pmx.posterior,
116+
runs[1]["pathfinder_idata"].posterior,
117+
)

0 commit comments

Comments
 (0)