Skip to content

Commit 4ef4f97

Browse files
canyon289Spaak
authored andcommitted
Add eight schools pickle test
1 parent 3ae43cf commit 4ef4f97

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed

pymc3/tests/test_model.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,3 +421,24 @@ def test_tempered_logp_dlogp():
421421

422422
npt.assert_allclose(func_nograd(x), func(x)[0])
423423
npt.assert_allclose(func_temp_nograd(x), func_temp(x)[0])
424+
425+
426+
import pickle
427+
def test_model_pickle(tmpdir):
428+
"""Tests that PyMC3 models are pickleable"""
429+
430+
# Data of the Eight Schools Model
431+
J = 8
432+
y = np.array([28., 8., -3., 7., -1., 1., 18., 12.])
433+
sigma = np.array([15., 10., 16., 11., 9., 11., 10., 18.])
434+
435+
with pm.Model() as model:
436+
mu = pm.Normal('mu', mu=0, sigma=5)
437+
tau = pm.HalfCauchy('tau', beta=5)
438+
theta = pm.Normal('theta', mu=mu, sigma=tau, shape=J)
439+
obs = pm.Normal('obs', mu=theta, sigma=sigma, observed=y)
440+
# t = pm.sample(draws=100)
441+
442+
file_path = tmpdir.join("model.p")
443+
with open(file_path, 'wb') as buff:
444+
pickle.dump(model, buff)

0 commit comments

Comments
 (0)