Skip to content

Commit d2b1ed2

Browse files
maresbmichaelosthege
authored andcommitted
Fix return_inferencedata warnings in tests
1 parent 135e510 commit d2b1ed2

19 files changed

+205
-83
lines changed

benchmarks/benchmarks/benchmarks.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def time_overhead_sample(self, step):
9898
random_seed=1,
9999
progressbar=False,
100100
compute_convergence_checks=False,
101+
return_inferencedata=False,
101102
)
102103

103104

@@ -150,13 +151,23 @@ def time_drug_evaluation(self):
150151
"effect size", diff_of_means / np.sqrt((group1_std ** 2 + group2_std ** 2) / 2)
151152
)
152153
pm.sample(
153-
draws=20000, cores=4, chains=4, progressbar=False, compute_convergence_checks=False
154+
draws=20000,
155+
cores=4,
156+
chains=4,
157+
progressbar=False,
158+
compute_convergence_checks=False,
159+
return_inferencedata=False,
154160
)
155161

156162
def time_glm_hierarchical(self):
157163
with glm_hierarchical_model():
158164
pm.sample(
159-
draws=20000, cores=4, chains=4, progressbar=False, compute_convergence_checks=False
165+
draws=20000,
166+
cores=4,
167+
chains=4,
168+
progressbar=False,
169+
compute_convergence_checks=False,
170+
return_inferencedata=False,
160171
)
161172

162173

@@ -190,6 +201,7 @@ def track_glm_hierarchical_ess(self, init):
190201
random_seed=100,
191202
progressbar=False,
192203
compute_convergence_checks=False,
204+
return_inferencedata=False,
193205
)
194206
tot = time.time() - t0
195207
ess = float(az.ess(trace, var_names=["mu_a"])["mu_a"].values)
@@ -212,6 +224,7 @@ def track_marginal_mixture_model_ess(self, init):
212224
random_seed=100,
213225
progressbar=False,
214226
compute_convergence_checks=False,
227+
return_inferencedata=False,
215228
)
216229
tot = time.time() - t0
217230
ess = az.ess(trace, var_names=["mu"])["mu"].values.min() # worst case
@@ -243,6 +256,7 @@ def track_glm_hierarchical_ess(self, step):
243256
random_seed=100,
244257
progressbar=False,
245258
compute_convergence_checks=False,
259+
return_inferencedata=False,
246260
)
247261
tot = time.time() - t0
248262
ess = float(az.ess(trace, var_names=["mu_a"])["mu_a"].values)
@@ -302,7 +316,9 @@ def freefall(y, t, p):
302316
Y = pm.Normal("Y", mu=ode_solution, sd=sigma, observed=y)
303317

304318
t0 = time.time()
305-
trace = pm.sample(500, tune=1000, chains=2, cores=2, random_seed=0)
319+
trace = pm.sample(
320+
500, tune=1000, chains=2, cores=2, random_seed=0, return_inferencedata=False
321+
)
306322
tot = time.time() - t0
307323
ess = az.ess(trace)
308324
return np.mean([ess.sigma, ess.gamma]) / tot

pymc3/tests/sampler_fixtures.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,13 @@ def setup_class(cls):
140140
cls.model = cls.make_model()
141141
with cls.model:
142142
cls.step = cls.make_step()
143-
cls.trace = pm.sample(cls.n_samples, tune=cls.tune, step=cls.step, cores=cls.chains)
143+
cls.trace = pm.sample(
144+
cls.n_samples,
145+
tune=cls.tune,
146+
step=cls.step,
147+
cores=cls.chains,
148+
return_inferencedata=False,
149+
)
144150
cls.samples = {}
145151
for var in cls.model.unobserved_RVs:
146152
cls.samples[get_var_name(var)] = cls.trace.get_values(var, burn=cls.burn)

pymc3/tests/test_bart.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,14 +98,14 @@ def test_model():
9898
sigma = pm.HalfNormal("sigma", 1)
9999
mu = pm.BART("mu", X, Y, m=50)
100100
y = pm.Normal("y", mu, sigma, observed=Y)
101-
trace = pm.sample(1000, random_seed=212480)
101+
trace = pm.sample(1000, random_seed=212480, return_inferencedata=False)
102102

103103
np.testing.assert_allclose(trace[mu].mean(0), Y, 0.5)
104104

105105
Y = np.repeat([0, 1], 50)
106106
with pm.Model() as model:
107107
mu = pm.BART("mu", X, Y, m=50, inv_link="logistic")
108108
y = pm.Bernoulli("y", mu, observed=Y)
109-
trace = pm.sample(1000, random_seed=212480)
109+
trace = pm.sample(1000, random_seed=212480, return_inferencedata=False)
110110

111111
np.testing.assert_allclose(trace[mu].mean(0), Y, atol=0.5)

pymc3/tests/test_data_container.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def test_sample(self):
4444
pm.Normal("obs", b * x_shared, np.sqrt(1e-2), observed=y)
4545

4646
prior_trace0 = pm.sample_prior_predictive(1000)
47-
trace = pm.sample(1000, init=None, tune=1000, chains=1)
47+
trace = pm.sample(1000, init=None, tune=1000, chains=1, return_inferencedata=False)
4848
pp_trace0 = pm.sample_posterior_predictive(trace, 1000)
4949
pp_trace01 = pm.fast_sample_posterior_predictive(trace, 1000)
5050

@@ -75,7 +75,7 @@ def test_sample_posterior_predictive_after_set_data(self):
7575
y = pm.Data("y", [1.0, 2.0, 3.0])
7676
beta = pm.Normal("beta", 0, 10.0)
7777
pm.Normal("obs", beta * x, np.sqrt(1e-2), observed=y)
78-
trace = pm.sample(1000, tune=1000, chains=1)
78+
trace = pm.sample(1000, tune=1000, chains=1, return_inferencedata=False)
7979
# Predict on new data.
8080
with model:
8181
x_test = [5.0, 6.0, 9.0]
@@ -94,13 +94,13 @@ def test_sample_after_set_data(self):
9494
y = pm.Data("y", [1.0, 2.0, 3.0])
9595
beta = pm.Normal("beta", 0, 10.0)
9696
pm.Normal("obs", beta * x, np.sqrt(1e-2), observed=y)
97-
pm.sample(1000, init=None, tune=1000, chains=1)
97+
pm.sample(1000, init=None, tune=1000, chains=1, return_inferencedata=False)
9898
# Predict on new data.
9999
new_x = [5.0, 6.0, 9.0]
100100
new_y = [5.0, 6.0, 9.0]
101101
with model:
102102
pm.set_data(new_data={"x": new_x, "y": new_y})
103-
new_trace = pm.sample(1000, init=None, tune=1000, chains=1)
103+
new_trace = pm.sample(1000, init=None, tune=1000, chains=1, return_inferencedata=False)
104104
pp_trace = pm.sample_posterior_predictive(new_trace, 1000)
105105
pp_tracef = pm.fast_sample_posterior_predictive(new_trace, 1000)
106106

@@ -121,7 +121,7 @@ def test_shared_data_as_index(self):
121121
pm.Normal("obs", alpha[index], np.sqrt(1e-2), observed=y)
122122

123123
prior_trace = pm.sample_prior_predictive(1000, var_names=["alpha"])
124-
trace = pm.sample(1000, init=None, tune=1000, chains=1)
124+
trace = pm.sample(1000, init=None, tune=1000, chains=1, return_inferencedata=False)
125125

126126
# Predict on new data
127127
new_index = np.array([0, 1, 2])
@@ -146,14 +146,14 @@ def test_shared_data_as_rv_input(self):
146146
with pm.Model() as m:
147147
x = pm.Data("x", [1.0, 2.0, 3.0])
148148
_ = pm.Normal("y", mu=x, shape=3)
149-
trace = pm.sample(chains=1)
149+
trace = pm.sample(chains=1, return_inferencedata=False)
150150

151151
np.testing.assert_allclose(np.array([1.0, 2.0, 3.0]), x.get_value(), atol=1e-1)
152152
np.testing.assert_allclose(np.array([1.0, 2.0, 3.0]), trace["y"].mean(0), atol=1e-1)
153153

154154
with m:
155155
pm.set_data({"x": np.array([2.0, 4.0, 6.0])})
156-
trace = pm.sample(chains=1)
156+
trace = pm.sample(chains=1, return_inferencedata=False)
157157

158158
np.testing.assert_allclose(np.array([2.0, 4.0, 6.0]), x.get_value(), atol=1e-1)
159159
np.testing.assert_allclose(np.array([2.0, 4.0, 6.0]), trace["y"].mean(0), atol=1e-1)
@@ -189,7 +189,7 @@ def test_set_data_to_non_data_container_variables(self):
189189
y = np.array([1.0, 2.0, 3.0])
190190
beta = pm.Normal("beta", 0, 10.0)
191191
pm.Normal("obs", beta * x, np.sqrt(1e-2), observed=y)
192-
pm.sample(1000, init=None, tune=1000, chains=1)
192+
pm.sample(1000, init=None, tune=1000, chains=1, return_inferencedata=False)
193193
with pytest.raises(TypeError) as error:
194194
pm.set_data({"beta": [1.1, 2.2, 3.3]}, model=model)
195195
error.match("defined as `pymc3.Data` inside the model")
@@ -201,7 +201,7 @@ def test_model_to_graphviz_for_model_with_data_container(self):
201201
beta = pm.Normal("beta", 0, 10.0)
202202
obs_sigma = floatX(np.sqrt(1e-2))
203203
pm.Normal("obs", beta * x, obs_sigma, observed=y)
204-
pm.sample(1000, init=None, tune=1000, chains=1)
204+
pm.sample(1000, init=None, tune=1000, chains=1, return_inferencedata=False)
205205

206206
for formatting in {"latex", "latex_with_params"}:
207207
with pytest.raises(ValueError, match="Unsupported formatting"):

pymc3/tests/test_distributions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2596,7 +2596,7 @@ def func(x):
25962596
with pm.Model():
25972597
pm.Normal("x")
25982598
y = pm.DensityDist("y", func)
2599-
pm.sample(draws=5, tune=1, mp_ctx="spawn")
2599+
pm.sample(draws=5, tune=1, mp_ctx="spawn", return_inferencedata=False)
26002600

26012601
import pickle
26022602

pymc3/tests/test_distributions_random.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1296,7 +1296,7 @@ def test_density_dist_with_random_sampleable(self, shape):
12961296
shape=shape,
12971297
random=normal_dist.random,
12981298
)
1299-
trace = pm.sample(100, cores=1)
1299+
trace = pm.sample(100, cores=1, return_inferencedata=False)
13001300

13011301
samples = 500
13021302
size = 100
@@ -1319,7 +1319,7 @@ def test_density_dist_with_random_sampleable_failure(self, shape):
13191319
random=normal_dist.random,
13201320
wrap_random_with_dist_shape=False,
13211321
)
1322-
trace = pm.sample(100, cores=1)
1322+
trace = pm.sample(100, cores=1, return_inferencedata=False)
13231323

13241324
samples = 500
13251325
with pytest.raises(RuntimeError):
@@ -1342,7 +1342,7 @@ def test_density_dist_with_random_sampleable_hidden_error(self, shape):
13421342
wrap_random_with_dist_shape=False,
13431343
check_shape_in_random=False,
13441344
)
1345-
trace = pm.sample(100, cores=1)
1345+
trace = pm.sample(100, cores=1, return_inferencedata=False)
13461346

13471347
samples = 500
13481348
ppc = pm.sample_posterior_predictive(trace, samples=samples, model=model)
@@ -1365,7 +1365,7 @@ def test_density_dist_with_random_sampleable_handcrafted_success(self):
13651365
random=rvs,
13661366
wrap_random_with_dist_shape=False,
13671367
)
1368-
trace = pm.sample(100, cores=1)
1368+
trace = pm.sample(100, cores=1, return_inferencedata=False)
13691369

13701370
samples = 500
13711371
size = 100
@@ -1385,7 +1385,7 @@ def test_density_dist_with_random_sampleable_handcrafted_success_fast(self):
13851385
random=rvs,
13861386
wrap_random_with_dist_shape=False,
13871387
)
1388-
trace = pm.sample(100, cores=1)
1388+
trace = pm.sample(100, cores=1, return_inferencedata=False)
13891389

13901390
samples = 500
13911391
size = 100
@@ -1398,7 +1398,7 @@ def test_density_dist_without_random_not_sampleable(self):
13981398
mu = pm.Normal("mu", 0, 1)
13991399
normal_dist = pm.Normal.dist(mu, 1)
14001400
pm.DensityDist("density_dist", normal_dist.logp, observed=np.random.randn(100))
1401-
trace = pm.sample(100, cores=1)
1401+
trace = pm.sample(100, cores=1, return_inferencedata=False)
14021402

14031403
samples = 500
14041404
with pytest.raises(ValueError):

pymc3/tests/test_examples.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def build_model(self):
7575
def test_run(self):
7676
model = self.build_model()
7777
with model:
78-
pm.sample(50, tune=50)
78+
pm.sample(50, tune=50, return_inferencedata=False)
7979

8080

8181
class TestARM12_6(SeededTest):
@@ -112,7 +112,7 @@ def too_slow(self):
112112
vars=[model["groupmean"], model["sd_interval__"], model["floor_m"]],
113113
)
114114
step = pm.NUTS(model.vars, scaling=start)
115-
pm.sample(50, step=step, start=start)
115+
pm.sample(50, step=step, start=start, return_inferencedata=False)
116116

117117

118118
class TestARM12_6Uranium(SeededTest):
@@ -159,7 +159,7 @@ def too_slow(self):
159159
h = np.diag(H(start))
160160

161161
step = pm.HamiltonianMC(model.vars, h)
162-
pm.sample(50, step=step, start=start)
162+
pm.sample(50, step=step, start=start, return_inferencedata=False)
163163

164164

165165
def build_disaster_model(masked=False):
@@ -202,7 +202,9 @@ def test_disaster_model(self):
202202
start = {"early_mean": 2.0, "late_mean": 3.0}
203203
# Use slice sampler for means (other variables auto-selected)
204204
step = pm.Slice([model.early_mean_log__, model.late_mean_log__])
205-
tr = pm.sample(500, tune=50, start=start, step=step, chains=2)
205+
tr = pm.sample(
206+
500, tune=50, start=start, step=step, chains=2, return_inferencedata=False
207+
)
206208
az.summary(tr)
207209

208210
def test_disaster_model_missing(self):
@@ -212,7 +214,9 @@ def test_disaster_model_missing(self):
212214
start = {"early_mean": 2.0, "late_mean": 3.0}
213215
# Use slice sampler for means (other variables auto-selected)
214216
step = pm.Slice([model.early_mean_log__, model.late_mean_log__])
215-
tr = pm.sample(500, tune=50, start=start, step=step, chains=2)
217+
tr = pm.sample(
218+
500, tune=50, start=start, step=step, chains=2, return_inferencedata=False
219+
)
216220
az.summary(tr)
217221

218222

@@ -231,7 +235,7 @@ def build_model(self):
231235
def test_run(self):
232236
with self.build_model():
233237
start = pm.find_MAP(method="Powell")
234-
pm.sample(50, pm.Slice(), start=start)
238+
pm.sample(50, pm.Slice(), start=start, return_inferencedata=False)
235239

236240

237241
class TestLatentOccupancy(SeededTest):
@@ -290,7 +294,9 @@ def test_run(self):
290294
}
291295
step_one = pm.Metropolis([model.theta_interval__, model.psi_logodds__])
292296
step_two = pm.BinaryMetropolis([model.z])
293-
pm.sample(50, step=[step_one, step_two], start=start, chains=1)
297+
pm.sample(
298+
50, step=[step_one, step_two], start=start, chains=1, return_inferencedata=False
299+
)
294300

295301

296302
@pytest.mark.xfail(
@@ -332,7 +338,7 @@ def build_model(self):
332338

333339
def test_run(self):
334340
with self.build_model():
335-
pm.sample(50, step=[pm.NUTS(), pm.Metropolis()])
341+
pm.sample(50, step=[pm.NUTS(), pm.Metropolis()], return_inferencedata=False)
336342

337343

338344
class TestMultilevelNormal(SeededTest):
@@ -375,9 +381,9 @@ def test_run(self):
375381

376382
with model:
377383
step = pm.MLDA(subsampling_rates=2, coarse_models=coarse_models)
378-
pm.sample(draws=50, chains=2, tune=50, step=step)
384+
pm.sample(draws=50, chains=2, tune=50, step=step, return_inferencedata=False)
379385

380386
step = pm.MLDA(
381387
subsampling_rates=2, coarse_models=coarse_models, base_sampler="Metropolis"
382388
)
383-
pm.sample(draws=50, chains=2, tune=50, step=step)
389+
pm.sample(draws=50, chains=2, tune=50, step=step, return_inferencedata=False)

pymc3/tests/test_hmc.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@ def test_nuts_tuning():
5353
with model:
5454
pymc3.Normal("mu", mu=0, sigma=1)
5555
step = pymc3.NUTS()
56-
trace = pymc3.sample(10, step=step, tune=5, progressbar=False, chains=1)
56+
trace = pymc3.sample(
57+
10, step=step, tune=5, progressbar=False, chains=1, return_inferencedata=False
58+
)
5759

5860
assert not step.tune
5961
assert np.all(trace["step_size"][5:] == trace["step_size"][5])

pymc3/tests/test_mixture.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ def build_toy_dataset(N, K):
375375

376376
pm.Mixture("x_obs", pi, comp_dist, observed=X)
377377
with model:
378-
trace = pm.sample(30, tune=10, chains=1)
378+
trace = pm.sample(30, tune=10, chains=1, return_inferencedata=False)
379379

380380
n_samples = 20
381381
with model:

pymc3/tests/test_model.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,13 @@ def test_matrix_multiplication():
194194
rv_det = pm.Deterministic("rv_det", matrix @ rv_rv)
195195
det_rv = pm.Deterministic("det_rv", rv_rv @ transformed)
196196

197-
posterior = pm.sample(10, tune=0, compute_convergence_checks=False, progressbar=False)
197+
posterior = pm.sample(
198+
10,
199+
tune=0,
200+
compute_convergence_checks=False,
201+
progressbar=False,
202+
return_inferencedata=False,
203+
)
198204
decimal = select_by_precision(7, 5)
199205
for point in posterior.points():
200206
npt.assert_almost_equal(

pymc3/tests/test_ndarray_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def model():
219219
@classmethod
220220
def setup_class(cls):
221221
with TestSaveLoad.model():
222-
cls.trace = pm.sample()
222+
cls.trace = pm.sample(return_inferencedata=False)
223223

224224
def test_save_new_model(self, tmpdir_factory):
225225
directory = str(tmpdir_factory.mktemp("data"))
@@ -228,7 +228,7 @@ def test_save_new_model(self, tmpdir_factory):
228228
assert save_dir == directory
229229
with pm.Model() as model:
230230
w = pm.Normal("w", 0, 1)
231-
new_trace = pm.sample()
231+
new_trace = pm.sample(return_inferencedata=False)
232232

233233
with pytest.raises(OSError):
234234
_ = pm.save_trace(new_trace, directory)

0 commit comments

Comments
 (0)