Skip to content

Commit da7c5df

Browse files
ricardoV94twiecki
authored andcommitted
Speedup model compilation in slow sampling tests
Add a specific FAST_COMPILE mode that skips canonicalization and specialization, while keeping rewrites that are required from aeppl and pymc for proper sampling. This mode is used in tests that take a long time to compile and for which numerical accuracy is not important (e.g., because we care only about the shape of the draws or deterministics of observed values)
1 parent 333f7f3 commit da7c5df

File tree

3 files changed

+71
-34
lines changed

3 files changed

+71
-34
lines changed

pymc/tests/helpers.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@
2121
import numpy.random as nr
2222

2323
from aesara.gradient import verify_grad as at_verify_grad
24+
from aesara.graph.opt import in2out
2425
from aesara.sandbox.rng_mrg import MRG_RandomStream as RandomStream
2526

26-
from pymc.aesaraf import at_rng, set_at_rng
27+
from pymc.aesaraf import at_rng, local_check_parameter_to_ninf_switch, set_at_rng
2728

2829

2930
class SeededTest:
@@ -132,3 +133,18 @@ def assert_random_state_equal(state1, state2):
132133
np.testing.assert_array_equal(field1, field2)
133134
else:
134135
assert field1 == field2
136+
137+
138+
# This mode can be used for tests where model compilations takes the bulk of the runtime
139+
# AND where we don't care about posterior numerical or sampling stability (e.g., when
140+
# all that matters are the shape of the draws or deterministic values of observed data).
141+
# DO NOT USE UNLESS YOU HAVE A GOOD REASON TO!
142+
fast_unstable_sampling_mode = (
143+
aesara.compile.mode.FAST_COMPILE
144+
# Remove slow rewrite phases
145+
.excluding("canonicalize", "specialize")
146+
# Include necessary rewrites for proper logp handling
147+
.including("remove_TransformedVariables").register(
148+
(in2out(local_check_parameter_to_ninf_switch), -1)
149+
)
150+
)

pymc/tests/test_ode.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
from pymc.ode import DifferentialEquation
2727
from pymc.ode.utils import augment_system
28+
from pymc.tests.helpers import fast_unstable_sampling_mode
2829

2930
IS_FLOAT32 = aesara.config.floatX == "float32"
3031
IS_WINDOWS = sys.platform == "win32"
@@ -291,11 +292,13 @@ def system(y, t, p):
291292
sigma = pm.HalfCauchy("sigma", 1)
292293
forward = ode_model(theta=[alpha], y0=[y0])
293294
y = pm.LogNormal("y", mu=pm.math.log(forward), sd=sigma, observed=yobs)
294-
idata = pm.sample(100, tune=0, chains=1)
295295

296-
assert idata.posterior["alpha"].shape == (1, 100)
297-
assert idata.posterior["y0"].shape == (1, 100)
298-
assert idata.posterior["sigma"].shape == (1, 100)
296+
with aesara.config.change_flags(mode=fast_unstable_sampling_mode):
297+
idata = pm.sample(50, tune=0, chains=1)
298+
299+
assert idata.posterior["alpha"].shape == (1, 50)
300+
assert idata.posterior["y0"].shape == (1, 50)
301+
assert idata.posterior["sigma"].shape == (1, 50)
299302

300303
def test_scalar_ode_2_param(self):
301304
"""Test running model for a scalar ODE with 2 parameters"""
@@ -321,12 +324,13 @@ def system(y, t, p):
321324
forward = ode_model(theta=[alpha, beta], y0=[y0])
322325
y = pm.LogNormal("y", mu=pm.math.log(forward), sd=sigma, observed=yobs)
323326

324-
idata = pm.sample(100, tune=0, chains=1)
327+
with aesara.config.change_flags(mode=fast_unstable_sampling_mode):
328+
idata = pm.sample(50, tune=0, chains=1)
325329

326-
assert idata.posterior["alpha"].shape == (1, 100)
327-
assert idata.posterior["beta"].shape == (1, 100)
328-
assert idata.posterior["y0"].shape == (1, 100)
329-
assert idata.posterior["sigma"].shape == (1, 100)
330+
assert idata.posterior["alpha"].shape == (1, 50)
331+
assert idata.posterior["beta"].shape == (1, 50)
332+
assert idata.posterior["y0"].shape == (1, 50)
333+
assert idata.posterior["sigma"].shape == (1, 50)
330334

331335
def test_vector_ode_1_param(self):
332336
"""Test running model for a vector ODE with 1 parameter"""
@@ -362,10 +366,11 @@ def system(y, t, p):
362366
forward = ode_model(theta=[R], y0=[0.99, 0.01])
363367
y = pm.LogNormal("y", mu=pm.math.log(forward), sd=sigma, observed=yobs)
364368

365-
idata = pm.sample(100, tune=0, chains=1)
369+
with aesara.config.change_flags(mode=fast_unstable_sampling_mode):
370+
idata = pm.sample(50, tune=0, chains=1)
366371

367-
assert idata.posterior["R"].shape == (1, 100)
368-
assert idata.posterior["sigma"].shape == (1, 100, 2)
372+
assert idata.posterior["R"].shape == (1, 50)
373+
assert idata.posterior["sigma"].shape == (1, 50, 2)
369374

370375
def test_vector_ode_2_param(self):
371376
"""Test running model for a vector ODE with 2 parameters"""
@@ -402,8 +407,9 @@ def system(y, t, p):
402407
forward = ode_model(theta=[beta, gamma], y0=[0.99, 0.01])
403408
y = pm.LogNormal("y", mu=pm.math.log(forward), sd=sigma, observed=yobs)
404409

405-
idata = pm.sample(100, tune=0, chains=1)
410+
with aesara.config.change_flags(mode=fast_unstable_sampling_mode):
411+
idata = pm.sample(50, tune=0, chains=1)
406412

407-
assert idata.posterior["beta"].shape == (1, 100)
408-
assert idata.posterior["gamma"].shape == (1, 100)
409-
assert idata.posterior["sigma"].shape == (1, 100, 2)
413+
assert idata.posterior["beta"].shape == (1, 50)
414+
assert idata.posterior["gamma"].shape == (1, 50)
415+
assert idata.posterior["sigma"].shape == (1, 50, 2)

pymc/tests/test_sampling.py

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from pymc.backends.base import MultiTrace
3636
from pymc.backends.ndarray import NDArray
3737
from pymc.exceptions import IncorrectArgumentsError, SamplingError
38-
from pymc.tests.helpers import SeededTest
38+
from pymc.tests.helpers import SeededTest, fast_unstable_sampling_mode
3939
from pymc.tests.models import simple_init
4040

4141

@@ -665,7 +665,8 @@ def test_model_not_drawable_prior(self):
665665
with model:
666666
mu = pm.HalfFlat("sigma")
667667
pm.Poisson("foo", mu=mu, observed=data)
668-
idata = pm.sample(tune=1000)
668+
with aesara.config.change_flags(mode=fast_unstable_sampling_mode):
669+
idata = pm.sample(tune=10, draws=40, chains=1)
669670

670671
with model:
671672
with pytest.raises(NotImplementedError) as excinfo:
@@ -718,12 +719,15 @@ def test_deterministic_of_observed(self):
718719
out_diff = in_1 + in_2
719720
pm.Deterministic("out", out_diff)
720721

721-
trace = pm.sample(
722-
100,
723-
chains=nchains,
724-
return_inferencedata=False,
725-
compute_convergence_checks=False,
726-
)
722+
with aesara.config.change_flags(mode=fast_unstable_sampling_mode):
723+
trace = pm.sample(
724+
tune=100,
725+
draws=100,
726+
chains=nchains,
727+
step=pm.Metropolis(),
728+
return_inferencedata=False,
729+
compute_convergence_checks=False,
730+
)
727731

728732
rtol = 1e-5 if aesara.config.floatX == "float64" else 1e-4
729733

@@ -754,11 +758,14 @@ def test_deterministic_of_observed_modified_interface(self):
754758
out_diff = in_1 + in_2
755759
pm.Deterministic("out", out_diff)
756760

757-
trace = pm.sample(
758-
100,
759-
return_inferencedata=False,
760-
compute_convergence_checks=False,
761-
)
761+
with aesara.config.change_flags(mode=fast_unstable_sampling_mode):
762+
trace = pm.sample(
763+
tune=100,
764+
draws=100,
765+
step=pm.Metropolis(),
766+
return_inferencedata=False,
767+
compute_convergence_checks=False,
768+
)
762769
varnames = [v for v in trace.varnames if v != "out"]
763770
ppc_trace = [
764771
dict(zip(varnames, row)) for row in zip(*(trace.get_values(v) for v in varnames))
@@ -779,7 +786,10 @@ def test_variable_type(self):
779786
mu = pm.HalfNormal("mu", 1)
780787
a = pm.Normal("a", mu=mu, sigma=2, observed=np.array([1, 2]))
781788
b = pm.Poisson("b", mu, observed=np.array([1, 2]))
782-
trace = pm.sample(compute_convergence_checks=False, return_inferencedata=False)
789+
with aesara.config.change_flags(mode=fast_unstable_sampling_mode):
790+
trace = pm.sample(
791+
tune=10, draws=10, compute_convergence_checks=False, return_inferencedata=False
792+
)
783793

784794
with model:
785795
ppc = pm.sample_posterior_predictive(trace, return_inferencedata=False, samples=1)
@@ -998,9 +1008,14 @@ def test_multivariate2(self):
9981008
with pm.Model() as dm_model:
9991009
probs = pm.Dirichlet("probs", a=np.ones(6))
10001010
obs = pm.Multinomial("obs", n=100, p=probs, observed=mn_data)
1001-
burned_trace = pm.sample(
1002-
20, tune=10, cores=1, return_inferencedata=False, compute_convergence_checks=False
1003-
)
1011+
with aesara.config.change_flags(mode=fast_unstable_sampling_mode):
1012+
burned_trace = pm.sample(
1013+
tune=10,
1014+
draws=20,
1015+
chains=1,
1016+
return_inferencedata=False,
1017+
compute_convergence_checks=False,
1018+
)
10041019
sim_priors = pm.sample_prior_predictive(
10051020
return_inferencedata=False, samples=20, model=dm_model
10061021
)

0 commit comments

Comments
 (0)