|
17 | 17 | import warnings
|
18 | 18 |
|
19 | 19 | from contextlib import ExitStack as does_not_raise
|
| 20 | +from copy import copy |
20 | 21 | from typing import Tuple
|
21 | 22 |
|
22 | 23 | import aesara
|
|
52 | 53 | NUTS,
|
53 | 54 | BinaryGibbsMetropolis,
|
54 | 55 | CategoricalGibbsMetropolis,
|
| 56 | + HamiltonianMC, |
55 | 57 | Metropolis,
|
56 | 58 | Slice,
|
57 | 59 | )
|
@@ -2565,3 +2567,45 @@ def test_modify_step_methods(self):
|
2565 | 2567 | with aesara.config.change_flags(mode=fast_unstable_sampling_mode):
|
2566 | 2568 | steps = assign_step_methods(model, [])
|
2567 | 2569 | assert isinstance(steps, NUTS)
|
| 2570 | + |
| 2571 | + |
| 2572 | +class TestType: |
| 2573 | + samplers = (Metropolis, Slice, HamiltonianMC, NUTS) |
| 2574 | + |
| 2575 | + def setup_method(self): |
| 2576 | + # save Aesara config object |
| 2577 | + self.aesara_config = copy(aesara.config) |
| 2578 | + |
| 2579 | + def teardown_method(self): |
| 2580 | + # restore aesara config |
| 2581 | + aesara.config = self.aesara_config |
| 2582 | + |
| 2583 | + @aesara.config.change_flags({"floatX": "float64", "warn_float64": "ignore"}) |
| 2584 | + def test_float64(self): |
| 2585 | + with pm.Model() as model: |
| 2586 | + x = pm.Normal("x", initval=np.array(1.0, dtype="float64")) |
| 2587 | + obs = pm.Normal("obs", mu=x, sigma=1.0, observed=np.random.randn(5)) |
| 2588 | + |
| 2589 | + assert x.dtype == "float64" |
| 2590 | + assert obs.dtype == "float64" |
| 2591 | + |
| 2592 | + for sampler in self.samplers: |
| 2593 | + with model: |
| 2594 | + with warnings.catch_warnings(): |
| 2595 | + warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning) |
| 2596 | + pm.sample(draws=10, tune=10, chains=1, step=sampler()) |
| 2597 | + |
| 2598 | + @aesara.config.change_flags({"floatX": "float32", "warn_float64": "warn"}) |
| 2599 | + def test_float32(self): |
| 2600 | + with pm.Model() as model: |
| 2601 | + x = pm.Normal("x", initval=np.array(1.0, dtype="float32")) |
| 2602 | + obs = pm.Normal("obs", mu=x, sigma=1.0, observed=np.random.randn(5).astype("float32")) |
| 2603 | + |
| 2604 | + assert x.dtype == "float32" |
| 2605 | + assert obs.dtype == "float32" |
| 2606 | + |
| 2607 | + for sampler in self.samplers: |
| 2608 | + with model: |
| 2609 | + with warnings.catch_warnings(): |
| 2610 | + warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning) |
| 2611 | + pm.sample(draws=10, tune=10, chains=1, step=sampler()) |
0 commit comments