Skip to content

Commit 00ffb3b

Browse files
michaelosthegetwiecki
authored andcommitted
Extract initval evaluation into its own method
1 parent a6a1dce commit 00ffb3b

File tree

1 file changed

+67
-35
lines changed

1 file changed

+67
-35
lines changed

pymc3/model.py

Lines changed: 67 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
from pymc3.blocking import DictToArrayBijection, RaveledVars
6060
from pymc3.data import GenTensorVariable, Minibatch
6161
from pymc3.distributions import logp_transform, logpt, logpt_sum
62+
from pymc3.distributions.transforms import Transform
6263
from pymc3.exceptions import ImputationWarning, SamplingError, ShapeError
6364
from pymc3.math import flatten_list
6465
from pymc3.util import UNSET, WithMemoization, get_var_name, treedict, treelist
@@ -954,46 +955,77 @@ def set_initval(self, rv_var, initval):
954955
transform = getattr(rv_value_var.tag, "transform", None)
955956

956957
if initval is None or transform:
957-
# Sample/evaluate this using the existing initial values, and
958-
# with the least effect on the RNGs involved (i.e. no in-placing)
958+
initval = self._eval_initval(rv_var, initval, test_value, transform)
959959

960-
mode = get_mode(None)
961-
opt_qry = mode.provided_optimizer.excluding("random_make_inplace")
962-
mode = Mode(linker=mode.linker, optimizer=opt_qry)
960+
self.initial_values[rv_value_var] = initval
961+
962+
def _eval_initval(
963+
self,
964+
rv_var: TensorVariable,
965+
initval: Optional[Variable],
966+
test_value: Optional[np.ndarray],
967+
transform: Optional[Transform],
968+
) -> np.ndarray:
969+
"""Sample/evaluate an initial value using the existing initial values,
970+
and with the least effect on the RNGs involved (i.e. no in-placing).
971+
972+
Parameters
973+
----------
974+
rv_var : TensorVariable
975+
The model variable the initival belongs to.
976+
initval : Variable or None
977+
The initial value to be evaluated.
978+
If `None` a random draw will be made.
979+
test_value : optional, ndarray
980+
Fallback option if initval is None and random draws are not implemented.
981+
This is relevant for pm.Flat or pm.HalfFlat distributions and is subject
982+
to ongoing refactoring of the initval API.
983+
transform : optional, Transform
984+
A transformation associated with the random variable.
985+
Transformations are automatically applied to initial values.
986+
987+
Returns
988+
-------
989+
initval : np.ndarray
990+
Numeric (transformed) initial value.
991+
"""
992+
mode = get_mode(None)
993+
opt_qry = mode.provided_optimizer.excluding("random_make_inplace")
994+
mode = Mode(linker=mode.linker, optimizer=opt_qry)
995+
996+
if transform:
997+
if initval is not None:
998+
value = initval
999+
else:
1000+
value = rv_var
1001+
rv_var = at.as_tensor_variable(transform.forward(rv_var, value))
9631002

1003+
def initval_to_rvval(value_var, value):
1004+
rv_var = self.values_to_rvs[value_var]
1005+
initval = value_var.type.make_constant(value)
1006+
transform = getattr(value_var.tag, "transform", None)
9641007
if transform:
965-
value = initval if initval is not None else rv_var
966-
rv_var = at.as_tensor_variable(transform.forward(rv_var, value))
1008+
return transform.backward(rv_var, initval)
1009+
else:
1010+
return initval
9671011

968-
def initval_to_rvval(value_var, value):
969-
rv_var = self.values_to_rvs[value_var]
970-
initval = value_var.type.make_constant(value)
971-
transform = getattr(value_var.tag, "transform", None)
972-
if transform:
973-
return transform.backward(rv_var, initval)
974-
else:
975-
return initval
976-
977-
givens = {
978-
self.values_to_rvs[k]: initval_to_rvval(k, v)
979-
for k, v in self.initial_values.items()
980-
}
981-
initval_fn = aesara.function(
982-
[], rv_var, mode=mode, givens=givens, on_unused_input="ignore"
983-
)
984-
try:
985-
initval = initval_fn()
986-
except NotImplementedError as ex:
987-
if "Cannot sample from" in ex.args[0]:
988-
# The RV does not have a random number generator.
989-
# Our last chance is to take the test_value.
990-
# Note that this is a workaround for Flat and HalfFlat
991-
# until an initval default mechanism is implemented (#4752).
992-
initval = test_value
993-
else:
994-
raise
1012+
givens = {
1013+
self.values_to_rvs[k]: initval_to_rvval(k, v) for k, v in self.initial_values.items()
1014+
}
1015+
initval_fn = aesara.function([], rv_var, mode=mode, givens=givens, on_unused_input="ignore")
1016+
try:
1017+
initval = initval_fn()
1018+
except NotImplementedError as ex:
1019+
if "Cannot sample from" in ex.args[0]:
1020+
# The RV does not have a random number generator.
1021+
# Our last chance is to take the test_value.
1022+
# Note that this is a workaround for Flat and HalfFlat
1023+
# until an initval default mechanism is implemented (#4752).
1024+
initval = test_value
1025+
else:
1026+
raise
9951027

996-
self.initial_values[rv_value_var] = initval
1028+
return initval
9971029

9981030
def next_rng(self) -> RandomStateSharedVariable:
9991031
"""Generate a new ``RandomStateSharedVariable``.

0 commit comments

Comments
 (0)