|
59 | 59 | from pymc3.blocking import DictToArrayBijection, RaveledVars
|
60 | 60 | from pymc3.data import GenTensorVariable, Minibatch
|
61 | 61 | from pymc3.distributions import logp_transform, logpt, logpt_sum
|
| 62 | +from pymc3.distributions.transforms import Transform |
62 | 63 | from pymc3.exceptions import ImputationWarning, SamplingError, ShapeError
|
63 | 64 | from pymc3.math import flatten_list
|
64 | 65 | from pymc3.util import UNSET, WithMemoization, get_var_name, treedict, treelist
|
@@ -954,46 +955,77 @@ def set_initval(self, rv_var, initval):
|
954 | 955 | transform = getattr(rv_value_var.tag, "transform", None)
|
955 | 956 |
|
956 | 957 | if initval is None or transform:
|
957 |
| - # Sample/evaluate this using the existing initial values, and |
958 |
| - # with the least effect on the RNGs involved (i.e. no in-placing) |
| 958 | + initval = self._eval_initval(rv_var, initval, test_value, transform) |
959 | 959 |
|
960 |
| - mode = get_mode(None) |
961 |
| - opt_qry = mode.provided_optimizer.excluding("random_make_inplace") |
962 |
| - mode = Mode(linker=mode.linker, optimizer=opt_qry) |
| 960 | + self.initial_values[rv_value_var] = initval |
| 961 | + |
| 962 | + def _eval_initval( |
| 963 | + self, |
| 964 | + rv_var: TensorVariable, |
| 965 | + initval: Optional[Variable], |
| 966 | + test_value: Optional[np.ndarray], |
| 967 | + transform: Optional[Transform], |
| 968 | + ) -> np.ndarray: |
| 969 | + """Sample/evaluate an initial value using the existing initial values, |
| 970 | + and with the least effect on the RNGs involved (i.e. no in-placing). |
| 971 | +
|
| 972 | + Parameters |
| 973 | + ---------- |
| 974 | + rv_var : TensorVariable |
| 975 | + The model variable the initival belongs to. |
| 976 | + initval : Variable or None |
| 977 | + The initial value to be evaluated. |
| 978 | + If `None` a random draw will be made. |
| 979 | + test_value : optional, ndarray |
| 980 | + Fallback option if initval is None and random draws are not implemented. |
| 981 | + This is relevant for pm.Flat or pm.HalfFlat distributions and is subject |
| 982 | + to ongoing refactoring of the initval API. |
| 983 | + transform : optional, Transform |
| 984 | + A transformation associated with the random variable. |
| 985 | + Transformations are automatically applied to initial values. |
| 986 | +
|
| 987 | + Returns |
| 988 | + ------- |
| 989 | + initval : np.ndarray |
| 990 | + Numeric (transformed) initial value. |
| 991 | + """ |
| 992 | + mode = get_mode(None) |
| 993 | + opt_qry = mode.provided_optimizer.excluding("random_make_inplace") |
| 994 | + mode = Mode(linker=mode.linker, optimizer=opt_qry) |
| 995 | + |
| 996 | + if transform: |
| 997 | + if initval is not None: |
| 998 | + value = initval |
| 999 | + else: |
| 1000 | + value = rv_var |
| 1001 | + rv_var = at.as_tensor_variable(transform.forward(rv_var, value)) |
963 | 1002 |
|
| 1003 | + def initval_to_rvval(value_var, value): |
| 1004 | + rv_var = self.values_to_rvs[value_var] |
| 1005 | + initval = value_var.type.make_constant(value) |
| 1006 | + transform = getattr(value_var.tag, "transform", None) |
964 | 1007 | if transform:
|
965 |
| - value = initval if initval is not None else rv_var |
966 |
| - rv_var = at.as_tensor_variable(transform.forward(rv_var, value)) |
| 1008 | + return transform.backward(rv_var, initval) |
| 1009 | + else: |
| 1010 | + return initval |
967 | 1011 |
|
968 |
| - def initval_to_rvval(value_var, value): |
969 |
| - rv_var = self.values_to_rvs[value_var] |
970 |
| - initval = value_var.type.make_constant(value) |
971 |
| - transform = getattr(value_var.tag, "transform", None) |
972 |
| - if transform: |
973 |
| - return transform.backward(rv_var, initval) |
974 |
| - else: |
975 |
| - return initval |
976 |
| - |
977 |
| - givens = { |
978 |
| - self.values_to_rvs[k]: initval_to_rvval(k, v) |
979 |
| - for k, v in self.initial_values.items() |
980 |
| - } |
981 |
| - initval_fn = aesara.function( |
982 |
| - [], rv_var, mode=mode, givens=givens, on_unused_input="ignore" |
983 |
| - ) |
984 |
| - try: |
985 |
| - initval = initval_fn() |
986 |
| - except NotImplementedError as ex: |
987 |
| - if "Cannot sample from" in ex.args[0]: |
988 |
| - # The RV does not have a random number generator. |
989 |
| - # Our last chance is to take the test_value. |
990 |
| - # Note that this is a workaround for Flat and HalfFlat |
991 |
| - # until an initval default mechanism is implemented (#4752). |
992 |
| - initval = test_value |
993 |
| - else: |
994 |
| - raise |
| 1012 | + givens = { |
| 1013 | + self.values_to_rvs[k]: initval_to_rvval(k, v) for k, v in self.initial_values.items() |
| 1014 | + } |
| 1015 | + initval_fn = aesara.function([], rv_var, mode=mode, givens=givens, on_unused_input="ignore") |
| 1016 | + try: |
| 1017 | + initval = initval_fn() |
| 1018 | + except NotImplementedError as ex: |
| 1019 | + if "Cannot sample from" in ex.args[0]: |
| 1020 | + # The RV does not have a random number generator. |
| 1021 | + # Our last chance is to take the test_value. |
| 1022 | + # Note that this is a workaround for Flat and HalfFlat |
| 1023 | + # until an initval default mechanism is implemented (#4752). |
| 1024 | + initval = test_value |
| 1025 | + else: |
| 1026 | + raise |
995 | 1027 |
|
996 |
| - self.initial_values[rv_value_var] = initval |
| 1028 | + return initval |
997 | 1029 |
|
998 | 1030 | def next_rng(self) -> RandomStateSharedVariable:
|
999 | 1031 | """Generate a new ``RandomStateSharedVariable``.
|
|
0 commit comments