@@ -647,7 +647,7 @@ def __init__(
647
647
648
648
# The sequence of model-generated RNGs
649
649
self .rng_seq = []
650
- self .initial_values = {}
650
+ self ._initial_values = {}
651
651
652
652
if self .parent is not None :
653
653
self .named_vars = treedict (parent = self .parent .named_vars )
@@ -911,17 +911,28 @@ def independent_vars(self):
911
911
return inputvars (self .unobserved_RVs )
912
912
913
913
@property
914
- def test_point (self ):
914
+ def test_point (self ) -> Dict [str , np .ndarray ]:
915
+ """Deprecated alias for `Model.initial_point`."""
915
916
warnings .warn (
916
917
"`Model.test_point` has been deprecated. Use `Model.initial_point` instead." ,
917
918
DeprecationWarning ,
918
919
)
919
920
return self .initial_point
920
921
921
922
@property
922
- def initial_point (self ):
923
+ def initial_point (self ) -> Dict [str , np .ndarray ]:
924
+ """Maps names of variables to initial values."""
923
925
return Point (list (self .initial_values .items ()), model = self )
924
926
927
+ @property
928
+ def initial_values (self ) -> Dict [TensorVariable , np .ndarray ]:
929
+ """Maps transformed variables to initial values.
930
+
931
+ ⚠ The keys are NOT the objects returned by, `pm.Normal(...)`.
932
+ For a name-based dictionary use the `initial_point` property.
933
+ """
934
+ return self ._initial_values
935
+
925
936
@property
926
937
def disc_vars (self ):
927
938
"""All the discrete variables in the model"""
@@ -1529,7 +1540,7 @@ def update_start_vals(self, a: Dict[str, np.ndarray], b: Dict[str, np.ndarray]):
1529
1540
r"""Update point `a` with `b`, without overwriting existing keys.
1530
1541
1531
1542
Values specified for transformed variables in `a` will be recomputed
1532
- conditional on the valures of `b` and stored in `b`.
1543
+ conditional on the values of `b` and stored in `b`.
1533
1544
1534
1545
"""
1535
1546
# TODO FIXME XXX: If we're going to incrementally update transformed
@@ -1716,14 +1727,16 @@ def fastfn(outs, mode=None, model=None):
1716
1727
return model .fastfn (outs , mode )
1717
1728
1718
1729
1719
- def Point (* args , filter_model_vars = False , ** kwargs ):
1730
+ def Point (* args , filter_model_vars = False , ** kwargs ) -> Dict [ str , np . ndarray ] :
1720
1731
"""Build a point. Uses same args as dict() does.
1721
1732
Filters out variables not in the model. All keys are strings.
1722
1733
1723
1734
Parameters
1724
1735
----------
1725
1736
args, kwargs
1726
1737
arguments to build a dict
1738
+ filter_model_vars : bool
1739
+ If `True`, only model variables are included in the result.
1727
1740
"""
1728
1741
model = modelcontext (kwargs .pop ("model" , None ))
1729
1742
args = list (args )
0 commit comments