Skip to content

Commit 9b31950

Browse files
committed
Add more tests for stats, competence, etc and fix bug with tune flag not being overriden in compound sub-methods
1 parent ea34371 commit 9b31950

File tree

2 files changed

+124
-10
lines changed

2 files changed

+124
-10
lines changed

pymc3/step_methods/metropolis.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -921,7 +921,10 @@ class MLDA(ArrayStepShared):
921921
base_scaling : scalar or array
922922
Initial scale factor for base proposal. Defaults to 1.
923923
tune : bool
924-
Flag for tuning for the base proposal. Defaults to True.
924+
Flag for tuning for the base proposal. Defaults to True. Note that
925+
this is overidden by the tune parameter in sample(), i.e. when calling
926+
step=MLDA(tune=False, ...) and then sample(step=step, tune=200, ...),
927+
tuning will be activated for the first 200 steps.
925928
base_tune_interval : int
926929
The frequency of tuning for the base proposal. Defaults to 100
927930
iterations.
@@ -1108,12 +1111,17 @@ def __init__(self, coarse_models, vars=None, base_S=None, base_proposal_dist=Non
11081111

11091112
def astep(self, q0):
11101113
"""One MLDA step, given current sample q0"""
1111-
# Check if tuning has been deactivated and if yes,
1114+
# Check if the tuning flag has been changed and if yes,
11121115
# change the proposal's tuning flag and reset self.accepted
1113-
# This is initially triggered in the highest-level MLDA step
1114-
# method (within iter_sample) and then propagates to all levels.
1116+
# This is triggered by iter_sample while the highest-level MLDA step
1117+
# method is running. It then propagates to all levels.
11151118
if self.proposal_dist.tune != self.tune:
11161119
self.proposal_dist.tune = self.tune
1120+
# set tune in sub-methods of compound stepper explicitly because
1121+
# it is not set within sample() (only the CompoundStep's tune flag is)
1122+
if isinstance(self.next_step_method, CompoundStep):
1123+
for method in self.next_step_method.methods:
1124+
method.tune = self.tune
11171125
self.accepted = 0
11181126

11191127
# Convert current sample from numpy array ->

pymc3/tests/test_step.py

Lines changed: 112 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1095,8 +1095,8 @@ def test_nonparallelized_chains_are_random(self):
10951095
samples = np.array(trace.get_values("x", combine=False))[:, 5]
10961096
assert len(set(samples)) == 2, \
10971097
"Non parallelized {} " "chains are identical.".format(
1098-
stepper
1099-
)
1098+
stepper
1099+
)
11001100

11011101
def test_parallelized_chains_are_random(self):
11021102
"""Test that parallel chain are
@@ -1115,8 +1115,8 @@ def test_parallelized_chains_are_random(self):
11151115
samples = np.array(trace.get_values("x", combine=False))[:, 5]
11161116
assert len(set(samples)) == 2, \
11171117
"Parallelized {} " "chains are identical.".format(
1118-
stepper
1119-
)
1118+
stepper
1119+
)
11201120

11211121
def test_acceptance_rate_against_coarseness(self):
11221122
"""Test that the acceptance rate increases
@@ -1160,8 +1160,9 @@ def test_mlda_non_blocked(self):
11601160
base_blocked=False).next_step_method,
11611161
CompoundStep)
11621162

1163-
def test_blocked(self):
1164-
"""Test the type of base sampler instantiated when switching base_blocked flag"""
1163+
def test_mlda_blocked(self):
1164+
"""Test the type of base sampler instantiated
1165+
when switching base_blocked flag"""
11651166
_, model = simple_2model_continuous()
11661167
_, model_coarse = simple_2model_continuous()
11671168
with model:
@@ -1173,3 +1174,108 @@ def test_blocked(self):
11731174
base_blocked=True).next_step_method,
11741175
Metropolis)
11751176

1177+
def test_tuning_and_scaling_on(self):
1178+
"""Test that tune and base_scaling change as expected when
1179+
tuning is on."""
1180+
np.random.seed(1234)
1181+
ts = 100
1182+
_, model = simple_2model_continuous()
1183+
_, model_coarse = simple_2model_continuous()
1184+
with model:
1185+
trace = sample(
1186+
tune=ts,
1187+
draws=20,
1188+
step=MLDA(coarse_models=[model_coarse],
1189+
base_tune_interval=50,
1190+
base_scaling=100.),
1191+
chains=1,
1192+
discard_tuned_samples=False,
1193+
random_seed=1234
1194+
)
1195+
1196+
assert trace.get_sampler_stats('tune', chains=0)[0]
1197+
assert trace.get_sampler_stats('tune', chains=0)[ts - 1]
1198+
assert not trace.get_sampler_stats('tune', chains=0)[ts]
1199+
assert not trace.get_sampler_stats('tune', chains=0)[-1]
1200+
assert trace.get_sampler_stats('base_scaling_x', chains=0)[0] == 100.
1201+
assert trace.get_sampler_stats('base_scaling_y_logodds__', chains=0)[0] == 100.
1202+
assert trace.get_sampler_stats('base_scaling_x', chains=0)[-1] < 100.
1203+
assert trace.get_sampler_stats('base_scaling_y_logodds__', chains=0)[-1] < 100.
1204+
1205+
def test_tuning_and_scaling_off(self):
1206+
"""Test that tuning is deactivated when sample()'s tune=0 and that
1207+
MLDA's tune=False is overridden by sample()'s tune."""
1208+
np.random.seed(12345)
1209+
_, model = simple_2model_continuous()
1210+
_, model_coarse = simple_2model_continuous()
1211+
1212+
ts_0 = 0
1213+
with model:
1214+
trace_0 = sample(
1215+
tune=ts_0,
1216+
draws=100,
1217+
step=MLDA(coarse_models=[model_coarse],
1218+
base_tune_interval=50,
1219+
base_scaling=100.,
1220+
tune=False),
1221+
chains=1,
1222+
discard_tuned_samples=False,
1223+
random_seed=12345
1224+
)
1225+
1226+
ts_1 = 100
1227+
with model:
1228+
trace_1 = sample(
1229+
tune=ts_1,
1230+
draws=20,
1231+
step=MLDA(coarse_models=[model_coarse],
1232+
base_tune_interval=50,
1233+
base_scaling=100.,
1234+
tune=False),
1235+
chains=1,
1236+
discard_tuned_samples=False,
1237+
random_seed=12345
1238+
)
1239+
1240+
assert not trace_0.get_sampler_stats('tune', chains=0)[0]
1241+
assert not trace_0.get_sampler_stats('tune', chains=0)[-1]
1242+
assert trace_0.get_sampler_stats('base_scaling_x', chains=0)[0] == \
1243+
trace_0.get_sampler_stats('base_scaling_x', chains=0)[-1] == 100.
1244+
1245+
assert trace_1.get_sampler_stats('tune', chains=0)[0]
1246+
assert trace_1.get_sampler_stats('tune', chains=0)[ts_1 - 1]
1247+
assert not trace_1.get_sampler_stats('tune', chains=0)[ts_1]
1248+
assert not trace_1.get_sampler_stats('tune', chains=0)[-1]
1249+
assert trace_1.get_sampler_stats('base_scaling_x', chains=0)[0] == 100.
1250+
assert trace_1.get_sampler_stats('base_scaling_y_logodds__', chains=0)[0] == 100.
1251+
assert trace_1.get_sampler_stats('base_scaling_x', chains=0)[-1] < 100.
1252+
assert trace_1.get_sampler_stats('base_scaling_y_logodds__', chains=0)[-1] < 100.
1253+
1254+
def test_trace_length(self):
1255+
"""Check if trace length is as expected."""
1256+
tune = 100
1257+
draws = 50
1258+
with Model() as coarse_model:
1259+
Normal('n', 0, 2.2, shape=(3,))
1260+
with Model():
1261+
Normal('n', 0, 2, shape=(3,))
1262+
step = MLDA(coarse_models=[coarse_model])
1263+
trace = sample(
1264+
tune=tune,
1265+
draws=draws,
1266+
step=step,
1267+
chains=1,
1268+
discard_tuned_samples=False
1269+
)
1270+
assert len(trace) == tune + draws
1271+
1272+
@pytest.mark.parametrize('variable,has_grad,outcome',
1273+
[('n', True, 1), ('n', False, 1), ('b', True, 0), ('b', False, 0)])
1274+
def test_competence(self, variable, has_grad, outcome):
1275+
"""Test if competence function returns expected
1276+
results for different models"""
1277+
with Model() as pmodel:
1278+
Normal('n', 0, 2, shape=(3,))
1279+
Binomial('b', n=2, p=0.3)
1280+
assert MLDA.competence(pmodel[variable], has_grad=has_grad) == outcome
1281+
pass

0 commit comments

Comments
 (0)