Skip to content

Commit d0474ed

Browse files
authored
Merge branch 'pymc-devs:main' into flaky_eulermaruyama_tests
2 parents 504f836 + 4acd98e commit d0474ed

19 files changed

+201
-205
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ repos:
1919
- id: isort
2020
name: isort
2121
- repo: https://github.com/asottile/pyupgrade
22-
rev: v3.2.0
22+
rev: v3.2.2
2323
hooks:
2424
- id: pyupgrade
2525
args: [--py37-plus]

pymc/backends/arviz.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -215,12 +215,11 @@ def __init__(
215215
}
216216

217217
self.dims = {} if dims is None else dims
218-
if hasattr(self.model, "RV_dims"):
219-
model_dims = {
220-
var_name: [dim for dim in dims if dim is not None]
221-
for var_name, dims in self.model.RV_dims.items()
222-
}
223-
self.dims = {**model_dims, **self.dims}
218+
model_dims = {
219+
var_name: [dim for dim in dims if dim is not None]
220+
for var_name, dims in self.model.named_vars_to_dims.items()
221+
}
222+
self.dims = {**model_dims, **self.dims}
224223
if sample_dims is None:
225224
sample_dims = ["chain", "draw"]
226225
self.sample_dims = sample_dims

pymc/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -718,6 +718,6 @@ def Data(
718718
length=xshape[d],
719719
)
720720

721-
model.add_random_variable(x, dims=dims)
721+
model.add_named_variable(x, dims=dims)
722722

723723
return x

pymc/initial_point.py

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -52,29 +52,11 @@ def convert_str_to_rv_dict(
5252
return initvals
5353

5454

55-
def filter_rvs_to_jitter(step) -> Set[TensorVariable]:
56-
"""Find the set of RVs for which the responsible step methods ask for
57-
the addition of jitter to the initial point.
58-
59-
Parameters
60-
----------
61-
step : BlockedStep or CompoundStep
62-
One or many step methods that were assigned model variables.
63-
64-
Returns
65-
-------
66-
rvs_to_jitter : set
67-
The random variables for which jitter should be added.
68-
"""
69-
# TODO: implement this
70-
return set()
71-
72-
7355
def make_initial_point_fns_per_chain(
7456
*,
7557
model,
7658
overrides: Optional[Union[StartDict, Sequence[Optional[StartDict]]]],
77-
jitter_rvs: Set[TensorVariable],
59+
jitter_rvs: Optional[Set[TensorVariable]] = None,
7860
chains: int,
7961
) -> List[Callable]:
8062
"""Create an initial point function for each chain, as defined by initvals
@@ -87,7 +69,7 @@ def make_initial_point_fns_per_chain(
8769
overrides : optional, list or dict
8870
Initial value strategy overrides that should take precedence over the defaults from the model.
8971
A sequence of None or dicts will be treated as chain-wise strategies and must have the same length as `seeds`.
90-
jitter_rvs : set
72+
jitter_rvs : set, optional
9173
Random variable tensors for which U(-1, 1) jitter shall be applied.
9274
(To the transformed space if applicable.)
9375
@@ -151,7 +133,7 @@ def make_initial_point_fn(
151133

152134
sdict_overrides = convert_str_to_rv_dict(model, overrides or {})
153135
initval_strats = {
154-
**model.initial_values,
136+
**model.rvs_to_initial_values,
155137
**sdict_overrides,
156138
}
157139

pymc/model.py

Lines changed: 49 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -550,35 +550,33 @@ def __init__(
550550
self.name = self._validate_name(name)
551551
self.check_bounds = check_bounds
552552

553-
self._initial_values: Dict[TensorVariable, Optional[Union[np.ndarray, Variable, str]]] = {}
554-
555553
if self.parent is not None:
556554
self.named_vars = treedict(parent=self.parent.named_vars)
555+
self.named_vars_to_dims = treedict(parent=self.parent.named_vars_to_dims)
557556
self.values_to_rvs = treedict(parent=self.parent.values_to_rvs)
558557
self.rvs_to_values = treedict(parent=self.parent.rvs_to_values)
559558
self.rvs_to_transforms = treedict(parent=self.parent.rvs_to_transforms)
560559
self.rvs_to_total_sizes = treedict(parent=self.parent.rvs_to_total_sizes)
560+
self.rvs_to_initial_values = treedict(parent=self.parent.rvs_to_initial_values)
561561
self.free_RVs = treelist(parent=self.parent.free_RVs)
562562
self.observed_RVs = treelist(parent=self.parent.observed_RVs)
563-
self.auto_deterministics = treelist(parent=self.parent.auto_deterministics)
564563
self.deterministics = treelist(parent=self.parent.deterministics)
565564
self.potentials = treelist(parent=self.parent.potentials)
566565
self._coords = self.parent._coords
567-
self._RV_dims = treedict(parent=self.parent._RV_dims)
568566
self._dim_lengths = self.parent._dim_lengths
569567
else:
570568
self.named_vars = treedict()
569+
self.named_vars_to_dims = treedict()
571570
self.values_to_rvs = treedict()
572571
self.rvs_to_values = treedict()
573572
self.rvs_to_transforms = treedict()
574573
self.rvs_to_total_sizes = treedict()
574+
self.rvs_to_initial_values = treedict()
575575
self.free_RVs = treelist()
576576
self.observed_RVs = treelist()
577-
self.auto_deterministics = treelist()
578577
self.deterministics = treelist()
579578
self.potentials = treelist()
580579
self._coords = {}
581-
self._RV_dims = treedict()
582580
self._dim_lengths = {}
583581
self.add_coords(coords)
584582

@@ -972,7 +970,11 @@ def RV_dims(self) -> Dict[str, Tuple[Union[str, None], ...]]:
972970
973971
Entries in the tuples may be ``None``, if the RV dimension was not given a name.
974972
"""
975-
return self._RV_dims
973+
warnings.warn(
974+
"Model.RV_dims is deprecated. User Model.named_vars_to_dims instead.",
975+
FutureWarning,
976+
)
977+
return self.named_vars_to_dims
976978

977979
@property
978980
def coords(self) -> Dict[str, Union[Tuple, None]]:
@@ -1124,15 +1126,18 @@ def initial_values(self) -> Dict[TensorVariable, Optional[Union[np.ndarray, Vari
11241126
Keys are the random variables (as returned by e.g. ``pm.Uniform()``) and
11251127
values are the numeric/symbolic initial values, strings denoting the strategy to get them, or None.
11261128
"""
1127-
return self._initial_values
1129+
warnings.warn(
1130+
"Model.initial_values is deprecated. Use Model.rvs_to_initial_values instead."
1131+
)
1132+
return self.rvs_to_initial_values
11281133

11291134
def set_initval(self, rv_var, initval):
11301135
"""Sets an initial value (strategy) for a random variable."""
11311136
if initval is not None and not isinstance(initval, (Variable, str)):
11321137
# Convert scalars or array-like inputs to ndarrays
11331138
initval = rv_var.type.filter(initval)
11341139

1135-
self.initial_values[rv_var] = initval
1140+
self.rvs_to_initial_values[rv_var] = initval
11361141

11371142
def set_data(
11381143
self,
@@ -1167,7 +1172,7 @@ def set_data(
11671172
if isinstance(values, list):
11681173
values = np.array(values)
11691174
values = convert_observed_data(values)
1170-
dims = self.RV_dims.get(name, None) or ()
1175+
dims = self.named_vars_to_dims.get(name, None) or ()
11711176
coords = coords or {}
11721177

11731178
if values.ndim != shared_object.ndim:
@@ -1257,7 +1262,7 @@ def set_data(
12571262
shared_object.set_value(values)
12581263

12591264
def register_rv(
1260-
self, rv_var, name, data=None, total_size=None, dims=None, transform=UNSET, initval=None
1265+
self, rv_var, name, observed=None, total_size=None, dims=None, transform=UNSET, initval=None
12611266
):
12621267
"""Register an (un)observed random variable with the model.
12631268
@@ -1266,9 +1271,8 @@ def register_rv(
12661271
rv_var: TensorVariable
12671272
name: str
12681273
Intended name for the model variable.
1269-
data: array_like (optional)
1270-
If data is provided, the variable is observed. If None,
1271-
the variable is unobserved.
1274+
observed: array_like (optional)
1275+
Data values for observed variables.
12721276
total_size: scalar
12731277
upscales logp of variable with ``coef = total_size/var.shape[0]``
12741278
dims: tuple
@@ -1295,31 +1299,31 @@ def register_rv(
12951299
if dname not in self.dim_lengths:
12961300
self.add_coord(dname, values=None, length=rv_var.shape[d])
12971301

1298-
if data is None:
1302+
if observed is None:
12991303
self.free_RVs.append(rv_var)
13001304
self.create_value_var(rv_var, transform)
1301-
self.add_random_variable(rv_var, dims)
1305+
self.add_named_variable(rv_var, dims)
13021306
self.set_initval(rv_var, initval)
13031307
else:
13041308
if (
1305-
isinstance(data, Variable)
1306-
and not isinstance(data, (GenTensorVariable, Minibatch))
1307-
and data.owner is not None
1309+
isinstance(observed, Variable)
1310+
and not isinstance(observed, (GenTensorVariable, Minibatch))
1311+
and observed.owner is not None
13081312
# The only Aesara operation we allow on observed data is type casting
13091313
# Although we could allow for any graph that does not depend on other RVs
13101314
and not (
1311-
isinstance(data.owner.op, Elemwise)
1312-
and isinstance(data.owner.op.scalar_op, Cast)
1315+
isinstance(observed.owner.op, Elemwise)
1316+
and isinstance(observed.owner.op.scalar_op, Cast)
13131317
)
13141318
):
13151319
raise TypeError(
13161320
"Variables that depend on other nodes cannot be used for observed data."
1317-
f"The data variable was: {data}"
1321+
f"The data variable was: {observed}"
13181322
)
13191323

13201324
# `rv_var` is potentially changed by `make_obs_var`,
13211325
# for example into a new graph for imputation of missing data.
1322-
rv_var = self.make_obs_var(rv_var, data, dims, transform)
1326+
rv_var = self.make_obs_var(rv_var, observed, dims, transform)
13231327

13241328
return rv_var
13251329

@@ -1425,14 +1429,15 @@ def make_obs_var(
14251429
observed_rv_var.tag.observations = nonmissing_data
14261430

14271431
self.create_value_var(observed_rv_var, transform=None, value_var=nonmissing_data)
1428-
self.add_random_variable(observed_rv_var)
1432+
self.add_named_variable(observed_rv_var)
14291433
self.observed_RVs.append(observed_rv_var)
14301434

14311435
# Create deterministic that combines observed and missing
1436+
# Note: This can widely increase memory consumption during sampling for large datasets
14321437
rv_var = at.zeros(data.shape)
14331438
rv_var = at.set_subtensor(rv_var[mask.nonzero()], missing_rv_var)
14341439
rv_var = at.set_subtensor(rv_var[antimask_idx], observed_rv_var)
1435-
rv_var = Deterministic(name, rv_var, self, dims, auto=True)
1440+
rv_var = Deterministic(name, rv_var, self, dims)
14361441

14371442
else:
14381443
if sps.issparse(data):
@@ -1441,7 +1446,7 @@ def make_obs_var(
14411446
data = at.as_tensor_variable(data, name=name)
14421447
rv_var.tag.observations = data
14431448
self.create_value_var(rv_var, transform=None, value_var=data)
1444-
self.add_random_variable(rv_var, dims)
1449+
self.add_named_variable(rv_var, dims)
14451450
self.observed_RVs.append(rv_var)
14461451

14471452
return rv_var
@@ -1481,15 +1486,18 @@ def create_value_var(
14811486
value_var.tag.test_value = transform.forward(
14821487
value_var, *rv_var.owner.inputs
14831488
).tag.test_value
1484-
self.named_vars[value_var.name] = value_var
14851489
self.rvs_to_transforms[rv_var] = transform
14861490
self.rvs_to_values[rv_var] = value_var
14871491
self.values_to_rvs[value_var] = rv_var
14881492

14891493
return value_var
14901494

1491-
def add_random_variable(self, var, dims: Optional[Tuple[Union[str, None], ...]] = None):
1492-
"""Add a random variable to the named variables of the model."""
1495+
def add_named_variable(self, var, dims: Optional[Tuple[Union[str, None], ...]] = None):
1496+
"""Add a random graph variable to the named variables of the model.
1497+
1498+
This can include several types of variables such basic_RVs, Data, Deterministics,
1499+
and Potentials.
1500+
"""
14931501
if self.named_vars.tree_contains(var.name):
14941502
raise ValueError(f"Variable name {var.name} already exists.")
14951503

@@ -1501,7 +1509,7 @@ def add_random_variable(self, var, dims: Optional[Tuple[Union[str, None], ...]]
15011509
raise ValueError(f"Dimension {dim} is not specified in `coords`.")
15021510
if any(var.name == dim for dim in dims):
15031511
raise ValueError(f"Variable `{var.name}` has the same name as its dimension label.")
1504-
self._RV_dims[var.name] = dims
1512+
self.named_vars_to_dims[var.name] = dims
15051513

15061514
self.named_vars[var.name] = var
15071515
if not hasattr(self, self.name_of(var.name)):
@@ -1705,14 +1713,17 @@ def check_start_vals(self, start):
17051713
None
17061714
"""
17071715
start_points = [start] if isinstance(start, dict) else start
1716+
1717+
value_names_to_dtypes = {value.name: value.dtype for value in self.value_vars}
1718+
value_names_set = set(value_names_to_dtypes.keys())
17081719
for elem in start_points:
17091720

17101721
for k, v in elem.items():
1711-
elem[k] = np.asarray(v, dtype=self[k].dtype)
1722+
elem[k] = np.asarray(v, dtype=value_names_to_dtypes[k])
17121723

1713-
if not set(elem.keys()).issubset(self.named_vars.keys()):
1714-
extra_keys = ", ".join(set(elem.keys()) - set(self.named_vars.keys()))
1715-
valid_keys = ", ".join(self.named_vars.keys())
1724+
if not set(elem.keys()).issubset(value_names_set):
1725+
extra_keys = ", ".join(set(elem.keys()) - value_names_set)
1726+
valid_keys = ", ".join(value_names_set)
17161727
raise KeyError(
17171728
"Some start parameters do not appear in the model!\n"
17181729
f"Valid keys are: {valid_keys}, but {extra_keys} was supplied"
@@ -1899,7 +1910,7 @@ def Point(*args, filter_model_vars=False, **kwargs) -> Dict[str, np.ndarray]:
18991910
}
19001911

19011912

1902-
def Deterministic(name, var, model=None, dims=None, auto=False):
1913+
def Deterministic(name, var, model=None, dims=None):
19031914
"""Create a named deterministic variable.
19041915
19051916
Deterministic nodes are only deterministic given all of their inputs, i.e.
@@ -1962,11 +1973,8 @@ def Deterministic(name, var, model=None, dims=None, auto=False):
19621973
"""
19631974
model = modelcontext(model)
19641975
var = var.copy(model.name_for(name))
1965-
if auto:
1966-
model.auto_deterministics.append(var)
1967-
else:
1968-
model.deterministics.append(var)
1969-
model.add_random_variable(var, dims)
1976+
model.deterministics.append(var)
1977+
model.add_named_variable(var, dims)
19701978

19711979
from pymc.printing import str_for_potential_or_deterministic
19721980

@@ -1998,7 +2006,7 @@ def Potential(name, var, model=None):
19982006
model = modelcontext(model)
19992007
var.name = model.name_for(name)
20002008
model.potentials.append(var)
2001-
model.add_random_variable(var)
2009+
model.add_named_variable(var)
20022010

20032011
from pymc.printing import str_for_potential_or_deterministic
20042012

pymc/model_graph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,10 +200,10 @@ def get_plates(self, var_names: Optional[Iterable[VarName]] = None) -> Dict[str,
200200

201201
for var_name in self.vars_to_plot(var_names):
202202
v = self.model[var_name]
203-
if var_name in self.model.RV_dims:
203+
if var_name in self.model.named_vars_to_dims:
204204
plate_label = " x ".join(
205205
f"{d} ({self._eval(self.model.dim_lengths[d])})"
206-
for d in self.model.RV_dims[var_name]
206+
for d in self.model.named_vars_to_dims[var_name]
207207
)
208208
else:
209209
plate_label = " x ".join(map(str, self._eval(v.shape)))

0 commit comments

Comments
 (0)