Skip to content

Commit 0db91fe

Browse files
committed
feat(pathfinder): add PyMC-based Pathfinder VI implementation
Add a new PyMC-based implementation of Pathfinder VI that uses PyTensor operations which provides support for both PyMC and BlackJAX backends in fit_pathfinder.
1 parent 05aeeaf commit 0db91fe

File tree

3 files changed

+532
-25
lines changed

3 files changed

+532
-25
lines changed

pymc_experimental/inference/lbfgs.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
from collections.abc import Callable
2+
from typing import NamedTuple
3+
4+
import numpy as np
5+
import pytensor.tensor as pt
6+
7+
from pytensor.tensor.variable import TensorVariable
8+
from scipy.optimize import fmin_l_bfgs_b
9+
10+
11+
class LBFGSHistory(NamedTuple):
12+
x: TensorVariable
13+
f: TensorVariable
14+
g: TensorVariable
15+
16+
17+
class LBFGSHistoryManager:
18+
def __init__(self, fn: Callable, grad_fn: Callable, x0: np.ndarray, maxiter: int):
19+
dim = x0.shape[0]
20+
maxiter_add_one = maxiter + 1
21+
# Preallocate arrays to save memory and improve speed
22+
self.x_history = np.empty((maxiter_add_one, dim), dtype=np.float64)
23+
self.f_history = np.empty(maxiter_add_one, dtype=np.float64)
24+
self.g_history = np.empty((maxiter_add_one, dim), dtype=np.float64)
25+
self.count = 0
26+
self.fn = fn
27+
self.grad_fn = grad_fn
28+
self.add_entry(x0, fn(x0), grad_fn(x0))
29+
30+
def add_entry(self, x, f, g=None):
31+
# Store the values directly in preallocated arrays
32+
self.x_history[self.count] = x
33+
self.f_history[self.count] = f
34+
if self.g_history is not None and g is not None:
35+
self.g_history[self.count] = g
36+
self.count += 1
37+
38+
def get_history(self):
39+
# Return trimmed arrays up to the number of entries actually used
40+
x = self.x_history[: self.count]
41+
f = self.f_history[: self.count]
42+
g = self.g_history[: self.count] if self.g_history is not None else None
43+
return LBFGSHistory(
44+
x=pt.as_tensor(x, dtype="float64"),
45+
f=pt.as_tensor(f, dtype="float64"),
46+
g=pt.as_tensor(g, dtype="float64"),
47+
)
48+
49+
def __call__(self, x):
50+
self.add_entry(x, self.fn(x), self.grad_fn(x))
51+
52+
53+
def lbfgs(
54+
fn,
55+
grad_fn,
56+
x0: np.ndarray,
57+
maxcor: int | None = None,
58+
maxiter=1000,
59+
ftol=1e-5,
60+
gtol=1e-8,
61+
maxls=1000,
62+
):
63+
def callback(xk):
64+
lbfgs_history_manager(xk)
65+
66+
lbfgs_history_manager = LBFGSHistoryManager(
67+
fn=fn,
68+
grad_fn=grad_fn,
69+
x0=x0,
70+
maxiter=maxiter,
71+
)
72+
73+
# options = dict(
74+
# maxcor=maxcor,
75+
# maxiter=maxiter,
76+
# ftol=ftol,
77+
# gtol=gtol,
78+
# maxls=maxls,
79+
# )
80+
# minimize(
81+
# fn,
82+
# x0,
83+
# method="L-BFGS-B",
84+
# jac=grad_fn,
85+
# options=options,
86+
# callback=callback,
87+
# )
88+
fmin_l_bfgs_b(
89+
func=fn,
90+
fprime=grad_fn,
91+
x0=x0,
92+
pgtol=gtol,
93+
factr=ftol / np.finfo(float).eps,
94+
maxls=maxls,
95+
maxiter=maxiter,
96+
m=maxcor,
97+
callback=callback,
98+
)
99+
return lbfgs_history_manager.get_history()

0 commit comments

Comments
 (0)