Skip to content

Implement a minimizer for INLA #513

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
75 changes: 75 additions & 0 deletions pymc_extras/inference/laplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
from pymc.model.transform.conditioning import remove_value_transforms
from pymc.model.transform.optimization import freeze_dims_and_data
from pymc.util import get_default_varnames
from pytensor.tensor import TensorVariable
from pytensor.tensor.optimize import minimize
from scipy import stats

from pymc_extras.inference.find_map import (
Expand Down Expand Up @@ -415,6 +417,79 @@ def sample_laplace_posterior(
return idata


def find_mode(
x: TensorVariable,
args: dict,
inputs: list[TensorVariable] | None = None,
x0: TensorVariable
| None = None, # TODO This isn't a TensorVariable, not sure what the general datatype for numeric arraylikes is
model: pm.Model | None = None,
method: minimize_method = "BFGS",
use_jac: bool = True,
use_hess: bool = False,
optimizer_kwargs: dict | None = None,
): # TODO Output type is list of same type as x0
model = pm.modelcontext(model)

# if x0 is None:
# #TODO Issue with X not being an RV
# print(model.initial_point())

# from pymc.initial_point import make_initial_point_fn
# frozen_model = freeze_dims_and_data(model)
# ipfn = make_initial_point_fn(
# model=model,
# jitter_rvs=set(),#(jitter_rvs),
# return_transformed=True,
# overrides=args,
# )

# random_seed = None
# start_dict = ipfn(random_seed)
# vars_dict = {var.name: var for var in frozen_model.continuous_value_vars}
# initial_params = DictToArrayBijection.map(
# {var_name: value for var_name, value in start_dict.items() if var_name in vars_dict}
# )
# print(initial_params)

# Minimise negative log likelihood
nll = -model.logp()
soln, _ = minimize(
objective=nll,
x=x,
method=method,
jac=use_jac,
hess=use_hess,
optimizer_kwargs=optimizer_kwargs,
)

# Get input variables
# TODO issue when this is nll
if inputs is None:
inputs = [
pytensor.graph.basic.get_var_by_name(model.basic_RVs[1], target_var_id=var)[0]
for var in args
]
for i, var in enumerate(inputs):
try:
inputs[i] = model.rvs_to_values[var]
except KeyError:
pass
inputs.insert(0, x)

# Obtain the Hessian (re-use graph if already computed in minimize)
if use_hess:
hess = soln.owner.op.inner_outputs[-1]
hess = pytensor.graph.replace.graph_replace(
hess, {x: soln}
) # TODO: x here is 'beta', soln is a MinimizeOp. There's no instance of MinimizeOp in the hessian graph
else:
hess = pytensor.gradient.hessian(nll, x)

get_mode_and_hessian = pytensor.function(inputs, [soln, hess])
return get_mode_and_hessian(x0, **args)


def fit_laplace(
optimize_method: minimize_method | Literal["basinhopping"] = "BFGS",
*,
Expand Down
159 changes: 159 additions & 0 deletions tests/test_laplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@

import numpy as np
import pymc as pm
import pytensor as pt
import pytest

import pymc_extras as pmx

from pymc_extras.inference.find_map import GradientBackend, find_MAP
from pymc_extras.inference.laplace import (
find_mode,
fit_laplace,
fit_mvn_at_MAP,
sample_laplace_posterior,
Expand Down Expand Up @@ -279,3 +281,160 @@ def test_laplace_scalar():
assert idata_laplace.fit.covariance_matrix.shape == (1, 1)

np.testing.assert_allclose(idata_laplace.fit.mean_vector.values.item(), data.mean(), atol=0.1)


def test_find_mode():
k = 10
N = 10000
y = pt.vector("y", dtype="int64")
X = pt.matrix("X", shape=(N, k))

# Pre-commit did this. Quite ugly. Should compute hess in code rather than storing a hardcoded array.
true_hess = np.array(
[
[
2.50100000e03,
-1.78838742e00,
1.59484217e01,
-9.78343803e00,
2.86125467e01,
-7.38071788e00,
-4.97729126e01,
3.53243810e01,
1.69071769e01,
-1.30755942e01,
],
[
-1.78838742e00,
2.54687995e03,
8.99456512e-02,
-1.33603390e01,
-2.37641179e01,
4.57780742e01,
-1.22640681e01,
2.70879664e01,
4.04435512e01,
2.08826556e00,
],
[
1.59484217e01,
8.99456512e-02,
2.46908384e03,
-1.80358232e01,
1.14131535e01,
2.21632317e01,
1.25443469e00,
1.50344618e01,
-3.59940488e01,
-1.05191328e01,
],
[
-9.78343803e00,
-1.33603390e01,
-1.80358232e01,
2.50546496e03,
3.27545028e01,
-3.33517501e01,
-2.68735672e01,
-2.69114305e01,
-1.20464337e01,
9.02338622e00,
],
[
2.86125467e01,
-2.37641179e01,
1.14131535e01,
3.27545028e01,
2.49959736e03,
-3.98220135e00,
-4.09495199e00,
-1.51115257e01,
-5.77436126e01,
-2.98600447e00,
],
[
-7.38071788e00,
4.57780742e01,
2.21632317e01,
-3.33517501e01,
-3.98220135e00,
2.48169432e03,
-1.26885014e01,
-3.53524089e01,
5.89656794e00,
1.67164400e01,
],
[
-4.97729126e01,
-1.22640681e01,
1.25443469e00,
-2.68735672e01,
-4.09495199e00,
-1.26885014e01,
2.47216241e03,
8.16935659e00,
-4.89399152e01,
-1.11646138e01,
],
[
3.53243810e01,
2.70879664e01,
1.50344618e01,
-2.69114305e01,
-1.51115257e01,
-3.53524089e01,
8.16935659e00,
2.52940405e03,
3.07751540e00,
-8.60023392e00,
],
[
1.69071769e01,
4.04435512e01,
-3.59940488e01,
-1.20464337e01,
-5.77436126e01,
5.89656794e00,
-4.89399152e01,
3.07751540e00,
2.49452594e03,
6.06984410e01,
],
[
-1.30755942e01,
2.08826556e00,
-1.05191328e01,
9.02338622e00,
-2.98600447e00,
1.67164400e01,
-1.11646138e01,
-8.60023392e00,
6.06984410e01,
2.49290175e03,
],
]
)

with pm.Model() as model:
beta = pm.MvNormal("beta", mu=np.zeros(k), cov=np.identity(k), shape=(k,))
p = pm.math.invlogit(beta @ X.T)
y = pm.Bernoulli("y", p)

rng = np.random.default_rng(123)
Xval = rng.normal(size=(10000, 9))
Xval = np.c_[np.ones(10000), Xval]

true_beta = rng.normal(scale=0.1, size=(10,))
true_p = pm.math.invlogit(Xval @ true_beta).eval()
ynum = rng.binomial(1, true_p)

beta_val = model.rvs_to_values[beta]
x0 = np.zeros(k)
args = {"y": ynum, "X": Xval}

beta_mode, beta_hess = find_mode(
x=beta_val, x0=x0, args=args, method="BFGS", optimizer_kwargs={"tol": 1e-8}
)

np.testing.assert_allclose(beta_mode, true_beta, atol=0.1, rtol=0.1)
np.testing.assert_allclose(beta_hess, true_hess, atol=0.1, rtol=0.1)
Loading