Skip to content

Commit f9e6418

Browse files
committed
Make test_inference tests fail on unexpected warnings
1 parent e3961fc commit f9e6418

File tree

2 files changed

+33
-4
lines changed

2 files changed

+33
-4
lines changed

tests/conftest.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import warnings
1415

1516
import numpy as np
1617
import pytensor
@@ -45,3 +46,10 @@ def strict_float32():
4546
def seeded_test():
4647
# TODO: use this instead of SeededTest
4748
np.random.seed(42)
49+
50+
51+
@pytest.fixture
52+
def fail_on_warning():
53+
with warnings.catch_warnings():
54+
warnings.simplefilter("error")
55+
yield

tests/variational/test_inference.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import io
1616
import operator
1717

18+
from contextlib import nullcontext
19+
1820
import cloudpickle
1921
import numpy as np
2022
import pytensor
@@ -29,7 +31,7 @@
2931
from pymc.variational.opvi import NotImplementedInference
3032
from tests import models
3133

32-
pytestmark = pytest.mark.usefixtures("strict_float32", "seeded_test")
34+
pytestmark = pytest.mark.usefixtures("strict_float32", "seeded_test", "fail_on_warning")
3335

3436

3537
@pytest.mark.parametrize("score", [True, False])
@@ -157,7 +159,16 @@ def fit_kwargs(inference, use_minibatch):
157159

158160

159161
def test_fit_oo(inference, fit_kwargs, simple_model_data):
160-
trace = inference.fit(**fit_kwargs).sample(10000)
162+
# Minibatch data can't be extracted into the `observed_data` group in the final InferenceData
163+
if getattr(simple_model_data["data"], "name", "").startswith("minibatch"):
164+
warn_ctxt = pytest.warns(
165+
UserWarning, match="Could not extract data from symbolic observation"
166+
)
167+
else:
168+
warn_ctxt = nullcontext()
169+
170+
with warn_ctxt:
171+
trace = inference.fit(**fit_kwargs).sample(10000)
161172
mu_post = simple_model_data["mu_post"]
162173
d = simple_model_data["d"]
163174
np.testing.assert_allclose(np.mean(trace.posterior["mu"]), mu_post, rtol=0.05)
@@ -180,11 +191,21 @@ def test_fit_start(inference_spec, simple_model):
180191
kw = {"start": {"mu": mu_init}}
181192
if has_start_sigma:
182193
kw.update({"start_sigma": {"mu": mu_sigma_init}})
183-
184194
with simple_model:
185195
inference = inference_spec(**kw)
196+
197+
# Minibatch data can't be extracted into the `observed_data` group in the final InferenceData
198+
[observed_value] = [simple_model.rvs_to_values[obs] for obs in simple_model.observed_RVs]
199+
if observed_value.name.startswith("minibatch"):
200+
warn_ctxt = pytest.warns(
201+
UserWarning, match="Could not extract data from symbolic observation"
202+
)
203+
else:
204+
warn_ctxt = nullcontext()
205+
186206
try:
187-
trace = inference.fit(n=0).sample(10000)
207+
with warn_ctxt:
208+
trace = inference.fit(n=0).sample(10000)
188209
except NotImplementedInference as e:
189210
pytest.skip(str(e))
190211
np.testing.assert_allclose(np.mean(trace.posterior["mu"]), mu_init, rtol=0.05)

0 commit comments

Comments
 (0)