Skip to content

Commit f19079c

Browse files
committed
first attempt to vectorize smc kernel
1 parent d0de763 commit f19079c

File tree

1 file changed

+43
-118
lines changed

1 file changed

+43
-118
lines changed

pymc3/smc/smc.py

Lines changed: 43 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@ def __init__(
8989
self.acc_per_chain = np.ones(self.draws)
9090
self.model.marginal_log_likelihood = 0
9191
self.variables = inputvars(self.model.vars)
92-
dimension = sum(v.dsize for v in self.variables)
93-
self.scalings = np.ones(self.draws) * min(1, 2.38 ** 2 / dimension)
92+
self.dimension = sum(v.dsize for v in self.variables)
93+
self.scalings = np.ones(self.draws) * min(1, 2.38 ** 2 / self.dimension)
9494
self.discrete = np.concatenate(
9595
[[v.dtype in discrete_types] * (v.dsize or 1) for v in self.variables]
9696
)
@@ -128,14 +128,14 @@ def setup_kernel(self):
128128
Set up the likelihood logp function based on the chosen kernel
129129
"""
130130
shared = make_shared_replacements(self.variables, self.model)
131-
self.prior_logp = logp_forw([self.model.varlogpt], self.variables, shared)
131+
self.prior_logp_func = logp_forw([self.model.varlogpt], self.variables, shared)
132132

133133
if self.kernel.lower() == "abc":
134134
warnings.warn(EXPERIMENTAL_WARNING)
135135
if len(self.model.observed_RVs) != 1:
136136
warnings.warn("SMC-ABC only works properly with models with one observed variable")
137137
simulator = self.model.observed_RVs[0]
138-
self.likelihood_logp = PseudoLikelihood(
138+
self.likelihood_logp_func = PseudoLikelihood(
139139
self.epsilon,
140140
simulator.observations,
141141
simulator.distribution.function,
@@ -147,24 +147,26 @@ def setup_kernel(self):
147147
self.sum_stat,
148148
)
149149
elif self.kernel.lower() == "metropolis":
150-
self.likelihood_logp = logp_forw([self.model.datalogpt], self.variables, shared)
150+
self.likelihood_logp_func = logp_forw([self.model.datalogpt], self.variables, shared)
151151

152152
def initialize_logp(self):
153153
"""
154154
initialize the prior and likelihood log probabilities
155155
"""
156156
if self.parallel and self.cores > 1:
157157
self.pool = mp.Pool(processes=self.cores)
158-
priors = self.pool.starmap(self.prior_logp, [(sample,) for sample in self.posterior])
158+
priors = self.pool.starmap(
159+
self.prior_logp_func, [(sample,) for sample in self.posterior]
160+
)
159161
likelihoods = self.pool.starmap(
160-
self.likelihood_logp, [(sample,) for sample in self.posterior]
162+
self.likelihood_logp_func, [(sample,) for sample in self.posterior]
161163
)
162164
else:
163-
priors = [self.prior_logp(sample) for sample in self.posterior]
164-
likelihoods = [self.likelihood_logp(sample) for sample in self.posterior]
165+
priors = [self.prior_logp_func(sample) for sample in self.posterior]
166+
likelihoods = [self.likelihood_logp_func(sample) for sample in self.posterior]
165167

166-
self.priors = np.array(priors).squeeze()
167-
self.likelihoods = np.array(likelihoods).squeeze()
168+
self.prior_logp = np.array(priors).squeeze()
169+
self.likelihood_logp = np.array(likelihoods).squeeze()
168170

169171
def update_weights_beta(self):
170172
"""
@@ -173,11 +175,11 @@ def update_weights_beta(self):
173175
"""
174176
low_beta = old_beta = self.beta
175177
up_beta = 2.0
176-
rN = int(len(self.likelihoods) * self.threshold)
178+
rN = int(len(self.likelihood_logp) * self.threshold)
177179

178180
while up_beta - low_beta > 1e-6:
179181
new_beta = (low_beta + up_beta) / 2.0
180-
log_weights_un = (new_beta - old_beta) * self.likelihoods
182+
log_weights_un = (new_beta - old_beta) * self.likelihood_logp
181183
log_weights = log_weights_un - logsumexp(log_weights_un)
182184
ESS = int(np.exp(-logsumexp(log_weights * 2)))
183185
if ESS == rN:
@@ -188,7 +190,7 @@ def update_weights_beta(self):
188190
low_beta = new_beta
189191
if new_beta >= 1:
190192
new_beta = 1
191-
log_weights_un = (new_beta - old_beta) * self.likelihoods
193+
log_weights_un = (new_beta - old_beta) * self.likelihood_logp
192194
log_weights = log_weights_un - logsumexp(log_weights_un)
193195

194196
ll_max = np.max(log_weights_un)
@@ -206,9 +208,9 @@ def resample(self):
206208
np.arange(self.draws), size=self.draws, p=self.weights
207209
)
208210
self.posterior = self.posterior[resampling_indexes]
209-
self.priors = self.priors[resampling_indexes]
210-
self.likelihoods = self.likelihoods[resampling_indexes]
211-
self.tempered_logp = self.priors + self.likelihoods * self.beta
211+
self.prior_logp = self.prior_logp[resampling_indexes]
212+
self.likelihood_logp = self.likelihood_logp[resampling_indexes]
213+
self.posterior_logp = self.prior_logp + self.likelihood_logp * self.beta
212214
self.acc_per_chain = self.acc_per_chain[resampling_indexes]
213215
self.scalings = self.scalings[resampling_indexes]
214216

@@ -221,6 +223,7 @@ def update_proposal(self):
221223
cov += 1e-6 * np.eye(cov.shape[0])
222224
if np.isnan(cov).any() or np.isinf(cov).any():
223225
raise ValueError('Sample covariances not valid! Likely "draws" is too small!')
226+
self.cov = cov
224227
self.proposal = MultivariateNormalProposal(cov)
225228

226229
def tune(self):
@@ -241,56 +244,30 @@ def tune(self):
241244
self.proposed = self.draws * self.n_steps
242245

243246
def mutate(self):
244-
"""
245-
Perform mutation step, i.e. apply selected kernel
246-
"""
247-
parameters = (
248-
self.proposal,
249-
self.scalings,
250-
self.any_discrete,
251-
self.all_discrete,
252-
self.discrete,
253-
self.n_steps,
254-
self.prior_logp,
255-
self.likelihood_logp,
256-
self.beta,
257-
)
258-
if self.parallel and self.cores > 1:
259-
results = self.pool.starmap(
260-
metrop_kernel,
261-
[
262-
(
263-
self.posterior[draw],
264-
self.tempered_logp[draw],
265-
self.priors[draw],
266-
self.likelihoods[draw],
267-
draw,
268-
*parameters,
269-
)
270-
for draw in range(self.draws)
271-
],
247+
248+
ac_ = np.empty((self.n_steps, self.draws))
249+
proposals = (
250+
np.random.multivariate_normal(
251+
np.zeros(self.dimension), self.cov, size=(self.n_steps, self.draws)
272252
)
273-
else:
274-
iterator = range(self.draws)
275-
if self.progressbar:
276-
iterator = progress_bar(iterator, display=self.progressbar)
277-
results = [
278-
metrop_kernel(
279-
self.posterior[draw],
280-
self.tempered_logp[draw],
281-
self.priors[draw],
282-
self.likelihoods[draw],
283-
draw,
284-
*parameters,
285-
)
286-
for draw in iterator
287-
]
288-
posterior, acc_list, priors, likelihoods = zip(*results)
289-
self.posterior = np.array(posterior)
290-
self.priors = np.array(priors)
291-
self.likelihoods = np.array(likelihoods)
292-
self.acc_per_chain = np.array(acc_list)
293-
self.acc_rate = np.mean(acc_list)
253+
* self.scalings[:, None]
254+
)
255+
log_R = np.log(np.random.rand(self.n_steps, self.draws))
256+
257+
for n_step in range(self.n_steps):
258+
proposal = self.posterior + proposals[n_step]
259+
ll = np.array([self.likelihood_logp_func(prop) for prop in proposal])
260+
pl = np.array([self.prior_logp_func(prop) for prop in proposal])
261+
proposal_logp = pl + ll * self.beta
262+
accepted = log_R[n_step] < (proposal_logp - self.posterior_logp)
263+
ac_[n_step] = accepted
264+
self.posterior[accepted] = proposal[accepted]
265+
self.posterior_logp[accepted] = proposal_logp[accepted]
266+
self.prior_logp[accepted] = pl[accepted]
267+
self.likelihood_logp[accepted] = ll[accepted]
268+
269+
self.acc_per_chain = np.mean(ac_, axis=0)
270+
self.acc_rate = np.mean(ac_)
294271

295272
def posterior_to_trace(self):
296273
"""
@@ -313,58 +290,6 @@ def posterior_to_trace(self):
313290
return MultiTrace([strace])
314291

315292

316-
def metrop_kernel(
317-
q_old,
318-
old_tempered_logp,
319-
old_prior,
320-
old_likelihood,
321-
draw,
322-
proposal,
323-
scalings,
324-
any_discrete,
325-
all_discrete,
326-
discrete,
327-
n_steps,
328-
prior_logp,
329-
likelihood_logp,
330-
beta,
331-
):
332-
"""
333-
Metropolis kernel
334-
"""
335-
deltas = np.squeeze(proposal(n_steps) * scalings[draw])
336-
337-
accepted = 0
338-
for n_step in range(n_steps):
339-
delta = deltas[n_step]
340-
341-
if any_discrete:
342-
if all_discrete:
343-
delta = np.round(delta, 0).astype("int64")
344-
q_old = q_old.astype("int64")
345-
q_new = (q_old + delta).astype("int64")
346-
else:
347-
delta[discrete] = np.round(delta[discrete], 0)
348-
q_new = floatX(q_old + delta)
349-
else:
350-
q_new = floatX(q_old + delta)
351-
352-
ll = likelihood_logp(q_new)
353-
pl = prior_logp(q_new)
354-
355-
new_tempered_logp = pl + ll * beta
356-
357-
q_old, accept = metrop_select(new_tempered_logp - old_tempered_logp, q_new, q_old)
358-
359-
if accept:
360-
accepted += 1
361-
old_prior = pl
362-
old_likelihood = ll
363-
old_tempered_logp = new_tempered_logp
364-
365-
return q_old, accepted / n_steps, old_prior, old_likelihood
366-
367-
368293
def logp_forw(out_vars, vars, shared):
369294
"""Compile Theano function of the model and the input and output variables.
370295

0 commit comments

Comments
 (0)