Skip to content

Commit b7cb362

Browse files
Make initial_values a property and add/fix docstrings and type hints
1 parent a5044fc commit b7cb362

File tree

2 files changed

+20
-7
lines changed

2 files changed

+20
-7
lines changed

pymc3/model.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -647,7 +647,7 @@ def __init__(
647647

648648
# The sequence of model-generated RNGs
649649
self.rng_seq = []
650-
self.initial_values = {}
650+
self._initial_values = {}
651651

652652
if self.parent is not None:
653653
self.named_vars = treedict(parent=self.parent.named_vars)
@@ -911,17 +911,28 @@ def independent_vars(self):
911911
return inputvars(self.unobserved_RVs)
912912

913913
@property
914-
def test_point(self):
914+
def test_point(self) -> Dict[str, np.ndarray]:
915+
"""Deprecated alias for `Model.initial_point`."""
915916
warnings.warn(
916917
"`Model.test_point` has been deprecated. Use `Model.initial_point` instead.",
917918
DeprecationWarning,
918919
)
919920
return self.initial_point
920921

921922
@property
922-
def initial_point(self):
923+
def initial_point(self) -> Dict[str, np.ndarray]:
924+
"""Maps names of variables to initial values."""
923925
return Point(list(self.initial_values.items()), model=self)
924926

927+
@property
928+
def initial_values(self) -> Dict[TensorVariable, np.ndarray]:
929+
"""Maps transformed variables to initial values.
930+
931+
⚠ The keys are NOT the objects returned by, `pm.Normal(...)`.
932+
For a name-based dictionary use the `initial_point` property.
933+
"""
934+
return self._initial_values
935+
925936
@property
926937
def disc_vars(self):
927938
"""All the discrete variables in the model"""
@@ -1529,7 +1540,7 @@ def update_start_vals(self, a: Dict[str, np.ndarray], b: Dict[str, np.ndarray]):
15291540
r"""Update point `a` with `b`, without overwriting existing keys.
15301541
15311542
Values specified for transformed variables in `a` will be recomputed
1532-
conditional on the valures of `b` and stored in `b`.
1543+
conditional on the values of `b` and stored in `b`.
15331544
15341545
"""
15351546
# TODO FIXME XXX: If we're going to incrementally update transformed
@@ -1716,14 +1727,16 @@ def fastfn(outs, mode=None, model=None):
17161727
return model.fastfn(outs, mode)
17171728

17181729

1719-
def Point(*args, filter_model_vars=False, **kwargs):
1730+
def Point(*args, filter_model_vars=False, **kwargs) -> Dict[str, np.ndarray]:
17201731
"""Build a point. Uses same args as dict() does.
17211732
Filters out variables not in the model. All keys are strings.
17221733
17231734
Parameters
17241735
----------
17251736
args, kwargs
17261737
arguments to build a dict
1738+
filter_model_vars : bool
1739+
If `True`, only model variables are included in the result.
17271740
"""
17281741
model = modelcontext(kwargs.pop("model", None))
17291742
args = list(args)

pymc3/util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,9 +191,9 @@ def get_default_varnames(var_iterator, include_transformed):
191191
return [var for var in var_iterator if not is_transformed_name(get_var_name(var))]
192192

193193

194-
def get_var_name(var):
194+
def get_var_name(var) -> str:
195195
"""Get an appropriate, plain variable name for a variable."""
196-
return getattr(var, "name", str(var))
196+
return str(getattr(var, "name", var))
197197

198198

199199
def get_transformed(z):

0 commit comments

Comments
 (0)