15
15
import io
16
16
import operator
17
17
18
+ from contextlib import nullcontext
19
+
18
20
import cloudpickle
19
21
import numpy as np
20
22
import pytensor
29
31
from pymc .variational .opvi import NotImplementedInference
30
32
from tests import models
31
33
32
- pytestmark = pytest .mark .usefixtures ("strict_float32" , "seeded_test" )
34
+ pytestmark = pytest .mark .usefixtures ("strict_float32" , "seeded_test" , "fail_on_warning" )
33
35
34
36
35
37
@pytest .mark .parametrize ("score" , [True , False ])
@@ -157,7 +159,16 @@ def fit_kwargs(inference, use_minibatch):
157
159
158
160
159
161
def test_fit_oo (inference , fit_kwargs , simple_model_data ):
160
- trace = inference .fit (** fit_kwargs ).sample (10000 )
162
+ # Minibatch data can't be extracted into the `observed_data` group in the final InferenceData
163
+ if getattr (simple_model_data ["data" ], "name" , "" ).startswith ("minibatch" ):
164
+ warn_ctxt = pytest .warns (
165
+ UserWarning , match = "Could not extract data from symbolic observation"
166
+ )
167
+ else :
168
+ warn_ctxt = nullcontext ()
169
+
170
+ with warn_ctxt :
171
+ trace = inference .fit (** fit_kwargs ).sample (10000 )
161
172
mu_post = simple_model_data ["mu_post" ]
162
173
d = simple_model_data ["d" ]
163
174
np .testing .assert_allclose (np .mean (trace .posterior ["mu" ]), mu_post , rtol = 0.05 )
@@ -180,11 +191,21 @@ def test_fit_start(inference_spec, simple_model):
180
191
kw = {"start" : {"mu" : mu_init }}
181
192
if has_start_sigma :
182
193
kw .update ({"start_sigma" : {"mu" : mu_sigma_init }})
183
-
184
194
with simple_model :
185
195
inference = inference_spec (** kw )
196
+
197
+ # Minibatch data can't be extracted into the `observed_data` group in the final InferenceData
198
+ [observed_value ] = [simple_model .rvs_to_values [obs ] for obs in simple_model .observed_RVs ]
199
+ if observed_value .name .startswith ("minibatch" ):
200
+ warn_ctxt = pytest .warns (
201
+ UserWarning , match = "Could not extract data from symbolic observation"
202
+ )
203
+ else :
204
+ warn_ctxt = nullcontext ()
205
+
186
206
try :
187
- trace = inference .fit (n = 0 ).sample (10000 )
207
+ with warn_ctxt :
208
+ trace = inference .fit (n = 0 ).sample (10000 )
188
209
except NotImplementedInference as e :
189
210
pytest .skip (str (e ))
190
211
np .testing .assert_allclose (np .mean (trace .posterior ["mu" ]), mu_init , rtol = 0.05 )
0 commit comments