Skip to content

Commit 9bfc48c

Browse files
committed
feat: implement PyMC-based Pathfinder VI backend
Add a PyMC/PyTensor implementation of Pathfinder VI as an alternative to the existing BlackJAX backend. Key changes include: - Implement core Pathfinder components using PyTensor with batched operations - Add inference_backend parameter to select between PyMC and BlackJAX implementations - Enable jittering of initial points for Pathfinder
1 parent 0db91fe commit 9bfc48c

File tree

2 files changed

+51
-30
lines changed

2 files changed

+51
-30
lines changed

pymc_experimental/inference/lbfgs.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,14 @@
22
from typing import NamedTuple
33

44
import numpy as np
5-
import pytensor.tensor as pt
65

7-
from pytensor.tensor.variable import TensorVariable
86
from scipy.optimize import fmin_l_bfgs_b
97

108

119
class LBFGSHistory(NamedTuple):
12-
x: TensorVariable
13-
f: TensorVariable
14-
g: TensorVariable
10+
x: np.ndarray
11+
f: np.ndarray
12+
g: np.ndarray
1513

1614

1715
class LBFGSHistoryManager:
@@ -41,9 +39,9 @@ def get_history(self):
4139
f = self.f_history[: self.count]
4240
g = self.g_history[: self.count] if self.g_history is not None else None
4341
return LBFGSHistory(
44-
x=pt.as_tensor(x, dtype="float64"),
45-
f=pt.as_tensor(f, dtype="float64"),
46-
g=pt.as_tensor(g, dtype="float64"),
42+
x=x,
43+
f=f,
44+
g=g,
4745
)
4846

4947
def __call__(self, x):

pymc_experimental/inference/pathfinder.py

Lines changed: 45 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from pymc.model.core import Point
3535
from pymc.sampling.jax import get_jaxified_graph
3636
from pymc.util import RandomSeed, _get_seeds_per_chain, get_default_varnames
37+
from pytensor.graph import Apply, Op
3738

3839
from pymc_experimental.inference.lbfgs import lbfgs
3940

@@ -311,7 +312,8 @@ def bfgs_sample(
311312
alpha,
312313
beta,
313314
gamma,
314-
random_seed: RandomSeed | None = None,
315+
rng,
316+
# random_seed: RandomSeed | None = None,
315317
):
316318
# batch: L = 8
317319
# alpha_l: (N,) => (L, N)
@@ -324,8 +326,6 @@ def bfgs_sample(
324326
# logdensity: (M,) => (L, M)
325327
# theta: (J, N)
326328

327-
rng = pytensor.shared(np.random.default_rng(seed=random_seed))
328-
329329
if not _batched(x, g, alpha, beta, gamma):
330330
x = pt.atleast_2d(x)
331331
g = pt.atleast_2d(g)
@@ -371,6 +371,24 @@ def bfgs_sample(
371371
return phi, logdensity
372372

373373

374+
class LogLike(Op):
375+
def __init__(self, logp_func):
376+
self.logp_func = logp_func
377+
super().__init__()
378+
379+
def make_node(self, phi_node):
380+
# Convert inputs to tensor variables
381+
phi_node = pt.as_tensor(phi_node)
382+
output_type = pt.tensor(dtype=phi_node.dtype, shape=(None, None))
383+
return Apply(self, [phi_node], [output_type])
384+
385+
def perform(self, node: Apply, phi_node, outputs) -> None:
386+
phi_node = phi_node[0]
387+
logp_node = np.apply_along_axis(self.logp_func, axis=-1, arr=phi_node)
388+
# outputs[0][0] = np.asarray(logp)
389+
outputs[0][0] = logp_node
390+
391+
374392
def _pymc_pathfinder(
375393
model,
376394
x0: np.float64,
@@ -406,38 +424,43 @@ def neg_dlogp_func(x):
406424
gtol=gtol,
407425
maxls=maxls,
408426
)
427+
x = pytensor.shared(history.x, "x")
428+
g = pytensor.shared(history.g, "g")
409429

410-
alpha, update_mask = alpha_recover(history.x, history.g)
411-
412-
beta, gamma = inverse_hessian_factors(alpha, history.x, history.g, update_mask, J=maxcor)
430+
alpha, update_mask = alpha_recover(x, g)
413431

414-
phi, logq_phi = bfgs_sample(
432+
beta, gamma = inverse_hessian_factors(alpha, x, g, update_mask, J=maxcor)
433+
rng = pytensor.shared(np.random.default_rng(seed=pathfinder_seed))
434+
_phi, _logq_phi = bfgs_sample(
415435
num_samples=num_elbo_draws,
416-
x=history.x,
417-
g=history.g,
436+
x=x,
437+
g=g,
418438
alpha=alpha,
419439
beta=beta,
420440
gamma=gamma,
421-
random_seed=pathfinder_seed,
441+
rng=rng,
422442
)
443+
sample_phi_fn = pytensor.function([alpha, beta, gamma], [_phi, _logq_phi])
444+
phi, logq_phi = sample_phi_fn(alpha.eval(), beta.eval(), gamma.eval())
423445

424446
# .vectorize is slower than apply_along_axis
425-
logp_phi = np.apply_along_axis(logp_func, axis=-1, arr=phi.eval())
426-
logq_phi = logq_phi.eval()
427-
elbo = (logp_phi - logq_phi).mean(axis=-1)
428-
lstar = np.argmax(elbo)
447+
loglike = LogLike(logp_func)
448+
logp_phi = loglike(phi)
449+
elbo = pt.mean(logp_phi - logq_phi, axis=-1)
450+
l_star = pt.argmax(elbo)
429451

452+
rng.set_value(np.random.default_rng(seed=sample_seed))
430453
psi, logq_psi = bfgs_sample(
431454
num_samples=num_draws,
432-
x=history.x[lstar],
433-
g=history.g[lstar],
434-
alpha=alpha[lstar],
435-
beta=beta[lstar],
436-
gamma=gamma[lstar],
437-
random_seed=sample_seed,
455+
x=x[l_star],
456+
g=g[l_star],
457+
alpha=alpha[l_star],
458+
beta=beta[l_star],
459+
gamma=gamma[l_star],
460+
rng=rng,
438461
)
439462

440-
return psi[0].eval(), logq_psi, logp_func
463+
return psi[0].eval(), logq_psi.eval()
441464

442465

443466
def fit_pathfinder(
@@ -492,7 +515,7 @@ def fit_pathfinder(
492515

493516
# TODO: make better
494517
if inference_backend == "pymc":
495-
pathfinder_samples, logq_psi, logp_func = _pymc_pathfinder(
518+
pathfinder_samples, logq_psi = _pymc_pathfinder(
496519
model,
497520
ip,
498521
maxcor=maxcor,

0 commit comments

Comments
 (0)