Skip to content

Commit 308bd05

Browse files
committed
run multiple chains
1 parent 1da02c4 commit 308bd05

File tree

3 files changed

+150
-37
lines changed

3 files changed

+150
-37
lines changed

pymc3/smc/sample_smc.py

Lines changed: 114 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,19 @@
1414

1515
import time
1616
import logging
17+
import warnings
18+
from collections.abc import Iterable
19+
import multiprocessing as mp
20+
import numpy as np
21+
1722
from .smc import SMC
23+
from ..model import modelcontext
24+
from ..backends.base import MultiTrace
25+
from ..parallel_sampling import _cpu_count
26+
27+
EXPERIMENTAL_WARNING = (
28+
"Warning: SMC-ABC is an experimental step method and not yet recommended for use in PyMC3!"
29+
)
1830

1931

2032
def sample_smc(
@@ -30,6 +42,9 @@ def sample_smc(
3042
sum_stat="identity",
3143
model=None,
3244
random_seed=-1,
45+
parallel=True,
46+
chains=None,
47+
cores=None,
3348
):
3449
r"""
3550
Sequential Monte Carlo based sampling
@@ -69,6 +84,16 @@ def sample_smc(
6984
model: Model (optional if in ``with`` context)).
7085
random_seed: int
7186
random seed
87+
parallel: bool
88+
Distribute computations across cores if the number of cores is larger than 1.
89+
Defaults to True.
90+
cores : int
91+
The number of chains to run in parallel. If ``None``, set to the number of CPUs in the
92+
system, but at most 4.
93+
chains : int
94+
The number of chains to sample. Running independent chains is important for some
95+
convergence statistics. If ``None`` (default), then set to either ``cores`` or 2, whichever
96+
is larger.
7297
7398
Notes
7499
-----
@@ -115,6 +140,89 @@ def sample_smc(
115140
%282007%29133:7%28816%29>`__
116141
"""
117142

143+
_log = logging.getLogger("pymc3")
144+
_log.info("Initializing SMC sampler...")
145+
146+
if cores is None:
147+
cores = _cpu_count()
148+
149+
if chains is None:
150+
chains = max(2, cores)
151+
152+
_log.info(f"Multiprocess sampling ({chains} chains in {cores} jobs)")
153+
154+
if random_seed == -1:
155+
random_seed = None
156+
if chains == 1 and isinstance(random_seed, int):
157+
random_seed = [random_seed]
158+
if random_seed is None or isinstance(random_seed, int):
159+
if random_seed is not None:
160+
np.random.seed(random_seed)
161+
random_seed = [np.random.randint(2 ** 30) for _ in range(chains)]
162+
if not isinstance(random_seed, Iterable):
163+
raise TypeError("Invalid value for `random_seed`. Must be tuple, list or int")
164+
165+
if kernel.lower() == "abc":
166+
warnings.warn(EXPERIMENTAL_WARNING)
167+
if len(modelcontext(model).observed_RVs) != 1:
168+
warnings.warn("SMC-ABC only works properly with models with one observed variable")
169+
170+
params = (
171+
draws,
172+
kernel,
173+
n_steps,
174+
start,
175+
tune_steps,
176+
p_acc_rate,
177+
threshold,
178+
epsilon,
179+
dist_func,
180+
sum_stat,
181+
model,
182+
)
183+
184+
t1 = time.time()
185+
if parallel:
186+
loggers = [_log] + [None] * (chains - 1)
187+
pool = mp.Pool(cores)
188+
results = pool.starmap(
189+
sample_smc_int, [(*params, random_seed[i], i, loggers[i]) for i in range(chains)]
190+
)
191+
192+
pool.close()
193+
pool.join()
194+
else:
195+
results = []
196+
for i in range(chains):
197+
results.append((sample_smc_int(*params, random_seed[i], i, _log)))
198+
199+
traces, log_marginal_likelihoods = zip(*results)
200+
trace = MultiTrace(traces)
201+
trace.report._n_draws = draws
202+
trace.report._n_tune = 0
203+
trace.report._t_sampling = time.time() - t1
204+
trace.report.log_marginal_likelihood = np.array(log_marginal_likelihoods)
205+
206+
return trace
207+
208+
209+
def sample_smc_int(
210+
draws=2000,
211+
kernel="metropolis",
212+
n_steps=25,
213+
start=None,
214+
tune_steps=True,
215+
p_acc_rate=0.99,
216+
threshold=0.5,
217+
epsilon=1.0,
218+
dist_func="gaussian_kernel",
219+
sum_stat="identity",
220+
model=None,
221+
random_seed=-1,
222+
chain=0,
223+
_log=None,
224+
):
225+
118226
smc = SMC(
119227
draws=draws,
120228
kernel=kernel,
@@ -128,33 +236,21 @@ def sample_smc(
128236
sum_stat=sum_stat,
129237
model=model,
130238
random_seed=random_seed,
239+
chain=chain,
131240
)
132-
133-
t1 = time.time()
134-
_log = logging.getLogger("pymc3")
135-
_log.info("Sample initial stage: ...")
136241
stage = 0
137242
smc.initialize_population()
138243
smc.setup_kernel()
139244
smc.initialize_logp()
140245

141246
while smc.beta < 1:
142247
smc.update_weights_beta()
143-
_log.info(
144-
"Stage: {:3d} Beta: {:.3f} Steps: {:3d} Acce: {:.3f}".format(
145-
stage, smc.beta, smc.n_steps, smc.acc_rate
146-
)
147-
)
248+
if _log is not None:
249+
_log.info(f"Stage: {stage:3d} Beta: {smc.beta:.3f}")
148250
smc.update_proposal()
149251
smc.resample()
150-
for _ in range(2):
151-
smc.mutate()
152-
smc.tune()
252+
smc.mutate()
253+
smc.tune()
153254
stage += 1
154255

155-
trace = smc.posterior_to_trace()
156-
trace.report._n_draws = smc.draws
157-
trace.report._n_tune = 0
158-
trace.report._t_sampling = time.time() - t1
159-
trace.report.ess = smc.ess
160-
return trace
256+
return smc.posterior_to_trace()

pymc3/smc/smc.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import numpy as np
1818
from scipy.special import logsumexp
19-
import warnings
2019
from theano import function as theano_function
2120
from arviz import psislw
2221

@@ -25,12 +24,6 @@
2524
from ..theanof import floatX, inputvars, make_shared_replacements, join_nonshared_inputs
2625
from ..sampling import sample_prior_predictive
2726
from ..backends.ndarray import NDArray
28-
from ..backends.base import MultiTrace
29-
30-
EXPERIMENTAL_WARNING = (
31-
"Warning: SMC-ABC methods are experimental step methods and not yet"
32-
" recommended for use in PyMC3!"
33-
)
3427

3528

3629
class SMC:
@@ -48,6 +41,7 @@ def __init__(
4841
sum_stat="Identity",
4942
model=None,
5043
random_seed=-1,
44+
chain=0,
5145
):
5246

5347
self.draws = draws
@@ -62,6 +56,7 @@ def __init__(
6256
self.sum_stat = sum_stat
6357
self.model = model
6458
self.random_seed = random_seed
59+
self.chain = chain
6560

6661
self.model = modelcontext(model)
6762

@@ -73,11 +68,11 @@ def __init__(
7368
self.proposed = draws * n_steps
7469
self.acc_rate = 1
7570
self.acc_per_chain = np.ones(self.draws)
76-
self.model.log_marginal_likelihood = 0
7771
self.variables = inputvars(self.model.vars)
7872
self.dimension = sum(v.dsize for v in self.variables)
7973
self.scalings = np.ones(self.draws) * 2.38 / (self.dimension) ** 0.5
8074
self.weights = np.ones(self.draws) / self.draws
75+
self.log_marginal_likelihood = 0
8176

8277
def initialize_population(self):
8378
"""
@@ -113,9 +108,6 @@ def setup_kernel(self):
113108
self.prior_logp_func = logp_forw([self.model.varlogpt], self.variables, shared)
114109

115110
if self.kernel.lower() == "abc":
116-
warnings.warn(EXPERIMENTAL_WARNING)
117-
if len(self.model.observed_RVs) != 1:
118-
warnings.warn("SMC-ABC only works properly with models with one observed variable")
119111
simulator = self.model.observed_RVs[0]
120112
self.likelihood_logp_func = PseudoLikelihood(
121113
self.epsilon,
@@ -165,9 +157,8 @@ def update_weights_beta(self):
165157
new_beta = 1
166158
log_weights_un = (new_beta - old_beta) * self.likelihood_logp
167159
log_weights = log_weights_un - logsumexp(log_weights_un)
168-
self.ess = np.exp(-logsumexp(log_weights * 2))
169160

170-
self.model.log_marginal_likelihood += logsumexp(log_weights_un) - np.log(self.draws)
161+
self.log_marginal_likelihood += logsumexp(log_weights_un) - np.log(self.draws)
171162
self.beta = new_beta
172163
self.weights = np.exp(log_weights)
173164

@@ -178,6 +169,7 @@ def resample(self):
178169
resampling_indexes = np.random.choice(
179170
np.arange(self.draws), size=self.draws, p=self.weights
180171
)
172+
181173
self.posterior = self.posterior[resampling_indexes]
182174
self.prior_logp = self.prior_logp[resampling_indexes]
183175
self.likelihood_logp = self.likelihood_logp[resampling_indexes]
@@ -239,6 +231,29 @@ def mutate(self):
239231
self.acc_per_chain = np.mean(ac_, axis=0)
240232
self.acc_rate = np.mean(ac_)
241233

234+
def posterior_to_trace_bk(self):
235+
"""
236+
Save results into a PyMC3 trace
237+
"""
238+
lenght_pos = len(self.posterior)
239+
varnames = [v.name for v in self.variables]
240+
straces = []
241+
with self.model:
242+
chain_lenght = int(lenght_pos / 10)
243+
for chain in range(10):
244+
strace = NDArray(self.model)
245+
strace.setup(chain_lenght, chain)
246+
for i in range(chain_lenght):
247+
value = []
248+
size = 0
249+
for var in varnames:
250+
shape, new_size = self.var_info[var]
251+
value.append(self.posterior[i][size : size + new_size].reshape(shape))
252+
size += new_size
253+
strace.record({k: v for k, v in zip(varnames, value)})
254+
straces.append(strace)
255+
return MultiTrace(straces)
256+
242257
def posterior_to_trace(self):
243258
"""
244259
Save results into a PyMC3 trace
@@ -248,16 +263,16 @@ def posterior_to_trace(self):
248263

249264
with self.model:
250265
strace = NDArray(self.model)
251-
strace.setup(lenght_pos, 0)
266+
strace.setup(lenght_pos, self.chain)
252267
for i in range(lenght_pos):
253268
value = []
254269
size = 0
255270
for var in varnames:
256271
shape, new_size = self.var_info[var]
257272
value.append(self.posterior[i][size : size + new_size].reshape(shape))
258273
size += new_size
259-
strace.record({k: v for k, v in zip(varnames, value)})
260-
return MultiTrace([strace])
274+
strace.record(point={k: v for k, v in zip(varnames, value)})
275+
return strace, self.log_marginal_likelihood
261276

262277

263278
def logp_forw(out_vars, vars, shared):

pymc3/tests/test_smc.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,9 @@ def test_ml(self):
7979
a = pm.Beta("a", alpha, beta)
8080
y = pm.Bernoulli("y", a, observed=data)
8181
trace = pm.sample_smc(2000)
82-
marginals.append(model.marginal_log_likelihood)
82+
marginals.append(trace.report.log_marginal_likelihood)
8383
# compare to the analytical result
84-
assert abs(np.exp(marginals[1] - marginals[0]) - 4.0) <= 1
84+
assert abs(np.exp(np.mean(marginals[1]) - np.mean(marginals[0])) - 4.0) <= 1
8585

8686
def test_start(self):
8787
with pm.Model() as model:
@@ -110,7 +110,9 @@ def normal_sim(a, b):
110110

111111
def test_one_gaussian(self):
112112
with self.SMABC_test:
113-
trace = pm.sample_smc(draws=1000, kernel="ABC", sum_stat="sorted", epsilon=1)
113+
trace = pm.sample_smc(
114+
draws=1000, kernel="ABC", sum_stat="sorted", epsilon=1, parallel=False
115+
)
114116

115117
np.testing.assert_almost_equal(self.data.mean(), trace["a"].mean(), decimal=2)
116118
np.testing.assert_almost_equal(self.data.std(), trace["b"].mean(), decimal=1)

0 commit comments

Comments
 (0)