Skip to content

Commit 42482fd

Browse files
ricardoV94twiecki
authored andcommitted
Refactor sampling_jax
sampling_jax is no longer wrapped in a symbolic Op, making sampling more performant. SharedVariables are now replaced by constants, and a ValueError is raised if any contains a default_update.
1 parent 1d2f592 commit 42482fd

File tree

2 files changed

+109
-169
lines changed

2 files changed

+109
-169
lines changed

pymc/sampling_jax.py

Lines changed: 90 additions & 167 deletions
Original file line numberDiff line numberDiff line change
@@ -15,144 +15,38 @@
1515
import pandas as pd
1616

1717
from aesara.compile import SharedVariable
18-
from aesara.graph.basic import Apply, Constant, clone, graph_inputs
18+
from aesara.graph.basic import clone_replace, graph_inputs
1919
from aesara.graph.fg import FunctionGraph
20-
from aesara.graph.op import Op
2120
from aesara.graph.opt import MergeOptimizer
2221
from aesara.link.jax.dispatch import jax_funcify
23-
from aesara.tensor.type import TensorType
2422

2523
from pymc import modelcontext
2624
from pymc.aesaraf import compile_rv_inplace
2725

2826
warnings.warn("This module is experimental.")
2927

3028

31-
class NumPyroNUTS(Op):
32-
def __init__(
33-
self,
34-
inputs,
35-
outputs,
36-
target_accept=0.8,
37-
draws=1000,
38-
tune=1000,
39-
chains=4,
40-
seed=None,
41-
progress_bar=True,
42-
):
43-
self.draws = draws
44-
self.tune = tune
45-
self.chains = chains
46-
self.target_accept = target_accept
47-
self.progress_bar = progress_bar
48-
self.seed = seed
29+
def replace_shared_variables(graph):
30+
"""Replace shared variables in graph by their constant values
4931
50-
self.inputs, self.outputs = clone(inputs, outputs, copy_inputs=False)
51-
self.inputs_type = tuple(input.type for input in inputs)
52-
self.outputs_type = tuple(output.type for output in outputs)
53-
self.nin = len(inputs)
54-
self.nout = len(outputs)
55-
self.nshared = len([v for v in inputs if isinstance(v, SharedVariable)])
56-
self.samples_bcast = [self.chains == 1, self.draws == 1]
32+
Raises
33+
------
34+
ValueError
35+
If any shared variable contains default_updates
36+
"""
5737

58-
self.fgraph = FunctionGraph(self.inputs, self.outputs, clone=False)
59-
MergeOptimizer().optimize(self.fgraph)
38+
shared_variables = [var for var in graph_inputs(graph) if isinstance(var, SharedVariable)]
6039

61-
super().__init__()
62-
63-
def make_node(self, *inputs):
64-
65-
# The samples for each variable
66-
outputs = [
67-
TensorType(v.dtype, self.samples_bcast + list(v.broadcastable))() for v in inputs
68-
]
69-
70-
# The leapfrog statistics
71-
outputs += [TensorType("int64", self.samples_bcast)()]
72-
73-
all_inputs = list(inputs)
74-
if self.nshared > 0:
75-
all_inputs += self.inputs[-self.nshared :]
76-
77-
return Apply(self, all_inputs, outputs)
78-
79-
def do_constant_folding(self, *args):
80-
return False
81-
82-
def perform(self, node, inputs, outputs):
83-
raise NotImplementedError()
84-
85-
86-
@jax_funcify.register(NumPyroNUTS)
87-
def jax_funcify_NumPyroNUTS(op, node, **kwargs):
88-
from numpyro.infer import MCMC, NUTS
89-
90-
draws = op.draws
91-
tune = op.tune
92-
chains = op.chains
93-
target_accept = op.target_accept
94-
progress_bar = op.progress_bar
95-
seed = op.seed
96-
97-
# Compile the "inner" log-likelihood function. This will have extra shared
98-
# variable inputs as the last arguments
99-
logp_fn = jax_funcify(op.fgraph, **kwargs)
100-
101-
if isinstance(logp_fn, (list, tuple)):
102-
# This handles the new JAX backend, which always returns a tuple
103-
logp_fn = logp_fn[0]
104-
105-
def _sample(*inputs):
106-
107-
if op.nshared > 0:
108-
current_state = inputs[: -op.nshared]
109-
shared_inputs = tuple(op.fgraph.inputs[-op.nshared :])
110-
else:
111-
current_state = inputs
112-
shared_inputs = ()
113-
114-
def log_fn_wrap(x):
115-
res = logp_fn(
116-
*(
117-
x
118-
# We manually obtain the shared values and added them
119-
# as arguments to our compiled "inner" function
120-
+ tuple(
121-
v.get_value(borrow=True, return_internal_type=True) for v in shared_inputs
122-
)
123-
)
124-
)
125-
126-
if isinstance(res, (list, tuple)):
127-
# This handles the new JAX backend, which always returns a tuple
128-
res = res[0]
129-
130-
return -res
131-
132-
nuts_kernel = NUTS(
133-
potential_fn=log_fn_wrap,
134-
target_accept_prob=target_accept,
135-
adapt_step_size=True,
136-
adapt_mass_matrix=True,
137-
dense_mass=False,
138-
)
139-
140-
pmap_numpyro = MCMC(
141-
nuts_kernel,
142-
num_warmup=tune,
143-
num_samples=draws,
144-
num_chains=chains,
145-
postprocess_fn=None,
146-
chain_method="parallel",
147-
progress_bar=progress_bar,
40+
if any(hasattr(var, "default_update") for var in shared_variables):
41+
raise ValueError(
42+
"Graph contains shared variables with default_update which cannot "
43+
"be safely replaced."
14844
)
14945

150-
pmap_numpyro.run(seed, init_params=current_state, extra_fields=("num_steps",))
151-
samples = pmap_numpyro.get_samples(group_by_chain=True)
152-
leapfrogs_taken = pmap_numpyro.get_extra_fields(group_by_chain=True)["num_steps"]
153-
return tuple(samples) + (leapfrogs_taken,)
46+
replacements = {var: at.constant(var.get_value(borrow=True)) for var in shared_variables}
15447

155-
return _sample
48+
new_graph = clone_replace(graph, replace=replacements)
49+
return new_graph
15650

15751

15852
def sample_numpyro_nuts(
@@ -165,72 +59,101 @@ def sample_numpyro_nuts(
16559
progress_bar=True,
16660
keep_untransformed=False,
16761
):
62+
from numpyro.infer import MCMC, NUTS
63+
16864
model = modelcontext(model)
16965

170-
seed = jax.random.PRNGKey(random_seed)
66+
tic1 = pd.Timestamp.now()
67+
print("Compiling...", file=sys.stdout)
17168

17269
rv_names = [rv.name for rv in model.value_vars]
17370
init_state = [model.initial_point[rv_name] for rv_name in rv_names]
17471
init_state_batched = jax.tree_map(lambda x: np.repeat(x[None, ...], chains, axis=0), init_state)
175-
init_state_batched_at = [at.as_tensor(v) for v in init_state_batched]
17672

177-
nuts_inputs = sorted(
178-
(v for v in graph_inputs([model.logpt]) if not isinstance(v, Constant)),
179-
key=lambda x: isinstance(x, SharedVariable),
180-
)
181-
map_seed = jax.random.split(seed, chains)
182-
numpyro_samples = NumPyroNUTS(
183-
nuts_inputs,
184-
[model.logpt],
185-
target_accept=target_accept,
186-
draws=draws,
187-
tune=tune,
188-
chains=chains,
189-
seed=map_seed,
190-
progress_bar=progress_bar,
191-
)(*init_state_batched_at)
73+
logpt = replace_shared_variables([model.logpt])[0]
74+
logpt_fgraph = FunctionGraph(outputs=[logpt], clone=False)
75+
MergeOptimizer().optimize(logpt_fgraph)
76+
logp_fn = jax_funcify(logpt_fgraph)
19277

193-
# Un-transform the transformed variables in JAX
194-
sample_outputs = []
195-
for i, (value_var, rv_samples) in enumerate(zip(model.value_vars, numpyro_samples[:-1])):
196-
rv = model.values_to_rvs[value_var]
197-
transform = getattr(value_var.tag, "transform", None)
198-
if transform is not None:
199-
untrans_value_var = transform.backward(rv, rv_samples)
200-
untrans_value_var.name = rv.name
201-
sample_outputs.append(untrans_value_var)
78+
if isinstance(logp_fn, (list, tuple)):
79+
# This handles the new JAX backend, which always returns a tuple
80+
logp_fn = logp_fn[0]
20281

203-
if keep_untransformed:
204-
rv_samples.name = value_var.name
205-
sample_outputs.append(rv_samples)
206-
else:
207-
rv_samples.name = rv.name
208-
sample_outputs.append(rv_samples)
82+
def logp_fn_wrap(x):
83+
res = logp_fn(*x)
20984

210-
print("Compiling...", file=sys.stdout)
85+
if isinstance(res, (list, tuple)):
86+
# This handles the new JAX backend, which always returns a tuple
87+
res = res[0]
21188

212-
tic1 = pd.Timestamp.now()
213-
_sample = compile_rv_inplace(
214-
[],
215-
sample_outputs + [numpyro_samples[-1]],
216-
allow_input_downcast=True,
217-
on_unused_input="ignore",
218-
accept_inplace=True,
219-
mode="JAX",
89+
# Jax expects a potential with the opposite sign of model.logpt
90+
return -res
91+
92+
nuts_kernel = NUTS(
93+
potential_fn=logp_fn_wrap,
94+
target_accept_prob=target_accept,
95+
adapt_step_size=True,
96+
adapt_mass_matrix=True,
97+
dense_mass=False,
98+
)
99+
100+
pmap_numpyro = MCMC(
101+
nuts_kernel,
102+
num_warmup=tune,
103+
num_samples=draws,
104+
num_chains=chains,
105+
postprocess_fn=None,
106+
chain_method="parallel",
107+
progress_bar=progress_bar,
220108
)
221-
tic2 = pd.Timestamp.now()
222109

110+
tic2 = pd.Timestamp.now()
223111
print("Compilation time = ", tic2 - tic1, file=sys.stdout)
224112

225113
print("Sampling...", file=sys.stdout)
226114

227-
*mcmc_samples, leapfrogs_taken = _sample()
228-
tic3 = pd.Timestamp.now()
115+
seed = jax.random.PRNGKey(random_seed)
116+
map_seed = jax.random.split(seed, chains)
229117

118+
pmap_numpyro.run(map_seed, init_params=init_state_batched, extra_fields=("num_steps",))
119+
raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True)
120+
121+
tic3 = pd.Timestamp.now()
230122
print("Sampling time = ", tic3 - tic2, file=sys.stdout)
231123

232-
posterior = {k.name: v for k, v in zip(sample_outputs, mcmc_samples)}
124+
print("Transforming variables...", file=sys.stdout)
125+
mcmc_samples = []
126+
for i, (value_var, raw_samples) in enumerate(zip(model.value_vars, raw_mcmc_samples)):
127+
raw_samples = at.constant(np.asarray(raw_samples))
128+
129+
rv = model.values_to_rvs[value_var]
130+
transform = getattr(value_var.tag, "transform", None)
131+
132+
if transform is not None:
133+
# TODO: This will fail when the transformation depends on another variable
134+
# such as in interval transform with RVs as edges
135+
trans_samples = transform.backward(rv, raw_samples)
136+
trans_samples.name = rv.name
137+
mcmc_samples.append(trans_samples)
138+
139+
if keep_untransformed:
140+
raw_samples.name = value_var.name
141+
mcmc_samples.append(raw_samples)
142+
else:
143+
raw_samples.name = rv.name
144+
mcmc_samples.append(raw_samples)
145+
146+
mcmc_varnames = [var.name for var in mcmc_samples]
147+
mcmc_samples = compile_rv_inplace(
148+
[],
149+
mcmc_samples,
150+
mode="JAX",
151+
)()
152+
153+
tic4 = pd.Timestamp.now()
154+
print("Transformation time = ", tic4 - tic3, file=sys.stdout)
233155

156+
posterior = {k: v for k, v in zip(mcmc_varnames, mcmc_samples)}
234157
az_trace = az.from_dict(posterior=posterior)
235158

236159
return az_trace

pymc/tests/test_sampling_jax.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
import aesara
22
import numpy as np
3+
import pytest
4+
5+
from aesara.compile import SharedVariable
6+
from aesara.graph import graph_inputs
37

48
import pymc as pm
59

6-
from pymc.sampling_jax import sample_numpyro_nuts
10+
from pymc.sampling_jax import replace_shared_variables, sample_numpyro_nuts
711

812

913
def test_transform_samples():
@@ -29,7 +33,20 @@ def test_transform_samples():
2933

3034
obs_at.set_value(-obs)
3135
with model:
32-
trace = sample_numpyro_nuts(chains=1, random_seed=1322, keep_untransformed=False)
36+
trace = sample_numpyro_nuts(chains=2, random_seed=1322, keep_untransformed=False)
3337

3438
assert -11 < trace.posterior["a"].mean() < -8
3539
assert 1.5 < trace.posterior["sigma"].mean() < 2.5
40+
41+
42+
def test_replace_shared_variables():
43+
44+
x = aesara.shared(5, name="shared_x")
45+
46+
new_x = replace_shared_variables([x])
47+
shared_variables = [var for var in graph_inputs(new_x) if isinstance(var, SharedVariable)]
48+
assert not shared_variables
49+
50+
x.default_update = x + 1
51+
with pytest.raises(ValueError, match="shared variables with default_update"):
52+
replace_shared_variables([x])

0 commit comments

Comments
 (0)