Skip to content

Commit 05aeeaf

Browse files
committed
Merge branch 'replicate_pathfinder_w_pytensor' into scipy_lbfgs
2 parents 663a60a + 0c880d2 commit 05aeeaf

File tree

1 file changed

+48
-16
lines changed

1 file changed

+48
-16
lines changed

pymc_experimental/inference/pathfinder.py

Lines changed: 48 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,55 @@
1515
import collections
1616
import sys
1717

18+
from collections.abc import Callable
19+
1820
import arviz as az
1921
import blackjax
2022
import jax
2123
import numpy as np
2224
import pymc as pm
2325

2426
from packaging import version
27+
from pymc import Model
2528
from pymc.backends.arviz import coords_and_dims_for_inferencedata
2629
from pymc.blocking import DictToArrayBijection, RaveledVars
30+
from pymc.initial_point import make_initial_point_fn
2731
from pymc.model import modelcontext
32+
from pymc.model.core import Point
2833
from pymc.sampling.jax import get_jaxified_graph
2934
from pymc.util import RandomSeed, _get_seeds_per_chain, get_default_varnames
3035

3136

37+
def get_jaxified_logp_ravel_inputs(
38+
model: Model,
39+
initial_points: dict | None = None,
40+
) -> tuple[Callable, DictToArrayBijection]:
41+
"""
42+
Get jaxified logp function and ravel inputs for a PyMC model.
43+
44+
Parameters
45+
----------
46+
model : Model
47+
PyMC model to jaxify.
48+
49+
Returns
50+
-------
51+
tuple[Callable, DictToArrayBijection]
52+
A tuple containing the jaxified logp function and the DictToArrayBijection.
53+
"""
54+
55+
new_logprob, new_input = pm.pytensorf.join_nonshared_inputs(
56+
initial_points, (model.logp(),), model.value_vars, ()
57+
)
58+
59+
logprob_fn_list = get_jaxified_graph([new_input], new_logprob)
60+
61+
def logprob_fn(x):
62+
return logprob_fn_list(x)[0]
63+
64+
return logprob_fn, DictToArrayBijection.map(initial_points)
65+
66+
3267
def convert_flat_trace_to_idata(
3368
samples,
3469
include_transformed=False,
@@ -37,7 +72,7 @@ def convert_flat_trace_to_idata(
3772
):
3873
model = modelcontext(model)
3974
ip = model.initial_point()
40-
ip_point_map_info = pm.blocking.DictToArrayBijection.map(ip).point_map_info
75+
ip_point_map_info = DictToArrayBijection.map(ip).point_map_info
4176
trace = collections.defaultdict(list)
4277
for sample in samples:
4378
raveld_vars = RaveledVars(sample, ip_point_map_info)
@@ -62,10 +97,10 @@ def convert_flat_trace_to_idata(
6297

6398

6499
def fit_pathfinder(
100+
model=None,
65101
num_draws=1000,
66102
random_seed: RandomSeed | None = None,
67103
postprocessing_backend="cpu",
68-
model=None,
69104
**pathfinder_kwargs,
70105
):
71106
"""
@@ -99,19 +134,16 @@ def fit_pathfinder(
99134

100135
model = modelcontext(model)
101136

102-
ip = model.initial_point()
103-
ip_map = DictToArrayBijection.map(ip)
137+
[jitter_seed, pathfinder_seed, sample_seed] = _get_seeds_per_chain(random_seed, 3)
104138

105-
new_logprob, new_input = pm.pytensorf.join_nonshared_inputs(
106-
ip, (model.logp(),), model.value_vars, ()
139+
# set initial points. PF requires jittering of initial points
140+
ipfn = make_initial_point_fn(
141+
model=model,
142+
jitter_rvs=set(model.free_RVs),
143+
# TODO: add argument for jitter strategy
107144
)
108-
109-
logprob_fn_list = get_jaxified_graph([new_input], new_logprob)
110-
111-
def logprob_fn(x):
112-
return logprob_fn_list(x)[0]
113-
114-
[pathfinder_seed, sample_seed] = _get_seeds_per_chain(random_seed, 2)
145+
ip = Point(ipfn(jitter_seed), model=model)
146+
logprob_fn, ip_map = get_jaxified_logp_ravel_inputs(model, initial_points=ip)
115147

116148
print("Running pathfinder...", file=sys.stdout)
117149
pathfinder_state, pathfinder_info = blackjax.vi.pathfinder.approximate(
@@ -120,17 +152,17 @@ def logprob_fn(x):
120152
initial_position=ip_map.data,
121153
**pathfinder_kwargs,
122154
)
123-
124-
# retrieved logq
125-
pathfinder_samples, logq = blackjax.vi.pathfinder.sample(
155+
pathfinder_samples, _ = blackjax.vi.pathfinder.sample(
126156
rng_key=jax.random.key(sample_seed),
127157
state=pathfinder_state,
128158
num_samples=num_draws,
129159
)
130160

131161
idata = convert_flat_trace_to_idata(
162+
pathfinder_samples,
132163
pathfinder_samples,
133164
postprocessing_backend=postprocessing_backend,
134165
model=model,
135166
)
136167
return pathfinder_state, pathfinder_info, pathfinder_samples, idata
168+
return pathfinder_state, pathfinder_info, pathfinder_samples, idata

0 commit comments

Comments
 (0)