Skip to content

Commit 1da02c4

Browse files
committed
add ess, remove multiprocessing
1 parent f19079c commit 1da02c4

File tree

2 files changed

+17
-64
lines changed

2 files changed

+17
-64
lines changed

pymc3/smc/sample_smc.py

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,16 @@
1818

1919

2020
def sample_smc(
21-
draws=1000,
21+
draws=2000,
2222
kernel="metropolis",
2323
n_steps=25,
24-
parallel=False,
2524
start=None,
26-
cores=None,
2725
tune_steps=True,
2826
p_acc_rate=0.99,
2927
threshold=0.5,
3028
epsilon=1.0,
3129
dist_func="gaussian_kernel",
3230
sum_stat="identity",
33-
progressbar=False,
3431
model=None,
3532
random_seed=-1,
3633
):
@@ -49,15 +46,9 @@ def sample_smc(
4946
The number of steps of each Markov Chain. If ``tune_steps == True`` ``n_steps`` will be used
5047
for the first stage and for the others it will be determined automatically based on the
5148
acceptance rate and `p_acc_rate`, the max number of steps is ``n_steps``.
52-
parallel: bool
53-
Distribute computations across cores if the number of cores is larger than 1.
54-
Defaults to False.
5549
start: dict, or array of dict
5650
Starting point in parameter space. It should be a list of dict with length `chains`.
5751
When None (default) the starting point is sampled from the prior distribution.
58-
cores: int
59-
The number of chains to run in parallel. If ``None`` (default), it will be automatically
60-
set to the number of CPUs in the system.
6152
tune_steps: bool
6253
Whether to compute the number of steps automatically or not. Defaults to True
6354
p_acc_rate: float
@@ -75,8 +66,6 @@ def sample_smc(
7566
sum_stat: str or callable
7667
Summary statistics. Available options are ``indentity``, ``sorted``, ``mean``, ``median``.
7768
If a callable is based it should return a number or a 1d numpy array.
78-
progressbar: bool
79-
Flag for displaying a progress bar. Defaults to False.
8069
model: Model (optional if in ``with`` context)).
8170
random_seed: int
8271
random seed
@@ -130,16 +119,13 @@ def sample_smc(
130119
draws=draws,
131120
kernel=kernel,
132121
n_steps=n_steps,
133-
parallel=parallel,
134122
start=start,
135-
cores=cores,
136123
tune_steps=tune_steps,
137124
p_acc_rate=p_acc_rate,
138125
threshold=threshold,
139126
epsilon=epsilon,
140127
dist_func=dist_func,
141128
sum_stat=sum_stat,
142-
progressbar=progressbar,
143129
model=model,
144130
random_seed=random_seed,
145131
)
@@ -159,19 +145,16 @@ def sample_smc(
159145
stage, smc.beta, smc.n_steps, smc.acc_rate
160146
)
161147
)
162-
smc.resample()
163148
smc.update_proposal()
164-
if stage > 0:
149+
smc.resample()
150+
for _ in range(2):
151+
smc.mutate()
165152
smc.tune()
166-
smc.mutate()
167153
stage += 1
168154

169-
if smc.parallel and smc.cores > 1:
170-
smc.pool.close()
171-
smc.pool.join()
172-
173155
trace = smc.posterior_to_trace()
174156
trace.report._n_draws = smc.draws
175157
trace.report._n_tune = 0
176158
trace.report._t_sampling = time.time() - t1
159+
trace.report.ess = smc.ess
177160
return trace

pymc3/smc/smc.py

Lines changed: 12 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,14 @@
1616

1717
import numpy as np
1818
from scipy.special import logsumexp
19-
from fastprogress.fastprogress import progress_bar
20-
import multiprocessing as mp
2119
import warnings
2220
from theano import function as theano_function
21+
from arviz import psislw
2322

2423
from ..model import modelcontext, Point
2524
from ..parallel_sampling import _cpu_count
26-
from ..theanof import inputvars, make_shared_replacements
27-
from ..vartypes import discrete_types
25+
from ..theanof import floatX, inputvars, make_shared_replacements, join_nonshared_inputs
2826
from ..sampling import sample_prior_predictive
29-
from ..theanof import floatX, join_nonshared_inputs
30-
from ..step_methods.arraystep import metrop_select
31-
from ..step_methods.metropolis import MultivariateNormalProposal
3227
from ..backends.ndarray import NDArray
3328
from ..backends.base import MultiTrace
3429

@@ -41,36 +36,30 @@
4136
class SMC:
4237
def __init__(
4338
self,
44-
draws=1000,
39+
draws=2000,
4540
kernel="metropolis",
4641
n_steps=25,
47-
parallel=False,
4842
start=None,
49-
cores=None,
5043
tune_steps=True,
5144
p_acc_rate=0.99,
5245
threshold=0.5,
5346
epsilon=1.0,
5447
dist_func="absolute_error",
5548
sum_stat="Identity",
56-
progressbar=False,
5749
model=None,
5850
random_seed=-1,
5951
):
6052

6153
self.draws = draws
6254
self.kernel = kernel
6355
self.n_steps = n_steps
64-
self.parallel = parallel
6556
self.start = start
66-
self.cores = cores
6757
self.tune_steps = tune_steps
6858
self.p_acc_rate = p_acc_rate
6959
self.threshold = threshold
7060
self.epsilon = epsilon
7161
self.dist_func = dist_func
7262
self.sum_stat = sum_stat
73-
self.progressbar = progressbar
7463
self.model = model
7564
self.random_seed = random_seed
7665

@@ -79,23 +68,16 @@ def __init__(
7968
if self.random_seed != -1:
8069
np.random.seed(self.random_seed)
8170

82-
if self.cores is None:
83-
self.cores = _cpu_count()
84-
8571
self.beta = 0
8672
self.max_steps = n_steps
8773
self.proposed = draws * n_steps
8874
self.acc_rate = 1
8975
self.acc_per_chain = np.ones(self.draws)
90-
self.model.marginal_log_likelihood = 0
76+
self.model.log_marginal_likelihood = 0
9177
self.variables = inputvars(self.model.vars)
9278
self.dimension = sum(v.dsize for v in self.variables)
93-
self.scalings = np.ones(self.draws) * min(1, 2.38 ** 2 / self.dimension)
94-
self.discrete = np.concatenate(
95-
[[v.dtype in discrete_types] * (v.dsize or 1) for v in self.variables]
96-
)
97-
self.any_discrete = self.discrete.any()
98-
self.all_discrete = self.discrete.all()
79+
self.scalings = np.ones(self.draws) * 2.38 / (self.dimension) ** 0.5
80+
self.weights = np.ones(self.draws) / self.draws
9981

10082
def initialize_population(self):
10183
"""
@@ -153,17 +135,8 @@ def initialize_logp(self):
153135
"""
154136
initialize the prior and likelihood log probabilities
155137
"""
156-
if self.parallel and self.cores > 1:
157-
self.pool = mp.Pool(processes=self.cores)
158-
priors = self.pool.starmap(
159-
self.prior_logp_func, [(sample,) for sample in self.posterior]
160-
)
161-
likelihoods = self.pool.starmap(
162-
self.likelihood_logp_func, [(sample,) for sample in self.posterior]
163-
)
164-
else:
165-
priors = [self.prior_logp_func(sample) for sample in self.posterior]
166-
likelihoods = [self.likelihood_logp_func(sample) for sample in self.posterior]
138+
priors = [self.prior_logp_func(sample) for sample in self.posterior]
139+
likelihoods = [self.likelihood_logp_func(sample) for sample in self.posterior]
167140

168141
self.prior_logp = np.array(priors).squeeze()
169142
self.likelihood_logp = np.array(likelihoods).squeeze()
@@ -192,11 +165,9 @@ def update_weights_beta(self):
192165
new_beta = 1
193166
log_weights_un = (new_beta - old_beta) * self.likelihood_logp
194167
log_weights = log_weights_un - logsumexp(log_weights_un)
168+
self.ess = np.exp(-logsumexp(log_weights * 2))
195169

196-
ll_max = np.max(log_weights_un)
197-
self.model.marginal_log_likelihood += ll_max + np.log(
198-
np.exp(log_weights_un - ll_max).mean()
199-
)
170+
self.model.log_marginal_likelihood += logsumexp(log_weights_un) - np.log(self.draws)
200171
self.beta = new_beta
201172
self.weights = np.exp(log_weights)
202173

@@ -218,13 +189,12 @@ def update_proposal(self):
218189
"""
219190
Update proposal based on the covariance matrix from tempered posterior
220191
"""
221-
cov = np.cov(self.posterior, bias=False, rowvar=0)
192+
cov = np.cov(self.posterior, ddof=0, aweights=self.weights, rowvar=0)
222193
cov = np.atleast_2d(cov)
223194
cov += 1e-6 * np.eye(cov.shape[0])
224195
if np.isnan(cov).any() or np.isinf(cov).any():
225196
raise ValueError('Sample covariances not valid! Likely "draws" is too small!')
226197
self.cov = cov
227-
self.proposal = MultivariateNormalProposal(cov)
228198

229199
def tune(self):
230200
"""
@@ -244,8 +214,8 @@ def tune(self):
244214
self.proposed = self.draws * self.n_steps
245215

246216
def mutate(self):
247-
248217
ac_ = np.empty((self.n_steps, self.draws))
218+
249219
proposals = (
250220
np.random.multivariate_normal(
251221
np.zeros(self.dimension), self.cov, size=(self.n_steps, self.draws)

0 commit comments

Comments
 (0)