Skip to content

Commit c3c7d84

Browse files
ArmavicaricardoV94
authored andcommitted
Split test_shared into test_model & test_sampling
1 parent 84152c7 commit c3c7d84

File tree

4 files changed

+45
-65
lines changed

4 files changed

+45
-65
lines changed

.github/workflows/tests.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ jobs:
6363
- |
6464
pymc/tests/tuning/test_scaling.py
6565
pymc/tests/tuning/test_starting.py
66-
pymc/tests/test_shared.py
6766
pymc/tests/test_sampling.py
6867
pymc/tests/distributions/test_dist_math.py
6968
pymc/tests/distributions/test_transform.py

pymc/tests/test_model.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1373,3 +1373,14 @@ def test_missing_symmetric():
13731373
logp_inputs = list(graph_inputs([logp]))
13741374
assert x_obs_vv in logp_inputs
13751375
assert x_unobs_vv in logp_inputs
1376+
1377+
1378+
class TestShared(SeededTest):
1379+
def test_deterministic(self):
1380+
with pm.Model() as model:
1381+
data_values = np.array([0.5, 0.4, 5, 2])
1382+
X = aesara.shared(np.asarray(data_values, dtype=aesara.config.floatX), borrow=True)
1383+
pm.Normal("y", 0, 1, observed=X)
1384+
assert np.all(
1385+
np.isclose(model.compile_logp(sum=False)({}), st.norm().logpdf(data_values))
1386+
)

pymc/tests/test_sampling.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2609,3 +2609,37 @@ def test_float32(self):
26092609
with warnings.catch_warnings():
26102610
warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning)
26112611
pm.sample(draws=10, tune=10, chains=1, step=sampler())
2612+
2613+
2614+
class TestShared(SeededTest):
2615+
def test_sample(self):
2616+
x = np.random.normal(size=100)
2617+
y = x + np.random.normal(scale=1e-2, size=100)
2618+
2619+
x_pred = np.linspace(-3, 3, 200)
2620+
2621+
x_shared = aesara.shared(x)
2622+
2623+
with pm.Model() as model:
2624+
b = pm.Normal("b", 0.0, 10.0)
2625+
pm.Normal("obs", b * x_shared, np.sqrt(1e-2), observed=y, shape=x_shared.shape)
2626+
prior_trace0 = pm.sample_prior_predictive(1000)
2627+
2628+
idata = pm.sample(1000, tune=1000, chains=1)
2629+
pp_trace0 = pm.sample_posterior_predictive(idata)
2630+
2631+
x_shared.set_value(x_pred)
2632+
prior_trace1 = pm.sample_prior_predictive(1000)
2633+
pp_trace1 = pm.sample_posterior_predictive(idata)
2634+
2635+
assert prior_trace0.prior["b"].shape == (1, 1000)
2636+
assert prior_trace0.prior_predictive["obs"].shape == (1, 1000, 100)
2637+
np.testing.assert_allclose(
2638+
x, pp_trace0.posterior_predictive["obs"].mean(("chain", "draw")), atol=1e-1
2639+
)
2640+
2641+
assert prior_trace1.prior["b"].shape == (1, 1000)
2642+
assert prior_trace1.prior_predictive["obs"].shape == (1, 1000, 200)
2643+
np.testing.assert_allclose(
2644+
x_pred, pp_trace1.posterior_predictive["obs"].mean(("chain", "draw")), atol=1e-1
2645+
)

pymc/tests/test_shared.py

Lines changed: 0 additions & 64 deletions
This file was deleted.

0 commit comments

Comments
 (0)