Skip to content

Commit f2a70fd

Browse files
authored
Merge branch 'pymc-devs:main' into flaky_eulermaruyama_tests
2 parents eadfd75 + 5d7283e commit f2a70fd

29 files changed

+829
-392
lines changed

pymc/aesaraf.py

Lines changed: 82 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import warnings
15+
1416
from typing import (
1517
Callable,
1618
Dict,
@@ -31,6 +33,7 @@
3133
import scipy.sparse as sps
3234

3335
from aeppl.logprob import CheckParameterValue
36+
from aeppl.transforms import RVTransform
3437
from aesara import scalar
3538
from aesara.compile.mode import Mode, get_mode
3639
from aesara.gradient import grad
@@ -205,10 +208,9 @@ def expand(var):
205208
yield from walk(graphs, expand, bfs=False)
206209

207210

208-
def replace_rvs_in_graphs(
211+
def _replace_rvs_in_graphs(
209212
graphs: Iterable[TensorVariable],
210213
replacement_fn: Callable[[TensorVariable], Dict[TensorVariable, TensorVariable]],
211-
initial_replacements: Optional[Dict[TensorVariable, TensorVariable]] = None,
212214
**kwargs,
213215
) -> Tuple[List[TensorVariable], Dict[TensorVariable, TensorVariable]]:
214216
"""Replace random variables in graphs
@@ -226,8 +228,6 @@ def replace_rvs_in_graphs(
226228
that were made.
227229
"""
228230
replacements = {}
229-
if initial_replacements:
230-
replacements.update(initial_replacements)
231231

232232
def expand_replace(var):
233233
new_nodes = []
@@ -239,6 +239,7 @@ def expand_replace(var):
239239
new_nodes.extend(replacement_fn(var, replacements))
240240
return new_nodes
241241

242+
# This iteration populates the replacements
242243
for var in walk_model(graphs, expand_fn=expand_replace, **kwargs):
243244
pass
244245

@@ -253,7 +254,15 @@ def expand_replace(var):
253254
clone=False,
254255
)
255256

256-
fg.replace_all(replacements.items(), import_missing=True)
257+
# replacements have to be done in reverse topological order so that nested
258+
# expressions get recursively replaced correctly
259+
toposort = fg.toposort()
260+
sorted_replacements = sorted(
261+
tuple(replacements.items()),
262+
key=lambda pair: toposort.index(pair[0].owner),
263+
reverse=True,
264+
)
265+
fg.replace_all(sorted_replacements, import_missing=True)
257266

258267
graphs = list(fg.outputs)
259268

@@ -263,7 +272,6 @@ def expand_replace(var):
263272
def rvs_to_value_vars(
264273
graphs: Iterable[Variable],
265274
apply_transforms: bool = True,
266-
initial_replacements: Optional[Dict[Variable, Variable]] = None,
267275
**kwargs,
268276
) -> List[Variable]:
269277
"""Clone and replace random variables in graphs with their value variables.
@@ -276,10 +284,11 @@ def rvs_to_value_vars(
276284
The graphs in which to perform the replacements.
277285
apply_transforms
278286
If ``True``, apply each value variable's transform.
279-
initial_replacements
280-
A ``dict`` containing the initial replacements to be made.
281-
282287
"""
288+
warnings.warn(
289+
"rvs_to_value_vars is deprecated. Use model.replace_rvs_by_values instead",
290+
FutureWarning,
291+
)
283292

284293
def populate_replacements(
285294
random_var: TensorVariable, replacements: Dict[TensorVariable, TensorVariable]
@@ -311,15 +320,72 @@ def populate_replacements(
311320
equiv = clone_get_equiv(inputs, graphs, False, False, {})
312321
graphs = [equiv[n] for n in graphs]
313322

314-
if initial_replacements:
315-
initial_replacements = {
316-
equiv.get(k, k): equiv.get(v, v) for k, v in initial_replacements.items()
317-
}
318-
319-
graphs, _ = replace_rvs_in_graphs(
323+
graphs, _ = _replace_rvs_in_graphs(
320324
graphs,
321325
replacement_fn=populate_replacements,
322-
initial_replacements=initial_replacements,
326+
**kwargs,
327+
)
328+
329+
return graphs
330+
331+
332+
def replace_rvs_by_values(
333+
graphs: Sequence[TensorVariable],
334+
*,
335+
rvs_to_values: Dict[TensorVariable, TensorVariable],
336+
rvs_to_transforms: Dict[TensorVariable, RVTransform],
337+
**kwargs,
338+
) -> List[TensorVariable]:
339+
"""Clone and replace random variables in graphs with their value variables.
340+
341+
This will *not* recompute test values in the resulting graphs.
342+
343+
Parameters
344+
----------
345+
graphs
346+
The graphs in which to perform the replacements.
347+
rvs_to_values
348+
Mapping between the original graph RVs and respective value variables
349+
rvs_to_transforms
350+
Mapping between the original graph RVs and respective value transforms
351+
"""
352+
353+
# Clone original graphs so that we don't modify variables in place
354+
inputs = [i for i in graph_inputs(graphs) if not isinstance(i, Constant)]
355+
equiv = clone_get_equiv(inputs, graphs, False, False, {})
356+
graphs = [equiv[n] for n in graphs]
357+
358+
# Get needed mappings for equivalent cloned variables
359+
equiv_rvs_to_values = {}
360+
equiv_rvs_to_transforms = {}
361+
for rv, value in rvs_to_values.items():
362+
equiv_rv = equiv.get(rv, rv)
363+
equiv_rvs_to_values[equiv_rv] = equiv.get(value, value)
364+
equiv_rvs_to_transforms[equiv_rv] = rvs_to_transforms[rv]
365+
366+
def poulate_replacements(rv, replacements):
367+
# Populate replacements dict with {rv: value} pairs indicating which graph
368+
# RVs should be replaced by what value variables.
369+
370+
# No value variable to replace RV with
371+
value = equiv_rvs_to_values.get(rv, None)
372+
if value is None:
373+
return []
374+
375+
transform = equiv_rvs_to_transforms.get(rv, None)
376+
if transform is not None:
377+
# We want to replace uses of the RV by the back-transformation of its value
378+
value = transform.backward(value, *rv.owner.inputs)
379+
value.name = rv.name
380+
381+
replacements[rv] = value
382+
# Also walk the graph of the value variable to make any additional
383+
# replacements if that is not a simple input variable
384+
return [value]
385+
386+
graphs, _ = _replace_rvs_in_graphs(
387+
graphs,
388+
replacement_fn=poulate_replacements,
323389
**kwargs,
324390
)
325391

pymc/backends/arviz.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def find_observations(model: "Model") -> Dict[str, Var]:
4747
"""If there are observations available, return them as a dictionary."""
4848
observations = {}
4949
for obs in model.observed_RVs:
50-
aux_obs = getattr(obs.tag, "observations", None)
50+
aux_obs = model.rvs_to_values.get(obs, None)
5151
if aux_obs is not None:
5252
try:
5353
obs_data = extract_obs_data(aux_obs)
@@ -261,7 +261,7 @@ def log_likelihood_vals_point(self, point, var, log_like_fun):
261261

262262
if isinstance(var.owner.op, (AdvancedIncSubtensor, AdvancedIncSubtensor1)):
263263
try:
264-
obs_data = extract_obs_data(var.tag.observations)
264+
obs_data = extract_obs_data(self.model.rvs_to_values[var])
265265
except TypeError:
266266
warnings.warn(f"Could not extract data from symbolic observation {var}")
267267

pymc/distributions/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
logcdf,
1717
logp,
1818
joint_logp,
19-
joint_logpt,
2019
)
2120

2221
from pymc.distributions.bound import Bound
@@ -199,7 +198,6 @@
199198
"Censored",
200199
"CAR",
201200
"PolyaGamma",
202-
"joint_logpt",
203201
"joint_logp",
204202
"logp",
205203
"logcdf",

pymc/distributions/distribution.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
shape_from_dims,
5050
)
5151
from pymc.printing import str_for_dist
52-
from pymc.util import UNSET
52+
from pymc.util import UNSET, _add_future_warning_tag
5353
from pymc.vartypes import string_types
5454

5555
__all__ = [
@@ -371,6 +371,7 @@ def dist(
371371
rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)")
372372
rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)")
373373
rv_out.random = _make_nice_attr_error("rv.random()", "pm.draw(rv)")
374+
_add_future_warning_tag(rv_out)
374375
return rv_out
375376

376377

pymc/distributions/logprob.py

Lines changed: 75 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -25,26 +25,18 @@
2525
from aeppl.logprob import logcdf as logcdf_aeppl
2626
from aeppl.logprob import logprob as logp_aeppl
2727
from aeppl.tensor import MeasurableJoin
28-
from aeppl.transforms import TransformValuesRewrite
28+
from aeppl.transforms import RVTransform, TransformValuesRewrite
2929
from aesara import tensor as at
3030
from aesara.graph.basic import graph_inputs, io_toposort
3131
from aesara.tensor.random.op import RandomVariable
32-
from aesara.tensor.subtensor import (
33-
AdvancedIncSubtensor,
34-
AdvancedIncSubtensor1,
35-
AdvancedSubtensor,
36-
AdvancedSubtensor1,
37-
IncSubtensor,
38-
Subtensor,
39-
)
4032
from aesara.tensor.var import TensorVariable
4133

4234
from pymc.aesaraf import constant_fold, floatX
4335

36+
TOTAL_SIZE = Union[int, Sequence[int], None]
4437

45-
def _get_scaling(
46-
total_size: Optional[Union[int, Sequence[int]]], shape, ndim: int
47-
) -> TensorVariable:
38+
39+
def _get_scaling(total_size: TOTAL_SIZE, shape, ndim: int) -> TensorVariable:
4840
"""
4941
Gets scaling constant for logp.
5042
@@ -112,22 +104,26 @@ def _get_scaling(
112104
return at.as_tensor(coef, dtype=aesara.config.floatX)
113105

114106

115-
subtensor_types = (
116-
AdvancedIncSubtensor,
117-
AdvancedIncSubtensor1,
118-
AdvancedSubtensor,
119-
AdvancedSubtensor1,
120-
IncSubtensor,
121-
Subtensor,
122-
)
123-
107+
def _check_no_rvs(logp_terms: Sequence[TensorVariable]):
108+
# Raise if there are unexpected RandomVariables in the logp graph
109+
# Only SimulatorRVs are allowed
110+
from pymc.distributions.simulator import SimulatorRV
124111

125-
def joint_logpt(*args, **kwargs):
126-
warnings.warn(
127-
"joint_logpt has been deprecated. Use joint_logp instead.",
128-
FutureWarning,
129-
)
130-
return joint_logp(*args, **kwargs)
112+
unexpected_rv_nodes = [
113+
node
114+
for node in aesara.graph.ancestors(logp_terms)
115+
if (
116+
node.owner
117+
and isinstance(node.owner.op, RandomVariable)
118+
and not isinstance(node.owner.op, SimulatorRV)
119+
)
120+
]
121+
if unexpected_rv_nodes:
122+
raise ValueError(
123+
f"Random variables detected in the logp graph: {unexpected_rv_nodes}.\n"
124+
"This can happen when DensityDist logp or Interval transform functions "
125+
"reference nonlocal variables."
126+
)
131127

132128

133129
def joint_logp(
@@ -169,6 +165,10 @@ def joint_logp(
169165
Sum the log-likelihood or return each term as a separate list item.
170166
171167
"""
168+
warnings.warn(
169+
"joint_logp has been deprecated, use model.logp instead",
170+
FutureWarning,
171+
)
172172
# TODO: In future when we drop support for tag.value_var most of the following
173173
# logic can be removed and logp can just be a wrapper function that calls aeppl's
174174
# joint_logprob directly.
@@ -241,33 +241,15 @@ def joint_logp(
241241
**kwargs,
242242
)
243243

244-
# Raise if there are unexpected RandomVariables in the logp graph
245-
# Only SimulatorRVs are allowed
246-
from pymc.distributions.simulator import SimulatorRV
247-
248-
unexpected_rv_nodes = [
249-
node
250-
for node in aesara.graph.ancestors(list(temp_logp_var_dict.values()))
251-
if (
252-
node.owner
253-
and isinstance(node.owner.op, RandomVariable)
254-
and not isinstance(node.owner.op, SimulatorRV)
255-
)
256-
]
257-
if unexpected_rv_nodes:
258-
raise ValueError(
259-
f"Random variables detected in the logp graph: {unexpected_rv_nodes}.\n"
260-
"This can happen when DensityDist logp or Interval transform functions "
261-
"reference nonlocal variables."
262-
)
263-
264244
# aeppl returns the logp for every single value term we provided to it. This includes
265245
# the extra values we plugged in above, so we filter those we actually wanted in the
266246
# same order they were given in.
267247
logp_var_dict = {}
268248
for value_var in rv_values.values():
269249
logp_var_dict[value_var] = temp_logp_var_dict[value_var]
270250

251+
_check_no_rvs(list(logp_var_dict.values()))
252+
271253
if scaling:
272254
for value_var in logp_var_dict.keys():
273255
if value_var in rv_scalings:
@@ -281,6 +263,52 @@ def joint_logp(
281263
return logp_var
282264

283265

266+
def _joint_logp(
267+
rvs: Sequence[TensorVariable],
268+
*,
269+
rvs_to_values: Dict[TensorVariable, TensorVariable],
270+
rvs_to_transforms: Dict[TensorVariable, RVTransform],
271+
jacobian: bool = True,
272+
rvs_to_total_sizes: Dict[TensorVariable, TOTAL_SIZE],
273+
**kwargs,
274+
) -> List[TensorVariable]:
275+
"""Thin wrapper around aeppl.factorized_joint_logprob, extended with PyMC specific
276+
concerns such as transforms, jacobian, and scaling"""
277+
278+
transform_rewrite = None
279+
values_to_transforms = {
280+
rvs_to_values[rv]: transform
281+
for rv, transform in rvs_to_transforms.items()
282+
if transform is not None
283+
}
284+
if values_to_transforms:
285+
# There seems to be an incorrect type hint in TransformValuesRewrite
286+
transform_rewrite = TransformValuesRewrite(values_to_transforms) # type: ignore
287+
288+
temp_logp_terms = factorized_joint_logprob(
289+
rvs_to_values,
290+
extra_rewrites=transform_rewrite,
291+
use_jacobian=jacobian,
292+
**kwargs,
293+
)
294+
295+
# aeppl returns the logp for every single value term we provided to it. This includes
296+
# the extra values we plugged in above, so we filter those we actually wanted in the
297+
# same order they were given in.
298+
logp_terms = {}
299+
for rv in rvs:
300+
value_var = rvs_to_values[rv]
301+
logp_term = temp_logp_terms[value_var]
302+
total_size = rvs_to_total_sizes.get(rv, None)
303+
if total_size is not None:
304+
scaling = _get_scaling(total_size, value_var.shape, value_var.ndim)
305+
logp_term *= scaling
306+
logp_terms[value_var] = logp_term
307+
308+
_check_no_rvs(list(logp_terms.values()))
309+
return list(logp_terms.values())
310+
311+
284312
def logp(rv: TensorVariable, value) -> TensorVariable:
285313
"""Return the log-probability graph of a Random Variable"""
286314

0 commit comments

Comments
 (0)