Skip to content

Commit 84152c7

Browse files
ArmavicaricardoV94
authored andcommitted
Move test_types into test_sampling
1 parent 5c7d972 commit 84152c7

File tree

3 files changed

+45
-67
lines changed

3 files changed

+45
-67
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ jobs:
6464
pymc/tests/tuning/test_scaling.py
6565
pymc/tests/tuning/test_starting.py
6666
pymc/tests/test_shared.py
67-
pymc/tests/test_types.py
67+
pymc/tests/test_sampling.py
6868
pymc/tests/distributions/test_dist_math.py
6969
pymc/tests/distributions/test_transform.py
7070
pymc/tests/test_parallel_sampling.py

pymc/tests/test_sampling.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import warnings
1818

1919
from contextlib import ExitStack as does_not_raise
20+
from copy import copy
2021
from typing import Tuple
2122

2223
import aesara
@@ -52,6 +53,7 @@
5253
NUTS,
5354
BinaryGibbsMetropolis,
5455
CategoricalGibbsMetropolis,
56+
HamiltonianMC,
5557
Metropolis,
5658
Slice,
5759
)
@@ -2565,3 +2567,45 @@ def test_modify_step_methods(self):
25652567
with aesara.config.change_flags(mode=fast_unstable_sampling_mode):
25662568
steps = assign_step_methods(model, [])
25672569
assert isinstance(steps, NUTS)
2570+
2571+
2572+
class TestType:
2573+
samplers = (Metropolis, Slice, HamiltonianMC, NUTS)
2574+
2575+
def setup_method(self):
2576+
# save Aesara config object
2577+
self.aesara_config = copy(aesara.config)
2578+
2579+
def teardown_method(self):
2580+
# restore aesara config
2581+
aesara.config = self.aesara_config
2582+
2583+
@aesara.config.change_flags({"floatX": "float64", "warn_float64": "ignore"})
2584+
def test_float64(self):
2585+
with pm.Model() as model:
2586+
x = pm.Normal("x", initval=np.array(1.0, dtype="float64"))
2587+
obs = pm.Normal("obs", mu=x, sigma=1.0, observed=np.random.randn(5))
2588+
2589+
assert x.dtype == "float64"
2590+
assert obs.dtype == "float64"
2591+
2592+
for sampler in self.samplers:
2593+
with model:
2594+
with warnings.catch_warnings():
2595+
warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning)
2596+
pm.sample(draws=10, tune=10, chains=1, step=sampler())
2597+
2598+
@aesara.config.change_flags({"floatX": "float32", "warn_float64": "warn"})
2599+
def test_float32(self):
2600+
with pm.Model() as model:
2601+
x = pm.Normal("x", initval=np.array(1.0, dtype="float32"))
2602+
obs = pm.Normal("obs", mu=x, sigma=1.0, observed=np.random.randn(5).astype("float32"))
2603+
2604+
assert x.dtype == "float32"
2605+
assert obs.dtype == "float32"
2606+
2607+
for sampler in self.samplers:
2608+
with model:
2609+
with warnings.catch_warnings():
2610+
warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning)
2611+
pm.sample(draws=10, tune=10, chains=1, step=sampler())

pymc/tests/test_types.py

Lines changed: 0 additions & 66 deletions
This file was deleted.

0 commit comments

Comments
 (0)